feat: openai embedding support (#75)

pull/82/head
Laurel Orr 1 year ago committed by GitHub
parent 693d105106
commit 40de0e7f59

@ -7,6 +7,7 @@ How to make prompt programming with Foundation Models a little easier.
- [Getting Started](#getting-started)
- [Manifest](#manifest-components)
- [Local HuggingFace Models](#local-huggingface-models)
- [Embedding Models](#embedding-models)
- [Development](#development)
- [Cite](#cite)
@ -47,6 +48,9 @@ manifest = Manifest(
manifest.run("Why is the grass green?")
```
## Examples
We have example notebook and python scripts located at [examples](examples). These show how to use different models, model types (i.e. text, diffusers, or embedding models), and async running.
# Manifest Components
Manifest is meant to be a very light weight package to help with prompt design and iteration. Three key design decisions of Manifest are
@ -112,6 +116,12 @@ You can also run over multiple examples if supported by the client.
results = manifest.run(["Where are the cats?", "Where are the dogs?"])
```
We support async queries as well via
```python
import asyncio
results = asyncio.run(manifest.arun_batch(["Where are the cats?", "Where are the dogs?"]))
```
If something doesn't go right, you can also ask to get a raw manifest Response.
```python
result_object = manifest.run(["Where are the cats?", "Where are the dogs?"], return_response=True)
@ -178,6 +188,23 @@ python3 -m manifest.api.app \
--percent_max_gpu_mem_reduction 0.85
```
# Embedding Models
Manifest also supports getting embeddings from models and available APIs. We do this all through changing the `client_name` argument. You still use `run` and `abatch_run`.
To use OpenAI's embedding models, simply run
```python
manifest = Manifest(client_name="openaiembedding")
embedding_as_np = manifest.run("Get me an embedding for a bunny")
```
As explained above, you can load local HuggingFace models that give you embeddings, too. If you want to use a standard generative model, load the model as above use use `client_name="huggingfaceembedding"`. If you want to use a standard embedding model, like those from SentenceTransformers, load your local model via
```bash
python3 -m manifest.api.app \
--model_type sentence_transformers \
--model_name_or_path all-mpnet-base-v2 \
--device 0
```
# Development
Before submitting a PR, run
```bash

@ -0,0 +1,139 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use OpenAI\n",
"\n",
"Set you `OPENAI_API_KEY` environment variable."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'model_name': 'openaiembedding', 'engine': 'text-embedding-ada-002'}\n"
]
}
],
"source": [
"from manifest import Manifest\n",
"\n",
"manifest = Manifest(client_name=\"openaiembedding\")\n",
"print(manifest.client.get_model_params())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1536,)\n"
]
}
],
"source": [
"emb = manifest.run(\"Is this an embedding?\")\n",
"print(emb.shape)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using Locally Hosted Huggingface LM\n",
"\n",
"Run\n",
"```\n",
"python3 manifest/api/app.py --model_type huggingface --model_name_or_path EleutherAI/gpt-neo-125M --device 0\n",
"```\n",
"in a separate `screen` or `tmux`."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'model_name': 'EleutherAI/gpt-neo-125M', 'model_path': 'EleutherAI/gpt-neo-125M'}\n"
]
}
],
"source": [
"from manifest import Manifest\n",
"\n",
"# Local hosted GPT Neo 125M\n",
"manifest = Manifest(\n",
" client_name=\"huggingfaceembedding\",\n",
" client_connection=\"http://127.0.0.1:6001\",\n",
" cache_name=\"sqlite\",\n",
" cache_connection=\"my_sqlite_manifest.sqlite\"\n",
")\n",
"print(manifest.client.get_model_params())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"emb = manifest.run(\"Is this an embedding?\")\n",
"emb2 = manifest.run(\"Is this an embedding?\", aggregation=\"mean\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "manifest",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -2,12 +2,14 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Union
from manifest.caches.serializers import ArraySerializer, Serializer
from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer, Serializer
from manifest.response import RESPONSE_CONSTRUCTORS, Response
CACHE_CONSTRUCTOR = {
"diffuser": ArraySerializer,
"tomadiffuser": ArraySerializer,
# Non-text return type caches
ARRAY_CACHE_TYPES = {
"diffuser",
"tomadiffuser",
"openaiembedding",
}
@ -21,17 +23,21 @@ class Cache(ABC):
cache_args: Dict[str, Any] = {},
):
"""
Initialize client.
Initialize cache.
Args:
connection_str: connection string.
client_name: name of client.
cache_args: arguments for cache.
cache_args are passed to client as default parameters.
cache_args are any arguments needed to initialize the cache.
For clients like OpenAI that do not require a connection,
the connection_str can be None.
Further, cache_args can contain `array_serializer` as a string
for embedding or image return types (e.g. diffusers) with values
as `local_file` or `byte_string`. `local_file` will save the
array in a local file and cache a pointer to the file.
`byte_string` will convert the array to a byte string and cache
the entire byte string. `byte_string` is default.
Args:
connection_str: connection string for client.
@ -39,7 +45,22 @@ class Cache(ABC):
"""
self.client_name = client_name
self.connect(connection_str, cache_args)
self.serializer = CACHE_CONSTRUCTOR.get(client_name, Serializer)()
if self.client_name in ARRAY_CACHE_TYPES:
array_serializer = cache_args.pop("array_serializer", "byte_string")
if array_serializer not in ["local_file", "byte_string"]:
raise ValueError(
"array_serializer must be local_file or byte_string,"
f" not {array_serializer}"
)
self.serializer = (
ArraySerializer()
if array_serializer == "local_file"
else NumpyByteSerializer()
)
else:
# If user has array_serializer type, it will throw an error as
# it is not recognized for non-array return types.
self.serializer = Serializer()
@abstractmethod
def close(self) -> None:
@ -107,7 +128,7 @@ class Cache(ABC):
response,
cached,
request,
**RESPONSE_CONSTRUCTORS.get(self.client_name, {})
**RESPONSE_CONSTRUCTORS.get(self.client_name, {}),
)
return None

@ -1,10 +1,12 @@
"""Serializer."""
import io
import json
import os
from pathlib import Path
from typing import Dict
import numpy as np
import xxhash
from manifest.caches.array_cache import ArrayCache
@ -62,6 +64,64 @@ class Serializer:
return json.loads(key)
class NumpyByteSerializer(Serializer):
"""Serializer by casting array to byte string."""
def response_to_key(self, response: Dict) -> str:
"""
Normalize a response into a key.
Args:
response: response to normalize.
Returns:
normalized key.
"""
# Assume response is a dict with keys "choices" -> List dicts
# with keys "array".
choices = response["choices"]
# We don't want to modify the response in place
# but we want to avoid calling deepcopy on an array
del response["choices"]
response_copy = response.copy()
response["choices"] = choices
response_copy["choices"] = []
for choice in choices:
if "array" not in choice:
raise ValueError(
f"Choice with keys {choice.keys()} does not have array key."
)
arr = choice["array"]
# Avoid copying an array
del choice["array"]
new_choice = choice.copy()
choice["array"] = arr
with io.BytesIO() as f:
np.savez_compressed(f, data=arr)
hash_str = f.getvalue().hex()
new_choice["array"] = hash_str
response_copy["choices"].append(new_choice)
return json.dumps(response_copy, sort_keys=True)
def key_to_response(self, key: str) -> Dict:
"""
Convert the normalized version to the response.
Args:
key: normalized key to convert.
Returns:
unnormalized response dict.
"""
response = json.loads(key)
for choice in response["choices"]:
hash_str = choice["array"]
byte_str = bytes.fromhex(hash_str)
with io.BytesIO(byte_str) as f:
choice["array"] = np.load(f)["data"]
return response
class ArraySerializer(Serializer):
"""Serializer for array."""

@ -92,7 +92,7 @@ class AI21Client(Client):
Returns:
model params.
"""
return {"model_name": "ai21", "engine": getattr(self, "engine")}
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""

@ -91,7 +91,7 @@ class CohereClient(Client):
Returns:
model params.
"""
return {"model_name": "cohere", "engine": getattr(self, "engine")}
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""

@ -1,12 +1,12 @@
"""OpenAI client."""
import logging
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type
import tiktoken
from manifest.clients.client import Client
from manifest.request import LMRequest
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
@ -41,7 +41,7 @@ class OpenAIClient(Client):
"presence_penalty": ("presence_penalty", 0.0),
"frequency_penalty": ("frequency_penalty", 0.0),
}
REQUEST_CLS = LMRequest
REQUEST_CLS: Type[Request] = LMRequest
NAME = "openai"
def connect(
@ -103,7 +103,7 @@ class OpenAIClient(Client):
Returns:
model params.
"""
return {"model_name": "openai", "engine": getattr(self, "engine")}
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
"""Split usage into list of usages for each prompt."""

@ -77,7 +77,7 @@ class OpenAIChatClient(OpenAIClient):
Returns:
model params.
"""
return {"model_name": "openaichat", "engine": getattr(self, "engine")}
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def _format_request_for_chat(self, request_params: Dict[str, Any]) -> Dict:
"""Format request params for chat.
@ -99,8 +99,8 @@ class OpenAIChatClient(OpenAIClient):
request_params["messages"] = messages
return request_params
def _format_request_for_text(self, response_dict: Dict[str, Any]) -> Dict:
"""Format response for text.
def _format_request_from_chat(self, response_dict: Dict[str, Any]) -> Dict:
"""Format response for standard response from chat.
Args:
response_dict: response.
@ -131,7 +131,7 @@ class OpenAIChatClient(OpenAIClient):
request_params = self._format_request_for_chat(request_params)
response_dict = super()._run_completion(request_params, retry_timeout)
# Reformat for text model
response_dict = self._format_request_for_text(response_dict)
response_dict = self._format_request_from_chat(response_dict)
return response_dict
async def _arun_completion(
@ -153,5 +153,5 @@ class OpenAIChatClient(OpenAIClient):
request_params, retry_timeout, batch_size
)
# Reformat for text model
response_dict = self._format_request_for_text(response_dict)
response_dict = self._format_request_from_chat(response_dict)
return response_dict

@ -0,0 +1,204 @@
"""OpenAI client."""
import copy
import logging
import os
from typing import Any, Dict, List, Optional
import numpy as np
import tiktoken
from manifest.clients.openai import OpenAIClient
from manifest.request import EmbeddingRequest
logger = logging.getLogger(__name__)
OPENAI_EMBEDDING_ENGINES = {
"text-embedding-ada-002",
}
class OpenAIEmbeddingClient(OpenAIClient):
"""OpenAI client."""
# User param -> (client param, default value)
PARAMS = {
"engine": ("model", "text-embedding-ada-002"),
}
REQUEST_CLS = EmbeddingRequest
NAME = "openaiembedding"
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the OpenAI server.
connection_str is passed as default OPENAI_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
self.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
if self.api_key is None:
raise ValueError(
"OpenAI API key not set. Set OPENAI_API_KEY environment "
"variable or pass through `client_connection`."
)
self.host = "https://api.openai.com/v1"
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAI_EMBEDDING_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. "
f"Must be {OPENAI_EMBEDDING_ENGINES}."
)
def get_generation_url(self) -> str:
"""Get generation URL."""
return self.host + "/embeddings"
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return True
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict
"""
if "data" not in response:
raise ValueError(f"Invalid response: {response}")
if "usage" in response:
# Handle splitting the usages for batch requests
if len(response["data"]) == 1:
if isinstance(response["usage"], list):
response["usage"] = response["usage"][0]
response["usage"] = [response["usage"]]
else:
# Try to split usage
split_usage = self.split_usage(request, response["data"])
if split_usage:
response["usage"] = split_usage
return response
def _format_request_for_embedding(self, request_params: Dict[str, Any]) -> Dict:
"""Format request params for embedding.
Args:
request_params: request params.
Returns:
formatted request params.
"""
# Format for embedding model
request_params = copy.deepcopy(request_params)
prompt = request_params.pop("prompt")
if isinstance(prompt, str):
prompt_list = [prompt]
else:
prompt_list = prompt
request_params["input"] = prompt_list
return request_params
def _format_request_from_embedding(self, response_dict: Dict[str, Any]) -> Dict:
"""Format response from embedding for standard response.
Args:
response_dict: response.
Return:
formatted response.
"""
new_choices = []
response_dict = copy.deepcopy(response_dict)
for res in response_dict.pop("data"):
new_choices.append({"array": np.array(res["embedding"])})
response_dict["choices"] = new_choices
return response_dict
def _run_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
Returns:
response as dict.
"""
# Format for embedding model
request_params = self._format_request_for_embedding(request_params)
response_dict = super()._run_completion(request_params, retry_timeout)
# Reformat for text model
response_dict = self._format_request_from_embedding(response_dict)
return response_dict
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
batch_size: batch size for requests.
Returns:
response as dict.
"""
# Format for embedding model
request_params = self._format_request_for_embedding(request_params)
response_dict = await super()._arun_completion(
request_params, retry_timeout, batch_size
)
# Reformat for text model
response_dict = self._format_request_from_embedding(response_dict)
return response_dict
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
"""Split usage into list of usages for each prompt."""
try:
encoding = tiktoken.encoding_for_model(getattr(self, "engine"))
except Exception:
return []
prompt = request["input"]
if isinstance(prompt, str):
prompts = [prompt]
else:
prompts = prompt
assert len(prompts) == len(choices)
usages = []
for pmt in prompts:
pmt_tokens = len(encoding.encode(pmt))
# No completion tokens for embedding models
chc_tokens = 0
usage = {
"prompt_tokens": pmt_tokens,
"completion_tokens": chc_tokens,
"total_tokens": pmt_tokens + chc_tokens,
}
usages.append(usage)
return usages

@ -121,7 +121,7 @@ class TOMAClient(Client):
Returns:
model params.
"""
return {"model_name": "toma", "engine": getattr(self, "engine")}
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def get_model_heartbeats(self) -> Dict[str, Dict]:
"""

@ -44,7 +44,7 @@ class TOMADiffuserClient(TOMAClient):
Returns:
model params.
"""
return {"model_name": "tomadiffuser", "engine": getattr(self, "engine")}
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""

@ -14,7 +14,8 @@ from manifest.clients.cohere import CohereClient
from manifest.clients.dummy import DummyClient
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.openai import OpenAIClient
from manifest.clients.openaichat import OpenAIChatClient
from manifest.clients.openai_chat import OpenAIChatClient
from manifest.clients.openai_embedding import OpenAIEmbeddingClient
from manifest.clients.toma import TOMAClient
from manifest.request import Request
from manifest.response import Response
@ -25,6 +26,7 @@ logger = logging.getLogger(__name__)
CLIENT_CONSTRUCTORS = {
OpenAIClient.NAME: OpenAIClient,
OpenAIChatClient.NAME: OpenAIChatClient,
OpenAIEmbeddingClient.NAME: OpenAIEmbeddingClient,
CohereClient.NAME: CohereClient,
AI21Client.NAME: AI21Client,
HuggingFaceClient.NAME: HuggingFaceClient,

@ -1,5 +1,5 @@
"""Request object."""
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel
@ -96,6 +96,13 @@ class LMRequest(Request):
frequency_penalty: float = 0
class EmbeddingRequest(Request):
"""Embedding Request object."""
# Aggregate method (if applicable)
aggregation_method: Optional[Literal["last_token", "mean"]] = None
class DiffusionRequest(Request):
"""Diffusion Model Request object."""

@ -13,6 +13,10 @@ RESPONSE_CONSTRUCTORS = {
"logits_key": "token_logprobs",
"item_key": "array",
},
"openaiembedding": {
"logits_key": "token_logprobs",
"item_key": "array",
},
}

@ -1,5 +1,5 @@
"""Cache test."""
from typing import cast
from typing import Dict, cast
import numpy as np
import pytest
@ -13,12 +13,15 @@ from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
def _get_postgres_cache(**kwargs) -> Cache: # type: ignore
def _get_postgres_cache(
client_name: str = "", cache_args: Dict = {}
) -> Cache: # type: ignore
"""Get postgres cache."""
cache_args.update({"cache_user": "", "cache_password": "", "cache_db": ""})
return PostgresCache(
"postgres",
cache_args={"cache_user": "", "cache_password": "", "cache_db": ""},
**kwargs,
client_name=client_name,
cache_args=cache_args,
)
@ -96,11 +99,11 @@ def test_get(
assert response.is_cached()
assert response.get_request() == test_request
# Test array
arr = np.random.rand(4, 4)
test_request = {"test": "hello", "testA": "world of images"}
compute_arr_response = {"choices": [{"array": arr}]}
# Test array
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, client_name="diffuser")
elif cache_type == "redis":
@ -117,6 +120,37 @@ def test_get(
assert response.is_cached()
assert response.get_request() == test_request
# Test array byte string
arr = np.random.rand(4, 4)
test_request = {"test": "hello", "testA": "world of images 2"}
compute_arr_response = {"choices": [{"array": arr}]}
if cache_type == "sqlite":
cache = SQLiteCache(
sqlite_cache,
client_name="diffuser",
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "redis":
cache = RedisCache(
redis_cache,
client_name="diffuser",
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "postgres":
cache = _get_postgres_cache(
client_name="diffuser", cache_args={"array_serializer": "byte_string"}
)
response = cache.get(test_request)
assert response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response(), arr)
assert response.is_cached()
assert response.get_request() == test_request
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@ -168,6 +202,39 @@ def test_get_batch_prompt(
assert response.is_cached()
assert response.get_request() == test_request
# Test arrays byte serializer
arr = np.random.rand(4, 4)
arr2 = np.random.rand(4, 4)
test_request = {"test": ["hello", "goodbye"], "testA": "world of images 2"}
compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]}
if cache_type == "sqlite":
cache = SQLiteCache(
sqlite_cache,
client_name="diffuser",
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "redis":
cache = RedisCache(
redis_cache,
client_name="diffuser",
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "postgres":
cache = _get_postgres_cache(
client_name="diffuser", cache_args={"array_serializer": "byte_string"}
)
response = cache.get(test_request)
assert response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response()[0], arr)
assert np.allclose(response.get_response()[1], arr2)
assert response.is_cached()
assert response.get_request() == test_request
def test_noop_cache() -> None:
"""Test cache that is a no-op cache."""

@ -4,6 +4,7 @@ import os
from typing import cast
from unittest.mock import MagicMock, Mock, patch
import numpy as np
import pytest
import requests
from requests import HTTPError
@ -567,6 +568,7 @@ def test_openai(sqlite_cache: str) -> None:
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
assert response.get_response() == res
assert response.is_cached() is True
assert "usage" in response.get_json_response()
assert response.get_json_response()["usage"][0]["total_tokens"] == 15
@ -643,6 +645,7 @@ def test_openaichat(sqlite_cache: str) -> None:
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
assert response.get_response() == res
assert response.is_cached() is True
assert "usage" in response.get_json_response()
assert response.get_json_response()["usage"][0]["total_tokens"] == 23
@ -685,6 +688,114 @@ def test_openaichat(sqlite_cache: str) -> None:
assert response.is_cached() is True
@pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set")
@pytest.mark.usefixtures("sqlite_cache")
def test_openaiembedding(sqlite_cache: str) -> None:
"""Test openaichat client."""
client = Manifest(
client_name="openaiembedding",
cache_name="sqlite",
cache_connection=sqlite_cache,
array_serializer="local_file",
)
res = client.run("Why are there carrots?")
assert isinstance(res, np.ndarray)
response = cast(
Response, client.run("Why are there carrots?", return_response=True)
)
assert isinstance(response.get_response(), np.ndarray)
assert np.allclose(response.get_response(), res)
client = Manifest(
client_name="openaiembedding",
cache_name="sqlite",
cache_connection=sqlite_cache,
)
res = client.run("Why are there apples?")
assert isinstance(res, np.ndarray)
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert isinstance(response.get_response(), np.ndarray)
assert np.allclose(response.get_response(), res)
assert response.is_cached() is True
assert "usage" in response.get_json_response()
assert response.get_json_response()["usage"][0]["total_tokens"] == 5
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert response.is_cached() is True
res_list = client.run(["Why are there apples?", "Why are there bananas?"])
assert (
isinstance(res_list, list)
and len(res_list) == 2
and isinstance(res_list[0], np.ndarray)
)
response = cast(
Response,
client.run(
["Why are there apples?", "Why are there mangos?"], return_response=True
),
)
assert (
isinstance(response.get_response(), list) and len(response.get_response()) == 2
)
assert (
"usage" in response.get_json_response()
and len(response.get_json_response()["usage"]) == 2
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 5
assert response.get_json_response()["usage"][1]["total_tokens"] == 6
response = cast(
Response, client.run("Why are there bananas?", return_response=True)
)
assert response.is_cached() is True
response = cast(
Response, client.run("Why are there oranges?", return_response=True)
)
assert response.is_cached() is False
res_list = asyncio.run(
client.arun_batch(["Why are there pears?", "Why are there oranges?"])
)
assert (
isinstance(res_list, list)
and len(res_list) == 2
and isinstance(res_list[0], np.ndarray)
)
response = cast(
Response,
asyncio.run(
client.arun_batch(
["Why are there pinenuts?", "Why are there cocoa?"],
return_response=True,
)
),
)
assert (
isinstance(response.get_response(), list)
and len(res_list) == 2
and isinstance(res_list[0], np.ndarray)
)
assert (
"usage" in response.get_json_response()
and len(response.get_json_response()["usage"]) == 2
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 7
assert response.get_json_response()["usage"][1]["total_tokens"] == 5
response = cast(
Response, client.run("Why are there oranges?", return_response=True)
)
assert response.is_cached() is True
def test_retry_handling() -> None:
"""Test retry handling."""
# We'll mock the response so we won't need a real connection

@ -3,10 +3,10 @@ import json
import numpy as np
from manifest.caches.serializers import ArraySerializer
from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer
def test_response_to_key() -> None:
def test_response_to_key_array() -> None:
"""Test array serializer initialization."""
serializer = ArraySerializer()
arr = np.random.rand(4, 4)
@ -17,3 +17,16 @@ def test_response_to_key() -> None:
res2 = serializer.key_to_response(key)
assert np.allclose(arr, res2["choices"][0]["array"])
def test_response_to_key_numpybytes() -> None:
"""Test array serializer initialization."""
serializer = NumpyByteSerializer()
arr = np.random.rand(4, 4)
res = {"choices": [{"array": arr}]}
key = serializer.response_to_key(res)
key_dct = json.loads(key)
assert isinstance(key_dct["choices"][0]["array"], str)
res2 = serializer.key_to_response(key)
assert np.allclose(arr, res2["choices"][0]["array"])

Loading…
Cancel
Save