mirror of https://github.com/HazyResearch/manifest
feat: add local huggingface embedding models (#76)
parent
40de0e7f59
commit
0fb192a0a2
@ -0,0 +1,113 @@
|
||||
"""Sentence transformer model."""
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from manifest.api.models.model import Model
|
||||
|
||||
|
||||
class SentenceTransformerModel(Model):
|
||||
"""SentenceTransformer model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
model_type: Optional[str] = None,
|
||||
model_config: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
device: int = 0,
|
||||
use_accelerate: bool = False,
|
||||
use_parallelize: bool = False,
|
||||
use_bitsandbytes: bool = False,
|
||||
use_deepspeed: bool = False,
|
||||
perc_max_gpu_mem_red: float = 1.0,
|
||||
use_fp16: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize model.
|
||||
|
||||
All arguments will be passed in the request from Manifest.
|
||||
|
||||
Args:
|
||||
model_name_or_path: model name string.
|
||||
model_config: model config string.
|
||||
cache_dir: cache directory for model.
|
||||
device: device to use for model.
|
||||
use_accelerate: whether to use accelerate for multi-gpu inference.
|
||||
use_parallelize: use HF default parallelize
|
||||
use_bitsandbytes: use HF bits and bytes
|
||||
use_deepspeed: use deepspeed
|
||||
perc_max_gpu_mem_red: percent max memory reduction in accelerate
|
||||
use_fp16: use fp16 for model weights.
|
||||
"""
|
||||
if use_accelerate or use_parallelize or use_bitsandbytes or use_deepspeed:
|
||||
raise ValueError(
|
||||
"Cannot use accelerate or parallelize or "
|
||||
"bitsandbytes or deepspeeed with sentence transformers"
|
||||
)
|
||||
# Check if providing path
|
||||
self.model_name = model_name_or_path
|
||||
print("Model Name:", self.model_name)
|
||||
torch_device = (
|
||||
torch.device("cpu")
|
||||
if (device == -1 or not torch.cuda.is_available())
|
||||
else torch.device(f"cuda:{device}")
|
||||
)
|
||||
self.embedding_model = SentenceTransformer(self.model_name, device=torch_device)
|
||||
self.embedding_model.to(torch_device)
|
||||
self.embedding_model.eval()
|
||||
|
||||
def get_init_params(self) -> Dict:
|
||||
"""Return init params to determine what model is being used."""
|
||||
return {"model_name": self.model_name, "model_path": self.model_name}
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self, prompt: Union[str, List[str]], **kwargs: Any
|
||||
) -> List[Tuple[Any, float, List[int], List[float]]]:
|
||||
"""
|
||||
Generate the prompt from model.
|
||||
|
||||
Outputs must be generated text and score, not including prompt.
|
||||
|
||||
Args:
|
||||
prompt: promt to generate from.
|
||||
|
||||
Returns:
|
||||
list of generated text (list of length 1 for 1 generation).
|
||||
"""
|
||||
raise NotImplementedError("Generate not supported for sentence transformers")
|
||||
|
||||
@torch.no_grad()
|
||||
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
|
||||
"""
|
||||
Embed the prompt from model.
|
||||
|
||||
Args:
|
||||
prompt: promt to embed from.
|
||||
|
||||
Returns:
|
||||
list of embeddings (list of length 1 for 1 embedding).
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
return self.embedding_model.encode(prompt)
|
||||
|
||||
@torch.no_grad()
|
||||
def score_sequence(
|
||||
self, prompt: Union[str, List[str]], **kwargs: Any
|
||||
) -> List[Tuple[float, List[int], List[float]]]:
|
||||
"""
|
||||
Score a sequence of choices.
|
||||
|
||||
Args:
|
||||
prompt (:obj:`str` or :obj:`List[str]`):
|
||||
The prompt to score the choices against.
|
||||
**kwargs:
|
||||
Additional keyword arguments passed along to the :obj:`__call__` method.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Score sequence not supported for sentence transformers"
|
||||
)
|
@ -0,0 +1,89 @@
|
||||
"""Hugging Face client."""
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from manifest.clients.client import Client
|
||||
from manifest.request import EmbeddingRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceEmbeddingClient(Client):
|
||||
"""HuggingFaceEmbedding client."""
|
||||
|
||||
# User param -> (client param, default value)
|
||||
PARAMS: Dict[str, Tuple[str, Any]] = {}
|
||||
REQUEST_CLS = EmbeddingRequest
|
||||
NAME = "huggingfaceembedding"
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the HuggingFace url.
|
||||
|
||||
Arsg:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
if not connection_str:
|
||||
raise ValueError("Must provide connection string")
|
||||
self.host = connection_str.rstrip("/")
|
||||
for key in self.PARAMS:
|
||||
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_generation_url(self) -> str:
|
||||
"""Get generation URL."""
|
||||
return self.host + "/embed"
|
||||
|
||||
def get_generation_header(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get generation header.
|
||||
|
||||
Returns:
|
||||
header.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def supports_batch_inference(self) -> bool:
|
||||
"""Return whether the client supports batch inference."""
|
||||
return True
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
res = requests.post(self.host + "/params").json()
|
||||
res["client_name"] = self.NAME
|
||||
return res
|
||||
|
||||
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Format response to dict.
|
||||
|
||||
Args:
|
||||
response: response
|
||||
request: request
|
||||
|
||||
Return:
|
||||
response as dict
|
||||
"""
|
||||
# Convert array to np.array
|
||||
for choice in response["choices"]:
|
||||
choice["array"] = np.array(choice["array"])
|
||||
return response
|
Loading…
Reference in New Issue