From 0fb192a0a205599df02a66208a918c577e0a8e6b Mon Sep 17 00:00:00 2001 From: Laurel Orr <57237365+lorr1@users.noreply.github.com> Date: Sat, 8 Apr 2023 23:55:50 -0700 Subject: [PATCH] feat: add local huggingface embedding models (#76) --- examples/manifest_embedding.ipynb | 31 ++++-- manifest/api/app.py | 44 ++++++-- manifest/api/models/diffuser.py | 16 ++- manifest/api/models/huggingface.py | 25 ++++- manifest/api/models/model.py | 11 +- manifest/api/models/sentence_transformer.py | 113 ++++++++++++++++++++ manifest/api/response.py | 11 +- manifest/caches/cache.py | 1 + manifest/clients/diffuser.py | 5 +- manifest/clients/huggingface.py | 5 +- manifest/clients/huggingface_embedding.py | 89 +++++++++++++++ manifest/manifest.py | 2 + manifest/request.py | 5 +- manifest/response.py | 4 + pyproject.toml | 1 + setup.py | 5 +- tests/test_huggingface_api.py | 33 ++++++ tests/test_manifest.py | 95 ++++++++++++++++ 18 files changed, 455 insertions(+), 41 deletions(-) create mode 100644 manifest/api/models/sentence_transformer.py create mode 100644 manifest/clients/huggingface_embedding.py diff --git a/examples/manifest_embedding.ipynb b/examples/manifest_embedding.ipynb index 862a3fe..f60c59d 100644 --- a/examples/manifest_embedding.ipynb +++ b/examples/manifest_embedding.ipynb @@ -69,19 +69,24 @@ "```\n", "python3 manifest/api/app.py --model_type huggingface --model_name_or_path EleutherAI/gpt-neo-125M --device 0\n", "```\n", - "in a separate `screen` or `tmux`." + "or\n", + "```\n", + "python3 manifest/api/app.py --model_type sentence_transformers --model_name_or_path all-mpnet-base-v2 --device 0\n", + "```\n", + "\n", + "in a separate `screen` or `tmux`. Make sure to note the port. You can change this with `export FLASK_PORT=`." ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'model_name': 'EleutherAI/gpt-neo-125M', 'model_path': 'EleutherAI/gpt-neo-125M'}\n" + "{'model_name': 'all-mpnet-base-v2', 'model_path': 'all-mpnet-base-v2', 'client_name': 'huggingfaceembedding'}\n" ] } ], @@ -91,7 +96,7 @@ "# Local hosted GPT Neo 125M\n", "manifest = Manifest(\n", " client_name=\"huggingfaceembedding\",\n", - " client_connection=\"http://127.0.0.1:6001\",\n", + " client_connection=\"http://127.0.0.1:6000\",\n", " cache_name=\"sqlite\",\n", " cache_connection=\"my_sqlite_manifest.sqlite\"\n", ")\n", @@ -100,12 +105,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(768,)\n", + "(768,) (768,)\n" + ] + } + ], "source": [ "emb = manifest.run(\"Is this an embedding?\")\n", - "emb2 = manifest.run(\"Is this an embedding?\", aggregation=\"mean\")" + "print(emb.shape)\n", + "\n", + "emb = manifest.run([\"Is this an embedding?\", \"Bananas!!!\"])\n", + "print(emb[0].shape, emb[1].shape)" ] } ], diff --git a/manifest/api/app.py b/manifest/api/app.py index 1856844..a8f9ebc 100644 --- a/manifest/api/app.py +++ b/manifest/api/app.py @@ -16,6 +16,7 @@ from manifest.api.models.huggingface import ( CrossModalEncoderModel, TextGenerationModel, ) +from manifest.api.models.sentence_transformer import SentenceTransformerModel from manifest.api.response import ModelResponse os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -28,6 +29,7 @@ model_type = None PORT = int(os.environ.get("FLASK_PORT", 5000)) MODEL_CONSTRUCTORS = { "huggingface": TextGenerationModel, + "sentence_transformers": SentenceTransformerModel, "huggingface_crossmodal": CrossModalEncoderModel, "diffuser": DiffuserModel, } @@ -198,11 +200,14 @@ def completions() -> Response: @app.route("/embed", methods=["POST"]) -def embed() -> Dict: +def embed() -> Response: """Get embed for generation.""" - modality = request.json["modality"] + if "modality" in request.json: + modality = request.json["modality"] + else: + modality = "text" if modality == "text": - prompts = request.json["prompts"] + prompts = request.json["prompt"] elif modality == "image": import base64 @@ -210,19 +215,36 @@ def embed() -> Dict: prompts = [ Image.open(io.BytesIO(base64.b64decode(data))) - for data in request.json["prompts"] + for data in request.json["prompt"] ] else: raise ValueError("modality must be text or image") - results = [] - embeddings = model.embed(prompts) - for embedding in embeddings: - results.append(embedding.tolist()) + try: + results = [] + embeddings = model.embed(prompts) + for embedding in embeddings: + results.append( + { + "array": embedding, + "logprob": None, + "tokens": None, + "token_logprobs": None, + } + ) - # transform the result into the openai format - # return Response(results, response_type="text_completion").__dict__() - return {"result": results} + return Response( + json.dumps( + ModelResponse(results, response_type="embedding_generation").__dict__() + ), + status=200, + ) + except Exception as e: + logger.error(e) + return Response( + json.dumps({"message": str(e)}), + status=400, + ) @app.route("/score_sequence", methods=["POST"]) diff --git a/manifest/api/models/diffuser.py b/manifest/api/models/diffuser.py index ffbc403..b42ed3a 100644 --- a/manifest/api/models/diffuser.py +++ b/manifest/api/models/diffuser.py @@ -1,7 +1,8 @@ -"""Huggingface model.""" +"""Diffuser model.""" from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np import torch from diffusers import StableDiffusionPipeline @@ -93,6 +94,19 @@ class DiffuserModel(Model): # Return None for logprobs and token logprobs return [(im, None, None, None) for im in result["images"]] + @torch.no_grad() + def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray: + """ + Embed the prompt from model. + + Args: + prompt: promt to embed from. + + Returns: + list of embeddings (list of length 1 for 1 embedding). + """ + raise NotImplementedError("Embed not supported for diffusers") + @torch.no_grad() def score_sequence( self, prompt: Union[str, List[str]], **kwargs: Any diff --git a/manifest/api/models/huggingface.py b/manifest/api/models/huggingface.py index 038b7c3..69c1eb6 100644 --- a/manifest/api/models/huggingface.py +++ b/manifest/api/models/huggingface.py @@ -535,15 +535,32 @@ class TextGenerationModel(HuggingFaceModel): @torch.no_grad() def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray: """ - Compute embedding for prompts. + Embed the prompt from model. Args: - prompt: promt to generate from. + prompt: promt to embed from. Returns: - embedding + list of embeddings (list of length 1 for 1 embedding). """ - pass + if isinstance(prompt, str): + prompt = [prompt] + encoded_prompt = self.pipeline.tokenizer( + prompt, + max_length=self.pipeline.max_length, + truncation=True, + padding=True, + return_tensors="pt", + ) + encoded_prompt = encoded_prompt.to(self.pipeline.device) + # Get last hidden state + output = self.pipeline.model( # type: ignore + **encoded_prompt, + output_hidden_states=True, + return_dict=True, + ) + last_hidden_state = output["hidden_states"][-1][:, -1, :] + return last_hidden_state.cpu().numpy() @torch.no_grad() def generate( diff --git a/manifest/api/models/model.py b/manifest/api/models/model.py index 84d91ab..3317211 100644 --- a/manifest/api/models/model.py +++ b/manifest/api/models/model.py @@ -1,14 +1,12 @@ """Model class.""" -from abc import ABC, abstractmethod from typing import Any, Dict, List, Tuple, Union import numpy as np -class Model(ABC): +class Model: """Model class.""" - @abstractmethod def __init__( self, model_name_or_path: str, @@ -41,7 +39,6 @@ class Model(ABC): """ raise NotImplementedError() - @abstractmethod def get_init_params(self) -> Dict: """Return init params to determine what model is being used.""" raise NotImplementedError() @@ -66,13 +63,13 @@ class Model(ABC): def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray: """ - Compute embedding for prompts. + Embed the prompt from model. Args: - prompt: promt to generate from. + prompt: promt to embed from. Returns: - embedding + list of embeddings (list of length 1 for 1 embedding). """ raise NotImplementedError() diff --git a/manifest/api/models/sentence_transformer.py b/manifest/api/models/sentence_transformer.py new file mode 100644 index 0000000..bd3f5fa --- /dev/null +++ b/manifest/api/models/sentence_transformer.py @@ -0,0 +1,113 @@ +"""Sentence transformer model.""" +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from sentence_transformers import SentenceTransformer + +from manifest.api.models.model import Model + + +class SentenceTransformerModel(Model): + """SentenceTransformer model.""" + + def __init__( + self, + model_name_or_path: str, + model_type: Optional[str] = None, + model_config: Optional[str] = None, + cache_dir: Optional[str] = None, + device: int = 0, + use_accelerate: bool = False, + use_parallelize: bool = False, + use_bitsandbytes: bool = False, + use_deepspeed: bool = False, + perc_max_gpu_mem_red: float = 1.0, + use_fp16: bool = False, + ): + """ + Initialize model. + + All arguments will be passed in the request from Manifest. + + Args: + model_name_or_path: model name string. + model_config: model config string. + cache_dir: cache directory for model. + device: device to use for model. + use_accelerate: whether to use accelerate for multi-gpu inference. + use_parallelize: use HF default parallelize + use_bitsandbytes: use HF bits and bytes + use_deepspeed: use deepspeed + perc_max_gpu_mem_red: percent max memory reduction in accelerate + use_fp16: use fp16 for model weights. + """ + if use_accelerate or use_parallelize or use_bitsandbytes or use_deepspeed: + raise ValueError( + "Cannot use accelerate or parallelize or " + "bitsandbytes or deepspeeed with sentence transformers" + ) + # Check if providing path + self.model_name = model_name_or_path + print("Model Name:", self.model_name) + torch_device = ( + torch.device("cpu") + if (device == -1 or not torch.cuda.is_available()) + else torch.device(f"cuda:{device}") + ) + self.embedding_model = SentenceTransformer(self.model_name, device=torch_device) + self.embedding_model.to(torch_device) + self.embedding_model.eval() + + def get_init_params(self) -> Dict: + """Return init params to determine what model is being used.""" + return {"model_name": self.model_name, "model_path": self.model_name} + + @torch.no_grad() + def generate( + self, prompt: Union[str, List[str]], **kwargs: Any + ) -> List[Tuple[Any, float, List[int], List[float]]]: + """ + Generate the prompt from model. + + Outputs must be generated text and score, not including prompt. + + Args: + prompt: promt to generate from. + + Returns: + list of generated text (list of length 1 for 1 generation). + """ + raise NotImplementedError("Generate not supported for sentence transformers") + + @torch.no_grad() + def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray: + """ + Embed the prompt from model. + + Args: + prompt: promt to embed from. + + Returns: + list of embeddings (list of length 1 for 1 embedding). + """ + if isinstance(prompt, str): + prompt = [prompt] + return self.embedding_model.encode(prompt) + + @torch.no_grad() + def score_sequence( + self, prompt: Union[str, List[str]], **kwargs: Any + ) -> List[Tuple[float, List[int], List[float]]]: + """ + Score a sequence of choices. + + Args: + prompt (:obj:`str` or :obj:`List[str]`): + The prompt to score the choices against. + **kwargs: + Additional keyword arguments passed along to the :obj:`__call__` method. + """ + raise NotImplementedError( + "Score sequence not supported for sentence transformers" + ) diff --git a/manifest/api/response.py b/manifest/api/response.py index a0860e9..6c55489 100644 --- a/manifest/api/response.py +++ b/manifest/api/response.py @@ -16,17 +16,24 @@ class ModelResponse: "text_completion", "prompt_logit_score", "image_generation", + "embedding_generation", }: raise ValueError( f"Invalid response type: {self.response_type}. " - "Must be one of: text_completion, prompt_logit_score, image_generation." + "Must be one of: text_completion, prompt_logit_score, " + "image_generation, embedding_generation." ) self.response_id = str(uuid.uuid4()) self.created = int(time.time()) def __dict__(self) -> Dict[str, Any]: # type: ignore """Return dictionary representation of response.""" - key = "text" if self.response_type != "image_generation" else "array" + key = ( + "text" + if self.response_type + not in {"prompt_logit_score", "image_generation", "embedding_generation"} + else "array" + ) return { "id": self.response_id, "object": self.response_type, diff --git a/manifest/caches/cache.py b/manifest/caches/cache.py index 74df2fb..347b982 100644 --- a/manifest/caches/cache.py +++ b/manifest/caches/cache.py @@ -10,6 +10,7 @@ ARRAY_CACHE_TYPES = { "diffuser", "tomadiffuser", "openaiembedding", + "huggingfaceembedding", } diff --git a/manifest/clients/diffuser.py b/manifest/clients/diffuser.py index bb8db66..e0254d2 100644 --- a/manifest/clients/diffuser.py +++ b/manifest/clients/diffuser.py @@ -82,8 +82,9 @@ class DiffuserClient(Client): Returns: model params. """ - res = requests.post(self.host + "/params") - return res.json() + res = requests.post(self.host + "/params").json() + res["client_name"] = self.NAME + return res def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ diff --git a/manifest/clients/huggingface.py b/manifest/clients/huggingface.py index f29bc43..4e1d95d 100644 --- a/manifest/clients/huggingface.py +++ b/manifest/clients/huggingface.py @@ -76,8 +76,9 @@ class HuggingFaceClient(Client): Returns: model params. """ - res = requests.post(self.host + "/params") - return res.json() + res = requests.post(self.host + "/params").json() + res["client_name"] = self.NAME + return res def get_score_prompt_request( self, diff --git a/manifest/clients/huggingface_embedding.py b/manifest/clients/huggingface_embedding.py new file mode 100644 index 0000000..a052b85 --- /dev/null +++ b/manifest/clients/huggingface_embedding.py @@ -0,0 +1,89 @@ +"""Hugging Face client.""" +import logging +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import requests + +from manifest.clients.client import Client +from manifest.request import EmbeddingRequest + +logger = logging.getLogger(__name__) + + +class HuggingFaceEmbeddingClient(Client): + """HuggingFaceEmbedding client.""" + + # User param -> (client param, default value) + PARAMS: Dict[str, Tuple[str, Any]] = {} + REQUEST_CLS = EmbeddingRequest + NAME = "huggingfaceembedding" + + def connect( + self, + connection_str: Optional[str] = None, + client_args: Dict[str, Any] = {}, + ) -> None: + """ + Connect to the HuggingFace url. + + Arsg: + connection_str: connection string. + client_args: client arguments. + """ + if not connection_str: + raise ValueError("Must provide connection string") + self.host = connection_str.rstrip("/") + for key in self.PARAMS: + setattr(self, key, client_args.pop(key, self.PARAMS[key][1])) + + def close(self) -> None: + """Close the client.""" + pass + + def get_generation_url(self) -> str: + """Get generation URL.""" + return self.host + "/embed" + + def get_generation_header(self) -> Dict[str, str]: + """ + Get generation header. + + Returns: + header. + """ + return {} + + def supports_batch_inference(self) -> bool: + """Return whether the client supports batch inference.""" + return True + + def get_model_params(self) -> Dict: + """ + Get model params. + + By getting model params from the server, we can add to request + and make sure cache keys are unique to model. + + Returns: + model params. + """ + res = requests.post(self.host + "/params").json() + res["client_name"] = self.NAME + return res + + def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: + """ + Format response to dict. + + Args: + response: response + request: request + + Return: + response as dict + """ + # Convert array to np.array + for choice in response["choices"]: + choice["array"] = np.array(choice["array"]) + return response diff --git a/manifest/manifest.py b/manifest/manifest.py index 0717c61..805366a 100644 --- a/manifest/manifest.py +++ b/manifest/manifest.py @@ -13,6 +13,7 @@ from manifest.clients.ai21 import AI21Client from manifest.clients.cohere import CohereClient from manifest.clients.dummy import DummyClient from manifest.clients.huggingface import HuggingFaceClient +from manifest.clients.huggingface_embedding import HuggingFaceEmbeddingClient from manifest.clients.openai import OpenAIClient from manifest.clients.openai_chat import OpenAIChatClient from manifest.clients.openai_embedding import OpenAIEmbeddingClient @@ -30,6 +31,7 @@ CLIENT_CONSTRUCTORS = { CohereClient.NAME: CohereClient, AI21Client.NAME: AI21Client, HuggingFaceClient.NAME: HuggingFaceClient, + HuggingFaceEmbeddingClient.NAME: HuggingFaceEmbeddingClient, DummyClient.NAME: DummyClient, TOMAClient.NAME: TOMAClient, } diff --git a/manifest/request.py b/manifest/request.py index a6c31de..82af1cf 100644 --- a/manifest/request.py +++ b/manifest/request.py @@ -1,5 +1,5 @@ """Request object.""" -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel @@ -99,8 +99,7 @@ class LMRequest(Request): class EmbeddingRequest(Request): """Embedding Request object.""" - # Aggregate method (if applicable) - aggregation_method: Optional[Literal["last_token", "mean"]] = None + pass class DiffusionRequest(Request): diff --git a/manifest/response.py b/manifest/response.py index 2b9a1ef..5c751ab 100644 --- a/manifest/response.py +++ b/manifest/response.py @@ -17,6 +17,10 @@ RESPONSE_CONSTRUCTORS = { "logits_key": "token_logprobs", "item_key": "array", }, + "huggingfaceembedding": { + "logits_key": "token_logprobs", + "item_key": "array", + }, } diff --git a/pyproject.toml b/pyproject.toml index 7b4b0ed..0e63470 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ module = [ "deepspeed", "numpy", "diffusers", + "sentence_transformers", "sqlitedict", "dill", "accelerate", diff --git a/setup.py b/setup.py index 4ddd8ec..5952017 100644 --- a/setup.py +++ b/setup.py @@ -42,12 +42,13 @@ REQUIRED = [ # What packages are optional? EXTRAS = { "api": [ + "accelerate>=0.10.0", "deepspeed>=0.7.0", "diffusers>=0.6.0", "Flask>=2.1.2", - "accelerate>=0.10.0", - "transformers>=4.20.0,<4.26.0", + "sentence_transformers>=2.2.0", "torch>=1.8.0", + "transformers>=4.20.0,<4.26.0", ], "app": [ "fastapi>=0.70.0", diff --git a/tests/test_huggingface_api.py b/tests/test_huggingface_api.py index b23ff5f..75446eb 100644 --- a/tests/test_huggingface_api.py +++ b/tests/test_huggingface_api.py @@ -4,9 +4,11 @@ import math import os from subprocess import PIPE, Popen +import numpy as np import pytest from manifest.api.models.huggingface import MODEL_REGISTRY, TextGenerationModel +from manifest.api.models.sentence_transformer import SentenceTransformerModel NOCUDA = 0 try: @@ -147,6 +149,37 @@ def test_gpt_score() -> None: assert isinstance(result[1][1], list) +def test_embed() -> None: + """Test embedding pipeline.""" + model = TextGenerationModel( + model_name_or_path="gpt2", + use_accelerate=False, + use_parallelize=False, + use_bitsandbytes=False, + use_deepspeed=False, + use_fp16=False, + device=-1, + ) + inputs = ["Why is the sky green?", "Cats are butterflies"] + embeddings = model.embed(inputs) + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == (2, 768) + + model2 = SentenceTransformerModel( + model_name_or_path="all-mpnet-base-v2", + use_accelerate=False, + use_parallelize=False, + use_bitsandbytes=False, + use_deepspeed=False, + use_fp16=False, + device=-1, + ) + inputs = ["Why is the sky green?", "Cats are butterflies"] + embeddings = model2.embed(inputs) + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == (2, 768) + + def test_batch_gpt_generate() -> None: """Test pipeline generation from a gpt model.""" model = TextGenerationModel( diff --git a/tests/test_manifest.py b/tests/test_manifest.py index b6b99ba..f273787 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -551,6 +551,101 @@ def test_local_huggingface(sqlite_cache: str) -> None: ) +@pytest.mark.skipif(not MODEL_ALIVE, reason=f"No model at {URL}") +@pytest.mark.usefixtures("sqlite_cache") +def test_local_huggingfaceembedding(sqlite_cache: str) -> None: + """Test openaichat client.""" + client = Manifest( + client_name="huggingfaceembedding", + client_connection=URL, + cache_name="sqlite", + cache_connection=sqlite_cache, + ) + + res = client.run("Why are there carrots?") + assert isinstance(res, np.ndarray) + + response = cast( + Response, client.run("Why are there carrots?", return_response=True) + ) + assert isinstance(response.get_response(), np.ndarray) + assert np.allclose(response.get_response(), res) + + client = Manifest( + client_name="huggingfaceembedding", + client_connection=URL, + cache_name="sqlite", + cache_connection=sqlite_cache, + ) + + res = client.run("Why are there apples?") + assert isinstance(res, np.ndarray) + + response = cast(Response, client.run("Why are there apples?", return_response=True)) + assert isinstance(response.get_response(), np.ndarray) + assert np.allclose(response.get_response(), res) + assert response.is_cached() is True + + response = cast(Response, client.run("Why are there apples?", return_response=True)) + assert response.is_cached() is True + + res_list = client.run(["Why are there apples?", "Why are there bananas?"]) + assert ( + isinstance(res_list, list) + and len(res_list) == 2 + and isinstance(res_list[0], np.ndarray) + ) + + response = cast( + Response, + client.run( + ["Why are there apples?", "Why are there mangos?"], return_response=True + ), + ) + assert ( + isinstance(response.get_response(), list) and len(response.get_response()) == 2 + ) + + response = cast( + Response, client.run("Why are there bananas?", return_response=True) + ) + assert response.is_cached() is True + + response = cast( + Response, client.run("Why are there oranges?", return_response=True) + ) + assert response.is_cached() is False + + res_list = asyncio.run( + client.arun_batch(["Why are there pears?", "Why are there oranges?"]) + ) + assert ( + isinstance(res_list, list) + and len(res_list) == 2 + and isinstance(res_list[0], np.ndarray) + ) + + response = cast( + Response, + asyncio.run( + client.arun_batch( + ["Why are there pinenuts?", "Why are there cocoa?"], + return_response=True, + ) + ), + ) + assert ( + isinstance(response.get_response(), list) + and len(res_list) == 2 + and isinstance(res_list[0], np.ndarray) + ) + + response = cast( + Response, client.run("Why are there oranges?", return_response=True) + ) + assert response.is_cached() is True + + @pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set") @pytest.mark.usefixtures("sqlite_cache") def test_openai(sqlite_cache: str) -> None: