fix: added pydantic types to response (#84)

pull/85/head
Laurel Orr 1 year ago committed by GitHub
parent 4602fb919b
commit d7401c6ec5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,6 +8,7 @@ Added
Fixed Fixed
^^^^^ ^^^^^
* Determine cache and response by request type, not client name * Determine cache and response by request type, not client name
* Refactor Response to use Pydantic types for Request and Response
0.1.1 0.1.1
--------------------- ---------------------

@ -126,9 +126,9 @@ results = asyncio.run(manifest.arun_batch(["Where are the cats?", "Where are the
If something doesn't go right, you can also ask to get a raw manifest Response. If something doesn't go right, you can also ask to get a raw manifest Response.
```python ```python
result_object = manifest.run(["Where are the cats?", "Where are the dogs?"], return_response=True) result_object = manifest.run(["Where are the cats?", "Where are the dogs?"], return_response=True)
print(result_object.get_request()) print(result_object.get_request_obj())
print(result_object.is_cached()) print(result_object.is_cached())
print(result_object.get_json_response()) print(result_object.get_response_obj())
``` ```
By default, we do not truncate results based on a stop token. You can change this by either passing a new stop token to a Manifest session or to a `run`. By default, we do not truncate results based on a stop token. You can change this by either passing a new stop token to a Manifest session or to a `run`.

@ -12,7 +12,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -32,7 +32,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -56,7 +56,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -89,7 +89,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -100,7 +100,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -131,23 +131,23 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 28, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"For loop: 229.93\n", "For loop: 128.68\n",
"Running with async single client\n", "Running with async single client\n",
"Running 1 tasks across all clients.\n", "Running 1 tasks across all clients.\n",
"Async loop: 1.39\n", "Async loop: 4.02\n",
"Running with async two clients but not chunking\n", "Running with async two clients but not chunking\n",
"Running 1 tasks across all clients.\n", "Running 1 tasks across all clients.\n",
"Async loop: 1.65\n", "Async loop: 3.92\n",
"Running with async two clients and chunk size\n", "Running with async two clients and chunk size\n",
"Running 20 tasks across all clients.\n", "Running 20 tasks across all clients.\n",
"Async loop: 0.64\n" "Async loop: 1.44\n"
] ]
} }
], ],

@ -4,7 +4,7 @@ from typing import Any, Dict, Type, Union
from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer, Serializer from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer, Serializer
from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest, Request from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest, Request
from manifest.response import RESPONSE_CONSTRUCTORS, Response from manifest.response import Response
# Non-text return type caches # Non-text return type caches
ARRAY_CACHE_TYPES = {EmbeddingRequest, DiffusionRequest} ARRAY_CACHE_TYPES = {EmbeddingRequest, DiffusionRequest}
@ -119,14 +119,9 @@ class Cache(ABC):
key = self.serializer.request_to_key(request) key = self.serializer.request_to_key(request)
cached_response = self.get_key(key) cached_response = self.get_key(key)
if cached_response: if cached_response:
cached = True
response = self.serializer.key_to_response(cached_response) response = self.serializer.key_to_response(cached_response)
return Response( response["cached"] = True
response, return Response.from_dict(response, request_dict=request)
cached,
request,
**RESPONSE_CONSTRUCTORS.get(self.request_type, {}),
)
return None return None
def set(self, request: Dict, response: Dict) -> None: def set(self, request: Dict, response: Dict) -> None:

@ -77,14 +77,15 @@ class NumpyByteSerializer(Serializer):
Returns: Returns:
normalized key. normalized key.
""" """
sub_response = response["response"]
# Assume response is a dict with keys "choices" -> List dicts # Assume response is a dict with keys "choices" -> List dicts
# with keys "array". # with keys "array".
choices = response["choices"] choices = sub_response["choices"]
# We don't want to modify the response in place # We don't want to modify the response in place
# but we want to avoid calling deepcopy on an array # but we want to avoid calling deepcopy on an array
del response["choices"] del sub_response["choices"]
response_copy = response.copy() response_copy = sub_response.copy()
response["choices"] = choices sub_response["choices"] = choices
response_copy["choices"] = [] response_copy["choices"] = []
for choice in choices: for choice in choices:
if "array" not in choice: if "array" not in choice:
@ -101,7 +102,8 @@ class NumpyByteSerializer(Serializer):
hash_str = f.getvalue().hex() hash_str = f.getvalue().hex()
new_choice["array"] = hash_str new_choice["array"] = hash_str
response_copy["choices"].append(new_choice) response_copy["choices"].append(new_choice)
return json.dumps(response_copy, sort_keys=True) response["response"] = response_copy
return json.dumps(response, sort_keys=True)
def key_to_response(self, key: str) -> Dict: def key_to_response(self, key: str) -> Dict:
""" """
@ -114,7 +116,7 @@ class NumpyByteSerializer(Serializer):
unnormalized response dict. unnormalized response dict.
""" """
response = json.loads(key) response = json.loads(key)
for choice in response["choices"]: for choice in response["response"]["choices"]:
hash_str = choice["array"] hash_str = choice["array"]
byte_str = bytes.fromhex(hash_str) byte_str = bytes.fromhex(hash_str)
with io.BytesIO(byte_str) as f: with io.BytesIO(byte_str) as f:
@ -152,14 +154,15 @@ class ArraySerializer(Serializer):
Returns: Returns:
normalized key. normalized key.
""" """
sub_response = response["response"]
# Assume response is a dict with keys "choices" -> List dicts # Assume response is a dict with keys "choices" -> List dicts
# with keys "array". # with keys "array".
choices = response["choices"] choices = sub_response["choices"]
# We don't want to modify the response in place # We don't want to modify the response in place
# but we want to avoid calling deepcopy on an array # but we want to avoid calling deepcopy on an array
del response["choices"] del sub_response["choices"]
response_copy = response.copy() response_copy = sub_response.copy()
response["choices"] = choices sub_response["choices"] = choices
response_copy["choices"] = [] response_copy["choices"] = []
for choice in choices: for choice in choices:
if "array" not in choice: if "array" not in choice:
@ -179,7 +182,8 @@ class ArraySerializer(Serializer):
response_copy["choices"].append(new_choice) response_copy["choices"].append(new_choice)
if not self.writer.contains_key(hash_str): if not self.writer.contains_key(hash_str):
self.writer.put(hash_str, arr) self.writer.put(hash_str, arr)
return json.dumps(response_copy, sort_keys=True) response["response"] = response_copy
return json.dumps(response, sort_keys=True)
def key_to_response(self, key: str) -> Dict: def key_to_response(self, key: str) -> Dict:
""" """
@ -194,7 +198,7 @@ class ArraySerializer(Serializer):
unnormalized response dict. unnormalized response dict.
""" """
response = json.loads(key) response = json.loads(key)
for choice in response["choices"]: for choice in response["response"]["choices"]:
hash_str = choice["array"] hash_str = choice["array"]
choice["array"] = self.writer.get(hash_str) choice["array"] = self.writer.get(hash_str)
return response return response

@ -94,7 +94,7 @@ class AI21Client(Client):
""" """
return {"model_name": self.NAME, "engine": getattr(self, "engine")} return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
""" """
Format response to dict. Format response to dict.

@ -10,8 +10,21 @@ import aiohttp
import requests import requests
from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential
from manifest.request import DEFAULT_REQUEST_KEYS, NOT_CACHE_KEYS, Request from manifest.request import (
from manifest.response import RESPONSE_CONSTRUCTORS, Response DEFAULT_REQUEST_KEYS,
NOT_CACHE_KEYS,
LMScoreRequest,
Request,
)
from manifest.response import (
RESPONSE_CONSTRUCTORS,
ArrayModelChoice,
LMModelChoice,
ModelChoices,
Response,
Usage,
Usages,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -161,16 +174,30 @@ class Client(ABC):
request_params = self.get_request_params(request) request_params = self.get_request_params(request)
for key in NOT_CACHE_KEYS: for key in NOT_CACHE_KEYS:
request_params.pop(key, None) request_params.pop(key, None)
# Make sure to add model params and request class
request_params.update(self.get_model_params()) request_params.update(self.get_model_params())
request_params["request_cls"] = request.__class__.__name__
return request_params return request_params
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]: def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
"""Split usage into list of usages for each prompt.""" """Split usage into list of usages for each prompt."""
return [] return []
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: def get_model_choices(self, response: Dict) -> ModelChoices:
"""Format response to ModelChoices."""
# Array or text response
response_type = RESPONSE_CONSTRUCTORS[self.REQUEST_CLS]["response_type"]
if response_type == "array":
choices: List[Union[LMModelChoice, ArrayModelChoice]] = [
ArrayModelChoice(**choice) for choice in response["choices"]
]
else:
choices = [LMModelChoice(**choice) for choice in response["choices"]]
return ModelChoices(choices=choices)
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
""" """
Format response to dict. Validate response as dict.
Args: Args:
response: response response: response
@ -246,7 +273,7 @@ class Client(ABC):
except requests.exceptions.HTTPError: except requests.exceptions.HTTPError:
logger.error(res.json()) logger.error(res.json())
raise requests.exceptions.HTTPError(res.json()) raise requests.exceptions.HTTPError(res.json())
return self.format_response(res.json(), request_params) return self.validate_response(res.json(), request_params)
@retry( @retry(
reraise=True, reraise=True,
@ -277,7 +304,7 @@ class Client(ABC):
) as res: ) as res:
res.raise_for_status() res.raise_for_status()
res_json = await res.json(content_type=None) res_json = await res.json(content_type=None)
return self.format_response(res_json, request_params) return self.validate_response(res_json, request_params)
def run_request(self, request: Request) -> Response: def run_request(self, request: Request) -> Response:
""" """
@ -301,11 +328,16 @@ class Client(ABC):
for key in DEFAULT_REQUEST_KEYS: for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None) request_params.pop(key, None)
response_dict = self._run_completion(request_params, retry_timeout) response_dict = self._run_completion(request_params, retry_timeout)
usages = None
if "usage" in response_dict:
usages = [Usage(**usage) for usage in response_dict["usage"]]
return Response( return Response(
response_dict, response=self.get_model_choices(response_dict),
cached=False, cached=False,
request_params=request_params, request=request,
**RESPONSE_CONSTRUCTORS.get(self.REQUEST_CLS, {}), # type: ignore usages=Usages(usages=usages) if usages else None,
**RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
) )
async def arun_batch_request(self, request: Request) -> Response: async def arun_batch_request(self, request: Request) -> Response:
@ -353,18 +385,20 @@ class Client(ABC):
if "usage" in res_dict: if "usage" in res_dict:
usages.extend(res_dict["usage"]) usages.extend(res_dict["usage"])
final_response_dict = {"choices": choices} final_response_dict = {"choices": choices}
final_usages = None
if usages: if usages:
final_response_dict["usage"] = usages final_usages = Usages(usages=[Usage(**usage) for usage in usages])
return Response( return Response(
final_response_dict, self.get_model_choices(final_response_dict),
cached=False, cached=False,
request_params=request_params, request=request,
**RESPONSE_CONSTRUCTORS.get(self.REQUEST_CLS, {}), # type: ignore usages=final_usages,
**RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
) )
def get_score_prompt_request( def get_score_prompt_request(
self, self,
request: Request, request: LMScoreRequest,
) -> Response: ) -> Response:
""" """
Get the logit score of the prompt via a forward pass of the model. Get the logit score of the prompt via a forward pass of the model.

@ -93,7 +93,7 @@ class CohereClient(Client):
""" """
return {"model_name": self.NAME, "engine": getattr(self, "engine")} return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
""" """
Format response to dict. Format response to dict.

@ -86,7 +86,7 @@ class DiffuserClient(Client):
res["client_name"] = self.NAME res["client_name"] = self.NAME
return res return res
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
""" """
Format response to dict. Format response to dict.

@ -3,8 +3,8 @@ import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from manifest.clients.client import Client from manifest.clients.client import Client
from manifest.request import LMRequest, Request from manifest.request import LMRequest, LMScoreRequest, Request
from manifest.response import Response from manifest.response import LMModelChoice, ModelChoices, Response, Usage, Usages
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -86,15 +86,30 @@ class DummyClient(Client):
num_results = 1 num_results = 1
request_params = request.to_dict(self.PARAMS) request_params = request.to_dict(self.PARAMS)
response_dict = { return Response(
"choices": [{"text": "hello"}] response=ModelChoices(
* int(request_params["num_results"]) choices=[LMModelChoice(text="hello")] # type: ignore
* num_results, * int(request_params["num_results"])
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * num_results
* int(request_params["num_results"]) ),
* num_results, cached=False,
} request=request,
return Response(response_dict, False, request_params) usages=Usages(
usages=[
Usage(
**{
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
}
)
]
* int(request_params["num_results"])
* num_results
),
response_type="text",
request_type=self.REQUEST_CLS,
)
async def arun_batch_request(self, request: Request) -> Response: async def arun_batch_request(self, request: Request) -> Response:
""" """
@ -110,7 +125,7 @@ class DummyClient(Client):
def get_score_prompt_request( def get_score_prompt_request(
self, self,
request: Request, request: LMScoreRequest,
) -> Response: ) -> Response:
""" """
Get the logit score of the prompt via a forward pass of the model. Get the logit score of the prompt via a forward pass of the model.
@ -126,17 +141,27 @@ class DummyClient(Client):
num_results = len(request.prompt) num_results = len(request.prompt)
else: else:
num_results = 1 num_results = 1
request_params = {"prompt": request.prompt}
response_dict = { response_dict = {
"choices": [ "choices": [
{ {
"text": request.prompt "text": request.prompt
if isinstance(request.prompt, str) if isinstance(request.prompt, str)
else request.prompt[i], else request.prompt[i],
"logprob": 0.3, "token_logprobs": [0.3],
} }
for i in range(num_results) for i in range(num_results)
] ]
} }
return Response(response_dict, False, request_params) return Response(
response=ModelChoices(
choices=[
LMModelChoice(**choice) # type: ignore
for choice in response_dict["choices"]
]
),
cached=False,
request=request,
usages=None,
response_type="text",
request_type=LMScoreRequest,
)

@ -5,8 +5,8 @@ from typing import Any, Dict, Optional
import requests import requests
from manifest.clients.client import Client from manifest.clients.client import Client
from manifest.request import DEFAULT_REQUEST_KEYS, LMRequest, Request from manifest.request import DEFAULT_REQUEST_KEYS, LMRequest, LMScoreRequest
from manifest.response import Response from manifest.response import LMModelChoice, ModelChoices, Response
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -82,7 +82,7 @@ class HuggingFaceClient(Client):
def get_score_prompt_request( def get_score_prompt_request(
self, self,
request: Request, request: LMScoreRequest,
) -> Response: ) -> Response:
""" """
Get the logit score of the prompt via a forward pass of the model. Get the logit score of the prompt via a forward pass of the model.
@ -116,4 +116,13 @@ class HuggingFaceClient(Client):
logger.error(res.text) logger.error(res.text)
raise e raise e
response_dict = res.json() response_dict = res.json()
return Response(response_dict, cached=False, request_params=request_params) return Response(
response=ModelChoices(
choices=[LMModelChoice(**choice) for choice in response_dict["choices"]]
),
cached=False,
request=request,
usages=None,
response_type="text",
request_type=LMScoreRequest,
)

@ -72,7 +72,7 @@ class HuggingFaceEmbeddingClient(Client):
res["client_name"] = self.NAME res["client_name"] = self.NAME
return res return res
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
""" """
Format response to dict. Format response to dict.

@ -37,6 +37,7 @@ class OpenAIClient(Client):
"n": ("n", 1), "n": ("n", 1),
"top_p": ("top_p", 1.0), "top_p": ("top_p", 1.0),
"top_k": ("best_of", 1), "top_k": ("best_of", 1),
"logprobs": ("logprobs", None),
"stop_sequences": ("stop", None), # OpenAI doesn't like empty lists "stop_sequences": ("stop", None), # OpenAI doesn't like empty lists
"presence_penalty": ("presence_penalty", 0.0), "presence_penalty": ("presence_penalty", 0.0),
"frequency_penalty": ("frequency_penalty", 0.0), "frequency_penalty": ("frequency_penalty", 0.0),

@ -76,7 +76,7 @@ class OpenAIEmbeddingClient(OpenAIClient):
""" """
return {"model_name": self.NAME, "engine": getattr(self, "engine")} return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
""" """
Format response to dict. Format response to dict.

@ -143,7 +143,7 @@ class TOMAClient(Client):
} }
return heartbeats return heartbeats
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
""" """
Format response to dict. Format response to dict.

@ -46,7 +46,7 @@ class TOMADiffuserClient(TOMAClient):
""" """
return {"model_name": self.NAME, "engine": getattr(self, "engine")} return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
""" """
Format response to dict. Format response to dict.

@ -2,7 +2,7 @@
import asyncio import asyncio
import copy import copy
import logging import logging
from typing import Any, Dict, List, Optional, Tuple, Union, cast from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
import numpy as np import numpy as np
@ -17,8 +17,8 @@ from manifest.connections.client_pool import (
ClientConnection, ClientConnection,
ClientConnectionPool, ClientConnectionPool,
) )
from manifest.request import Request from manifest.request import LMScoreRequest, Request
from manifest.response import Response from manifest.response import ModelChoices, Response, Usage, Usages
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -178,82 +178,72 @@ class Manifest:
number_prompts = len(cached_idx_to_response) number_prompts = len(cached_idx_to_response)
single_output = False single_output = False
if response: if response:
if isinstance(response.get_request()["prompt"], str): if isinstance(response.get_request_obj().prompt, str):
single_output = True single_output = True
number_prompts += 1 number_prompts += 1
else: else:
number_prompts += len(response.get_request()["prompt"]) number_prompts += len(response.get_request_obj().prompt)
response_gen_key = None response_type = None
response_logits_key = None request_type: Type[Request] = None
response_item_key = None
for idx in range(number_prompts): for idx in range(number_prompts):
if idx in cached_idx_to_response: if idx in cached_idx_to_response:
cached_res = cached_idx_to_response[idx] cached_res = cached_idx_to_response[idx]
response_gen_key = cached_res.generation_key response_type = cached_res._response_type
response_logits_key = cached_res.logits_key request_type = cached_res._request_type
response_item_key = cached_res.item_key all_input_prompts.append(cached_res.get_request_obj().prompt)
response_usage_key = cached_res.usage_key
all_input_prompts.append(cached_res.get_request()["prompt"])
json_response = cached_res.get_json_response()
if request.n == 1: if request.n == 1:
assert ( assert (
len(json_response[response_gen_key]) == 1 len(cached_res.get_response_obj().choices) == 1
), "cached response should have only one choice" ), "cached response should have only one choice"
all_model_choices.extend(json_response[response_gen_key]) all_model_choices.extend(cached_res.get_response_obj().choices)
if response_usage_key: if cached_res.get_usage_obj().usages:
all_usages.extend(json_response[response_usage_key]) all_usages.extend(cached_res.get_usage_obj().usages)
else: else:
assert response is not None, "response should not be None" assert response is not None, "response should not be None"
response = cast(Response, response) response = cast(Response, response)
response_gen_key = response.generation_key response_type = response._response_type
response_logits_key = response.logits_key request_type = response._request_type
response_item_key = response.item_key
response_usage_key = response.usage_key
# the choices list in the response is a flat one. # the choices list in the response is a flat one.
# length is request.n * num_prompts # length is request.n * num_prompts
current_choices = response.get_json_response()[response_gen_key][ current_choices = response.get_response_obj().choices[
response_idx * request.n : (response_idx + 1) * request.n response_idx * request.n : (response_idx + 1) * request.n
] ]
all_model_choices.extend(current_choices) all_model_choices.extend(current_choices)
if isinstance(response.get_request()["prompt"], list): if isinstance(response.get_request_obj().prompt, list):
prompt = response.get_request()["prompt"][response_idx] prompt = response.get_request_obj().prompt[response_idx]
else: else:
prompt = str(response.get_request()["prompt"]) prompt = str(response.get_request_obj().prompt)
if response_usage_key: usages: Optional[List[Usage]] = None
usage = response.get_json_response()[response_usage_key][ if response.get_usage_obj().usages:
usages = response.get_usage_obj().usages[
response_idx * request.n : (response_idx + 1) * request.n response_idx * request.n : (response_idx + 1) * request.n
] ]
all_usages.extend(usage) all_usages.extend(usages)
all_input_prompts.append(prompt) all_input_prompts.append(prompt)
# set cache # set cache
new_request = copy.deepcopy(request) new_request = copy.deepcopy(request)
new_request.prompt = prompt new_request.prompt = prompt
cache_key = client.get_cache_key(new_request) cache_key = client.get_cache_key(new_request)
new_response_key = copy.deepcopy(response.get_json_response()) new_response = copy.deepcopy(response)
new_response_key[response_gen_key] = current_choices new_response._response.choices = current_choices
if response_usage_key: new_response._usages = Usages(usages=(usages or []))
new_response_key[response_usage_key] = usage self.cache.set(cache_key, new_response.to_dict(drop_request=True))
self.cache.set(cache_key, new_response_key)
response_idx += 1 response_idx += 1
new_request = copy.deepcopy(request) new_request = copy.deepcopy(request)
new_request.prompt = ( new_request.prompt = (
all_input_prompts all_input_prompts # type: ignore
if len(all_input_prompts) > 1 or not single_output if len(all_input_prompts) > 1 or not single_output
else all_input_prompts[0] else all_input_prompts[0]
) )
new_response = {response_gen_key: all_model_choices}
if response_usage_key:
new_response[response_usage_key] = all_usages
response_obj = Response( response_obj = Response(
new_response, response=ModelChoices(choices=all_model_choices),
cached=len(cached_idx_to_response) > 0, cached=len(cached_idx_to_response) > 0,
request_params=client.get_cache_key(new_request), request=new_request,
generation_key=response_gen_key, usages=Usages(usages=all_usages),
logits_key=response_logits_key, response_type=response_type,
item_key=response_item_key, request_type=request_type,
usage_key=response_usage_key,
) )
return response_obj return response_obj
@ -457,20 +447,20 @@ class Manifest:
client = self.client_pool.get_client() client = self.client_pool.get_client()
# Must pass kwargs as dict for client "pop" methods removed used arguments # Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = client.get_request(prompt, kwargs) request_params = client.get_request(prompt, kwargs)
request_params.request_type = "score_prompt" request_params_as_score = LMScoreRequest(**request_params.to_dict())
if request_params.n > 1: if request_params_as_score.n > 1:
raise ValueError("Sequence scoring does not support n > 1.") raise ValueError("Sequence scoring does not support n > 1.")
self._validate_kwargs(kwargs, request_params) self._validate_kwargs(kwargs, request_params_as_score)
cached_idx_to_response, request_params = self._split_cached_requests( cached_idx_to_response, request_params_as_score = self._split_cached_requests( # type: ignore # noqa: E501
request_params, client, overwrite_cache request_params_as_score, client, overwrite_cache
) )
# If not None value or empty list - run new request # If not None value or empty list - run new request
if request_params.prompt: if request_params_as_score.prompt:
try: try:
response = cast(HuggingFaceClient, client).get_score_prompt_request( response = cast(HuggingFaceClient, client).get_score_prompt_request(
request_params request_params_as_score
) )
except AttributeError: except AttributeError:
raise ValueError("`score_prompt` only supported for HF models.") raise ValueError("`score_prompt` only supported for HF models.")
@ -479,7 +469,7 @@ class Manifest:
response = None response = None
final_response = self._stitch_responses_and_cache( final_response = self._stitch_responses_and_cache(
request=request_params, request=request_params_as_score,
client=client, client=client,
response=response, response=response,
cached_idx_to_response=cached_idx_to_response, cached_idx_to_response=cached_idx_to_response,

@ -3,13 +3,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel from pydantic import BaseModel
# Used when unioning requests after async connection pool
ENGINE_SEP = "::"
NOT_CACHE_KEYS = {"client_timeout", "batch_size"} NOT_CACHE_KEYS = {"client_timeout", "batch_size"}
# The below should match those in Request. # The below should match those in Request.
DEFAULT_REQUEST_KEYS = { DEFAULT_REQUEST_KEYS = {
"client_timeout": ("client_timeout", 60), # seconds "client_timeout": ("client_timeout", 60), # seconds
"batch_size": ("batch_size", 8), "batch_size": ("batch_size", 8),
"run_id": ("run_id", None), "run_id": ("run_id", None),
"request_type": ("request_type", None),
} }
@ -34,9 +35,6 @@ class Request(BaseModel):
# Batch size for async batch run # Batch size for async batch run
batch_size: int = 8 batch_size: int = 8
# Request type None is for completion. Used for scoring prompt
request_type: str = None
def to_dict( def to_dict(
self, allowable_keys: Dict[str, Tuple[str, Any]] = None, add_prompt: bool = True self, allowable_keys: Dict[str, Tuple[str, Any]] = None, add_prompt: bool = True
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -78,6 +76,9 @@ class LMRequest(Request):
# Top k sampling taking top_k highest probability tokens # Top k sampling taking top_k highest probability tokens
top_k: int = 50 top_k: int = 50
# Logprobs return value
logprobs: Optional[int] = None
# Stop sequences # Stop sequences
stop_sequences: Optional[List[str]] = None stop_sequences: Optional[List[str]] = None
@ -100,6 +101,12 @@ class LMRequest(Request):
frequency_penalty: float = 0 frequency_penalty: float = 0
class LMScoreRequest(LMRequest):
"""Language Model Score Request object."""
pass
class EmbeddingRequest(Request): class EmbeddingRequest(Request):
"""Embedding Request object.""" """Embedding Request object."""
@ -109,9 +116,6 @@ class EmbeddingRequest(Request):
class DiffusionRequest(Request): class DiffusionRequest(Request):
"""Diffusion Model Request object.""" """Diffusion Model Request object."""
# Request type
request_type: str = "diffusion"
# Number of steps # Number of steps
num_inference_steps: int = 50 num_inference_steps: int = 50

@ -1,21 +1,25 @@
"""Client response.""" """Client response."""
import copy import copy
import json import json
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Optional, Type, Union, cast
import numpy as np import numpy as np
from pydantic import BaseModel
from manifest.request import DiffusionRequest, EmbeddingRequest
from manifest.request import (
RESPONSE_CONSTRUCTORS = { ENGINE_SEP,
EmbeddingRequest: { DiffusionRequest,
"logits_key": "token_logprobs", EmbeddingRequest,
"item_key": "array", LMRequest,
}, LMScoreRequest,
DiffusionRequest: { Request,
"logits_key": "token_logprobs", )
"item_key": "array",
}, RESPONSE_CONSTRUCTORS: Dict[Type[Request], Dict[str, Union[str, Type[Request]]]] = {
LMRequest: {"response_type": "text", "request_type": LMRequest},
LMScoreRequest: {"response_type": "text", "request_type": LMScoreRequest},
EmbeddingRequest: {"response_type": "array", "request_type": EmbeddingRequest},
DiffusionRequest: {"response_type": "array", "request_type": DiffusionRequest},
} }
@ -29,94 +33,114 @@ class NumpyArrayEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, obj) return json.JSONEncoder.default(self, obj)
class Usage(BaseModel):
"""Prompt usage class."""
completion_tokens: int = 0
prompt_tokens: int = 0
total_tokens: int = 0
class Usages(BaseModel):
"""Prompt usage class."""
usages: List[Usage]
class LMModelChoice(BaseModel):
"""Model single completion."""
text: str
token_logprobs: Optional[List[float]] = None
tokens: Optional[List[int]] = None
class ArrayModelChoice(BaseModel):
"""Model single completion."""
array: np.ndarray
token_logprobs: Optional[List[float]] = None
class Config:
"""Pydantic config class."""
arbitrary_types_allowed = True
class ModelChoices(BaseModel):
"""Model choices."""
choices: List[Union[LMModelChoice, ArrayModelChoice]]
class Response: class Response:
"""Response class.""" """Response class."""
def __init__( def __init__(
self, self,
response: Dict, # TODO: make pydantic model response: ModelChoices,
cached: bool, cached: bool,
request_params: Dict, # TODO: use request pydantic model request: Request,
generation_key: str = "choices", response_type: str,
logits_key: str = "token_logprobs", request_type: Type[Request],
item_key: str = "text", usages: Optional[Usages] = None,
usage_key: str = "usage",
): ):
""" """
Initialize response. Initialize response.
Args: Args:
response: response dict. response: response dict.
usages: usage dict.
cached: whether response is cached. cached: whether response is cached.
request_params: request parameters. request: request.
generation_key: key for generation results. response_type: response type.
logits_key: key for logits. request_type: request type.
item_key: key for item in the generations.
""" """
self.generation_key = generation_key self._item_dtype = None
self.logits_key = logits_key self._response_type = response_type
self.item_key = item_key if self._response_type not in {"array", "text"}:
self.usage_key = usage_key raise ValueError(f"Invalid response type {self._response_type}")
self.item_dtype = None self._request_type = request_type
if isinstance(response, dict): self._response = response
self._response = response self._usages = usages or Usages(usages=[])
else:
raise ValueError(f"Response must be dict. Response is\n{response}.")
if (
(self.generation_key not in self._response)
or (not isinstance(self._response[self.generation_key], list))
or (len(self._response[self.generation_key]) <= 0)
):
raise ValueError(
"Response must be serialized to a dict with a nonempty"
f" list of choices. Response is\n{self._response}."
)
# Turn off usage if it is not in response
if self.usage_key not in self._response:
self.usage_key = None
else:
if not isinstance(self._response[self.usage_key], list):
raise ValueError(
"Response must be a list with usage dicts, one per choice."
f" Response is\n{self._response}."
)
if self.item_key not in self._response[self.generation_key][0]:
raise ValueError(
"Response must be serialized to a dict with a "
f"list of choices with {self.item_key} field"
)
if (
self.logits_key in self._response[self.generation_key][0]
and self._response[self.generation_key][0][self.logits_key]
):
if not isinstance(
self._response[self.generation_key][0][self.logits_key], list
):
raise ValueError(
f"{self.logits_key} must be a list of items "
"one for each token in the choice."
)
if isinstance(
self._response[self.generation_key][0][self.item_key], np.ndarray
):
self.item_dtype = str(
self._response[self.generation_key][0][self.item_key].dtype
)
self._cached = cached self._cached = cached
self._request_params = request_params self._request = request
if self._response.choices:
if response_type == "array":
if not isinstance(self._response.choices[0], ArrayModelChoice):
raise ValueError(
"response_type is array but response is "
f"{self._response.choices[0].__class__}"
)
self._item_dtype = str(
cast(ArrayModelChoice, self._response.choices[0]).array.dtype
)
else:
if not isinstance(self._response.choices[0], LMModelChoice):
raise ValueError(
"response_type is text but response is "
f"{self._response.choices[0].__class__}"
)
def is_cached(self) -> bool: def is_cached(self) -> bool:
"""Check if response is cached.""" """Check if response is cached."""
return self._cached return self._cached
def get_request(self) -> Dict: def get_request_obj(self) -> Request:
"""Get request parameters.""" """Get request parameters."""
return self._request_params return self._request
def get_response_obj(self) -> ModelChoices:
"""Get response object."""
return self._response
def get_usage_obj(self) -> Usages:
"""Get usage object."""
return self._usages
def get_json_response(self) -> Dict: def get_json_response(self) -> Dict:
"""Get response dict without parsing.""" """Get response dict without parsing."""
return self._response return self._response.dict()
def get_response( def get_response(
self, stop_token: str = "", is_batch: bool = False self, stop_token: str = "", is_batch: bool = False
@ -132,7 +156,8 @@ class Response:
lambda x: x.strip().split(stop_token)[0] if stop_token else x.strip() lambda x: x.strip().split(stop_token)[0] if stop_token else x.strip()
) )
extracted_items = [ extracted_items = [
choice[self.item_key] for choice in self._response[self.generation_key] choice.text if isinstance(choice, LMModelChoice) else choice.array
for choice in self._response.choices
] ]
if len(extracted_items) == 0: if len(extracted_items) == 0:
return None return None
@ -153,25 +178,15 @@ class Response:
if len(responses) == 1: if len(responses) == 1:
return responses[0] return responses[0]
first_response = responses[0] first_response = responses[0]
generation_key = first_response.generation_key request_type = first_response._request_type
logits_key = first_response.logits_key response_type = first_response._response_type
item_key = first_response.item_key request = first_response.get_request_obj()
# Usage key may be None, so get first not-None value
possible_usage_keys = [r.usage_key for r in responses if r.usage_key]
if possible_usage_keys:
usage_key = possible_usage_keys[0]
else:
usage_key = None
request = first_response._request_params
# Make sure all responses have the same keys # Make sure all responses have the same keys
if not all( if not all(
[ [
(r.generation_key == generation_key) (r._request_type == request_type)
and (r.logits_key == logits_key) and (r._response_type == response_type)
and (r.item_key == item_key)
# Usage key can be empty
and (not r.usage_key or not usage_key or r.usage_key == usage_key)
for r in responses for r in responses
] ]
): ):
@ -181,33 +196,31 @@ class Response:
all_prompts = [] all_prompts = []
all_choices = [] all_choices = []
all_usages = [] all_usages = []
all_engines = []
for res in responses: for res in responses:
json_response = res.get_json_response() all_engines.extend(res.get_request_obj().engine.split(ENGINE_SEP))
res_prompt = res.get_request()["prompt"] res_prompt = res.get_request_obj().prompt
if isinstance(res_prompt, str): if isinstance(res_prompt, str):
res_prompt = [res_prompt] res_prompt = [res_prompt]
all_prompts.extend(res_prompt) all_prompts.extend(res_prompt)
all_choices.extend(json_response[generation_key]) all_choices.extend(res.get_response_obj().choices)
if usage_key and usage_key in json_response: if res.get_usage_obj().usages:
all_usages.extend(json_response[usage_key]) all_usages.extend(res.get_usage_obj().usages)
else: else:
# Add empty usage # Add empty usages if not present
all_usages.extend([{}] * len(res_prompt)) all_usages.extend([Usage()] * len(res_prompt))
new_request = copy.deepcopy(request) new_request = copy.deepcopy(request)
# TODO: add both models back in request. This should be a lot new_request.engine = ENGINE_SEP.join(sorted(set(all_engines)))
# easier after I pydantic the response and request more formally new_request.prompt = all_prompts
new_request["prompt"] = all_prompts new_response = ModelChoices(choices=all_choices)
new_response = {generation_key: all_choices} new_usages = Usages(usages=all_usages)
if usage_key:
new_response[usage_key] = all_usages
response_obj = cls( response_obj = cls(
new_response, response=new_response,
cached=any(res.is_cached() for res in responses), cached=any(res.is_cached() for res in responses),
request_params=new_request, request=new_request,
generation_key=generation_key, usages=new_usages,
logits_key=logits_key, request_type=request_type,
item_key=item_key, response_type=response_type,
usage_key=usage_key,
) )
return response_obj return response_obj
@ -232,56 +245,74 @@ class Response:
serialized response. serialized response.
""" """
deserialized = json.loads(value) deserialized = json.loads(value)
item_dtype = deserialized["item_dtype"] return cls.from_dict(deserialized)
if item_dtype:
for choice in deserialized["response"][deserialized["generation_key"]]:
choice[deserialized["item_key"]] = np.array(
choice[deserialized["item_key"]]
).astype(item_dtype)
return cls(
deserialized["response"],
deserialized["cached"],
deserialized["request_params"],
generation_key=deserialized["generation_key"],
logits_key=deserialized["logits_key"],
item_key=deserialized["item_key"],
)
def to_dict(self) -> Dict: def to_dict(self, drop_request: bool = False) -> Dict:
""" """
Get dictionary representation of response. Get dictionary representation of response.
Returns: Returns:
dictionary representation of response. dictionary representation of response.
""" """
return { to_return = {
"generation_key": self.generation_key, "response": self._response.dict(),
"logits_key": self.logits_key, "usages": self._usages.dict(),
"item_key": self.item_key,
"item_dtype": self.item_dtype,
"response": self._response,
"cached": self._cached, "cached": self._cached,
"request_params": self._request_params, "request": self._request.dict(),
"response_type": self._response_type,
"request_type": str(self._request_type.__name__),
"item_dtype": self._item_dtype,
} }
if drop_request:
to_return.pop("request")
return to_return
@classmethod @classmethod
def from_dict(cls, response: Dict) -> "Response": def from_dict(
cls, response_dict: Dict, request_dict: Optional[Dict] = None
) -> "Response":
""" """
Create response from dictionary. Create response from dictionary.
Args: Args:
response: dictionary representation of response. response: dictionary representation of response.
request_dict: dictionary representation of request which
will override what is in response_dict.
Returns: Returns:
response. response.
""" """
if "request" not in response_dict and request_dict is None:
raise ValueError(
"Request dictionary must be provided if "
"request is not in response dictionary."
)
item_dtype = response_dict["item_dtype"]
response_type = response_dict["response_type"]
if response_dict["request_type"] == "LMRequest":
request_type: Type[Request] = LMRequest
elif response_dict["request_type"] == "LMScoreRequest":
request_type = LMScoreRequest
elif response_dict["request_type"] == "EmbeddingRequest":
request_type = EmbeddingRequest
elif response_dict["request_type"] == "DiffusionRequest":
request_type = DiffusionRequest
choices: List[Union[LMModelChoice, ArrayModelChoice]] = []
if item_dtype and response_type == "array":
for choice in response_dict["response"]["choices"]:
choice["array"] = np.array(choice["array"]).astype(item_dtype)
choices.append(ArrayModelChoice(**choice))
else:
for choice in response_dict["response"]["choices"]:
choices.append(LMModelChoice(**choice))
response = ModelChoices(choices=choices)
return cls( return cls(
response["response"], response=response,
response["cached"], usages=Usages(**response_dict["usages"]),
response["request_params"], cached=response_dict["cached"],
generation_key=response["generation_key"], request=request_type(**(request_dict or response_dict["request"])),
logits_key=response["logits_key"], response_type=response_type,
item_key=response["item_key"], request_type=request_type,
) )
def __str__(self) -> str: def __str__(self) -> str:

@ -4,9 +4,94 @@ import shutil
from pathlib import Path from pathlib import Path
from typing import Generator from typing import Generator
import numpy as np
import pytest import pytest
import redis import redis
from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest
from manifest.response import ArrayModelChoice, LMModelChoice, ModelChoices
@pytest.fixture
def model_choice() -> ModelChoices:
"""Get dummy model choice."""
model_choices = ModelChoices(
choices=[
LMModelChoice(text="hello", token_logprobs=[0.1, 0.2]),
LMModelChoice(text="bye", token_logprobs=[0.3]),
]
)
return model_choices
@pytest.fixture
def model_choice_single() -> ModelChoices:
"""Get dummy model choice."""
model_choices = ModelChoices(
choices=[
LMModelChoice(text="helloo", token_logprobs=[0.1, 0.2]),
]
)
return model_choices
@pytest.fixture
def model_choice_arr() -> ModelChoices:
"""Get dummy model choice."""
np.random.seed(0)
model_choices = ModelChoices(
choices=[
ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.1, 0.2]),
ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.3]),
]
)
return model_choices
@pytest.fixture
def model_choice_arr_int() -> ModelChoices:
"""Get dummy model choice."""
np.random.seed(0)
model_choices = ModelChoices(
choices=[
ArrayModelChoice(
array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.1, 0.2]
),
ArrayModelChoice(
array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.3]
),
]
)
return model_choices
@pytest.fixture
def request_lm() -> LMRequest:
"""Get dummy request."""
request = LMRequest(prompt=["what", "cat"])
return request
@pytest.fixture
def request_lm_single() -> LMRequest:
"""Get dummy request."""
request = LMRequest(prompt="monkey", engine="dummy")
return request
@pytest.fixture
def request_array() -> EmbeddingRequest:
"""Get dummy request."""
request = EmbeddingRequest(prompt="hello")
return request
@pytest.fixture
def request_diff() -> DiffusionRequest:
"""Get dummy request."""
request = DiffusionRequest(prompt="hello")
return request
@pytest.fixture @pytest.fixture
def sqlite_cache(tmp_path: Path) -> Generator[str, None, None]: def sqlite_cache(tmp_path: Path) -> Generator[str, None, None]:

@ -12,6 +12,7 @@ from manifest.caches.postgres import PostgresCache
from manifest.caches.redis import RedisCache from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache from manifest.caches.sqlite import SQLiteCache
from manifest.request import DiffusionRequest, LMRequest, Request from manifest.request import DiffusionRequest, LMRequest, Request
from manifest.response import ArrayModelChoice, ModelChoices, Response
def _get_postgres_cache( def _get_postgres_cache(
@ -78,7 +79,16 @@ def test_key_get_and_set(
@pytest.mark.usefixtures("postgres_cache") @pytest.mark.usefixtures("postgres_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis", "postgres"]) @pytest.mark.parametrize("cache_type", ["sqlite", "redis", "postgres"])
def test_get( def test_get(
sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str sqlite_cache: str,
redis_cache: str,
postgres_cache: str,
cache_type: str,
model_choice: ModelChoices,
model_choice_single: ModelChoices,
model_choice_arr_int: ModelChoices,
request_lm: LMRequest,
request_lm_single: LMRequest,
request_diff: DiffusionRequest,
) -> None: ) -> None:
"""Test cache save prompt.""" """Test cache save prompt."""
if cache_type == "sqlite": if cache_type == "sqlite":
@ -88,22 +98,51 @@ def test_get(
elif cache_type == "postgres": elif cache_type == "postgres":
cache = cast(Cache, _get_postgres_cache()) cache = cast(Cache, _get_postgres_cache())
test_request = {"test": "hello", "testA": "world"} response = Response(
test_response = {"choices": [{"text": "hello"}]} response=model_choice_single,
cached=False,
request=request_lm_single,
usages=None,
request_type=LMRequest,
response_type="text",
)
response = cache.get(test_request) cache_response = cache.get(request_lm_single.dict())
assert response is None assert cache_response is None
cache.set(request_lm_single.dict(), response.to_dict(drop_request=True))
cache_response = cache.get(request_lm_single.dict())
assert cache_response.get_response() == "helloo"
assert cache_response.is_cached()
assert cache_response.get_request_obj() == request_lm_single
response = Response(
response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="text",
)
cache.set(test_request, test_response) cache_response = cache.get(request_lm.dict())
response = cache.get(test_request) assert cache_response is None
assert response.get_response() == "hello"
assert response.is_cached() cache.set(request_lm.dict(), response.to_dict(drop_request=True))
assert response.get_request() == test_request cache_response = cache.get(request_lm.dict())
assert cache_response.get_response() == ["hello", "bye"]
assert cache_response.is_cached()
assert cache_response.get_request_obj() == request_lm
# Test array # Test array
arr = np.random.rand(4, 4) response = Response(
test_request = {"test": "hello", "testA": "world of images"} response=model_choice_arr_int,
compute_arr_response = {"choices": [{"array": arr}]} cached=False,
request=request_diff,
usages=None,
request_type=DiffusionRequest,
response_type="array",
)
if cache_type == "sqlite": if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, request_type=DiffusionRequest) cache = SQLiteCache(sqlite_cache, request_type=DiffusionRequest)
@ -112,103 +151,34 @@ def test_get(
elif cache_type == "postgres": elif cache_type == "postgres":
cache = _get_postgres_cache(request_type=DiffusionRequest) cache = _get_postgres_cache(request_type=DiffusionRequest)
response = cache.get(test_request) cache_response = cache.get(request_diff.dict())
assert response is None assert cache_response is None
cache.set(test_request, compute_arr_response) cache.set(request_diff.dict(), response.to_dict(drop_request=True))
response = cache.get(test_request) cached_response = cache.get(request_diff.dict())
assert np.allclose(response.get_response(), arr) assert np.allclose(
assert response.is_cached() cached_response.get_response()[0],
assert response.get_request() == test_request cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array,
)
assert np.allclose(
cached_response.get_response()[1],
cast(ArrayModelChoice, model_choice_arr_int.choices[1]).array,
)
assert cached_response.is_cached()
assert cached_response.get_request_obj() == request_diff
# Test array byte string # Test array byte string
arr = np.random.rand(4, 4) # Make sure to not hit the cache
test_request = {"test": "hello", "testA": "world of images 2"} new_request_diff = DiffusionRequest(**request_diff.dict())
compute_arr_response = {"choices": [{"array": arr}]} new_request_diff.prompt = ["blahhh", "yayayay"]
response = Response(
if cache_type == "sqlite": response=model_choice_arr_int,
cache = SQLiteCache( cached=False,
sqlite_cache, request=new_request_diff,
request_type=DiffusionRequest, usages=None,
cache_args={"array_serializer": "byte_string"}, request_type=DiffusionRequest,
) response_type="array",
elif cache_type == "redis": )
cache = RedisCache(
redis_cache,
request_type=DiffusionRequest,
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "postgres":
cache = _get_postgres_cache(
request_type=DiffusionRequest,
cache_args={"array_serializer": "byte_string"},
)
response = cache.get(test_request)
assert response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response(), arr)
assert response.is_cached()
assert response.get_request() == test_request
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.usefixtures("postgres_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis", "postgres"])
def test_get_batch_prompt(
sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str
) -> None:
"""Test cache save prompt."""
if cache_type == "sqlite":
cache = cast(Cache, SQLiteCache(sqlite_cache))
elif cache_type == "redis":
cache = cast(Cache, RedisCache(redis_cache))
elif cache_type == "postgres":
cache = cast(Cache, _get_postgres_cache())
test_request = {"test": ["hello", "goodbye"], "testA": "world"}
test_response = {"choices": [{"text": "hello"}, {"text": "goodbye"}]}
response = cache.get(test_request)
assert response is None
cache.set(test_request, test_response)
response = cache.get(test_request)
assert response.get_response() == ["hello", "goodbye"]
assert response.is_cached()
assert response.get_request() == test_request
# Test arrays
arr = np.random.rand(4, 4)
arr2 = np.random.rand(4, 4)
test_request = {"test": ["hello", "goodbye"], "testA": "world of images"}
compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]}
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, request_type=DiffusionRequest)
elif cache_type == "redis":
cache = RedisCache(redis_cache, request_type=DiffusionRequest)
elif cache_type == "postgres":
cache = _get_postgres_cache(request_type=DiffusionRequest)
response = cache.get(test_request)
assert response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response()[0], arr)
assert np.allclose(response.get_response()[1], arr2)
assert response.is_cached()
assert response.get_request() == test_request
# Test arrays byte serializer
arr = np.random.rand(4, 4)
arr2 = np.random.rand(4, 4)
test_request = {"test": ["hello", "goodbye"], "testA": "world of images 2"}
compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]}
if cache_type == "sqlite": if cache_type == "sqlite":
cache = SQLiteCache( cache = SQLiteCache(
@ -228,15 +198,21 @@ def test_get_batch_prompt(
cache_args={"array_serializer": "byte_string"}, cache_args={"array_serializer": "byte_string"},
) )
response = cache.get(test_request) cached_response = cache.get(new_request_diff.dict())
assert response is None assert cached_response is None
cache.set(test_request, compute_arr_response) cache.set(new_request_diff.dict(), response.to_dict(drop_request=True))
response = cache.get(test_request) cached_response = cache.get(new_request_diff.dict())
assert np.allclose(response.get_response()[0], arr) assert np.allclose(
assert np.allclose(response.get_response()[1], arr2) cached_response.get_response()[0],
assert response.is_cached() cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array,
assert response.get_request() == test_request )
assert np.allclose(
cached_response.get_response()[1],
cast(ArrayModelChoice, model_choice_arr_int.choices[1]).array,
)
assert cached_response.is_cached()
assert cached_response.get_request_obj() == new_request_diff
def test_noop_cache() -> None: def test_noop_cache() -> None:

@ -33,10 +33,13 @@ def test_get_request() -> None:
"prompt": "hello", "prompt": "hello",
"num_results": 3, "num_results": 3,
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
} }
assert response.get_json_response() == { assert response.get_json_response() == {
"choices": [{"text": "hello"}] * 3, "choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 3,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 3, }
assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 3,
} }
request_params = client.get_request("hello", {"n": 5}) request_params = client.get_request("hello", {"n": 5})
@ -45,10 +48,13 @@ def test_get_request() -> None:
"prompt": "hello", "prompt": "hello",
"num_results": 5, "num_results": 5,
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
} }
assert response.get_json_response() == { assert response.get_json_response() == {
"choices": [{"text": "hello"}] * 5, "choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5, }
assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
} }
request_params = client.get_request(["hello"] * 5, {"n": 1}) request_params = client.get_request(["hello"] * 5, {"n": 1})
@ -57,8 +63,11 @@ def test_get_request() -> None:
"prompt": ["hello"] * 5, "prompt": ["hello"] * 5,
"num_results": 1, "num_results": 1,
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
} }
assert response.get_json_response() == { assert response.get_json_response() == {
"choices": [{"text": "hello"}] * 5, "choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5, }
assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
} }

@ -90,8 +90,8 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response: if return_response:
assert isinstance(result, Response) assert isinstance(result, Response)
result = cast(Response, result) result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(
result.get_json_response()["choices"] result.get_response_obj().choices
) )
res = result.get_response(manifest.stop_token) res = result.get_response(manifest.stop_token)
else: else:
@ -101,6 +101,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{ {
"prompt": "This is a prompt", "prompt": "This is a prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
"num_results": n, "num_results": n,
}, },
) )
@ -116,8 +117,8 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response: if return_response:
assert isinstance(result, Response) assert isinstance(result, Response)
result = cast(Response, result) result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(
result.get_json_response()["choices"] result.get_response_obj().choices
) )
res = result.get_response(manifest.stop_token) res = result.get_response(manifest.stop_token)
else: else:
@ -127,6 +128,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{ {
"prompt": "This is a prompt", "prompt": "This is a prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
"num_results": n, "num_results": n,
"run_id": "34", "run_id": "34",
} }
@ -143,8 +145,8 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response: if return_response:
assert isinstance(result, Response) assert isinstance(result, Response)
result = cast(Response, result) result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(
result.get_json_response()["choices"] result.get_response_obj().choices
) )
res = result.get_response(manifest.stop_token) res = result.get_response(manifest.stop_token)
else: else:
@ -154,6 +156,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{ {
"prompt": "Hello is a prompt", "prompt": "Hello is a prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
"num_results": n, "num_results": n,
}, },
) )
@ -169,8 +172,8 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response: if return_response:
assert isinstance(result, Response) assert isinstance(result, Response)
result = cast(Response, result) result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(
result.get_json_response()["choices"] result.get_response_obj().choices
) )
res = result.get_response(stop_token="ll") res = result.get_response(stop_token="ll")
else: else:
@ -180,6 +183,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{ {
"prompt": "Hello is a prompt", "prompt": "Hello is a prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
"num_results": n, "num_results": n,
}, },
) )
@ -212,8 +216,8 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response: if return_response:
assert isinstance(result, Response) assert isinstance(result, Response)
result = cast(Response, result) result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(
result.get_json_response()["choices"] result.get_response_obj().choices
) )
res = result.get_response(manifest.stop_token, is_batch=True) res = result.get_response(manifest.stop_token, is_batch=True)
else: else:
@ -224,6 +228,7 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{ {
"prompt": "This is a prompt", "prompt": "This is a prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
"num_results": n, "num_results": n,
}, },
) )
@ -235,8 +240,8 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response: if return_response:
assert isinstance(result, Response) assert isinstance(result, Response)
result = cast(Response, result) result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(
result.get_json_response()["choices"] result.get_response_obj().choices
) )
res = result.get_response(manifest.stop_token, is_batch=True) res = result.get_response(manifest.stop_token, is_batch=True)
else: else:
@ -247,6 +252,7 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{ {
"prompt": "Hello is a prompt", "prompt": "Hello is a prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
"num_results": n, "num_results": n,
}, },
) )
@ -262,6 +268,7 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{ {
"prompt": "New prompt", "prompt": "New prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
"num_results": n, "num_results": n,
}, },
) )
@ -272,8 +279,8 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response: if return_response:
assert isinstance(result, Response) assert isinstance(result, Response)
result = cast(Response, result) result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(
result.get_json_response()["choices"] result.get_response_obj().choices
) )
res = result.get_response(manifest.stop_token, is_batch=True) res = result.get_response(manifest.stop_token, is_batch=True)
# Cached because one item is in cache # Cached because one item is in cache
@ -287,8 +294,8 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response: if return_response:
assert isinstance(result, Response) assert isinstance(result, Response)
result = cast(Response, result) result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(
result.get_json_response()["choices"] result.get_response_obj().choices
) )
res = result.get_response(stop_token="ll", is_batch=True) res = result.get_response(stop_token="ll", is_batch=True)
else: else:
@ -309,9 +316,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
) )
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token, is_batch=True) res = result.get_response(manifest.stop_token, is_batch=True)
assert res == ["hello"] assert res == ["hello"]
assert ( assert (
@ -319,6 +324,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
{ {
"prompt": "This is a prompt", "prompt": "This is a prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
"num_results": 1, "num_results": 1,
}, },
) )
@ -330,9 +336,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
) )
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token, is_batch=True) res = result.get_response(manifest.stop_token, is_batch=True)
assert res == ["hello", "hello"] assert res == ["hello", "hello"]
assert ( assert (
@ -340,6 +344,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
{ {
"prompt": "Hello is a prompt", "prompt": "Hello is a prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
"num_results": 1, "num_results": 1,
}, },
) )
@ -350,9 +355,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
) )
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token, is_batch=True) res = result.get_response(manifest.stop_token, is_batch=True)
assert result.is_cached() assert result.is_cached()
@ -361,6 +364,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
{ {
"prompt": "New prompt", "prompt": "New prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest",
"num_results": 1, "num_results": 1,
}, },
) )
@ -371,9 +375,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
) )
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token, is_batch=True) res = result.get_response(manifest.stop_token, is_batch=True)
# Cached because one item is in cache # Cached because one item is in cache
assert result.is_cached() assert result.is_cached()
@ -384,9 +386,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
) )
assert len(result.get_json_response()["usage"]) == len( assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
result.get_json_response()["choices"]
)
res = result.get_response(stop_token="ll", is_batch=True) res = result.get_response(stop_token="ll", is_batch=True)
assert res == ["he", "he"] assert res == ["he", "he"]
@ -407,25 +407,43 @@ def test_score_run(sqlite_cache: str) -> None:
{ {
"prompt": "This is a prompt", "prompt": "This is a prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMScoreRequest",
"num_results": 1, "num_results": 1,
"request_type": "score_prompt",
}, },
) )
is not None is not None
) )
assert result == { assert result == {
"generation_key": "choices", "response": {
"logits_key": "token_logprobs", "choices": [
"item_key": "text", {"text": "This is a prompt", "token_logprobs": [0.3], "tokens": None}
"item_dtype": None, ]
"response": {"choices": [{"text": "This is a prompt", "logprob": 0.3}]}, },
"usages": {"usages": []},
"cached": False, "cached": False,
"request_params": { "request": {
"prompt": "This is a prompt", "prompt": "This is a prompt",
"engine": "dummy", "engine": "text-ada-001",
"num_results": 1, "n": 1,
"request_type": "score_prompt", "client_timeout": 60,
"run_id": None,
"batch_size": 8,
"temperature": 0.7,
"max_tokens": 100,
"top_p": 1.0,
"top_k": 50,
"logprobs": None,
"stop_sequences": None,
"num_beams": 1,
"do_sample": False,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
}, },
"response_type": "text",
"request_type": "LMScoreRequest",
"item_dtype": None,
} }
prompt_list = ["Hello is a prompt", "Hello is another prompt"] prompt_list = ["Hello is a prompt", "Hello is another prompt"]
@ -435,8 +453,8 @@ def test_score_run(sqlite_cache: str) -> None:
{ {
"prompt": "Hello is a prompt", "prompt": "Hello is a prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMScoreRequest",
"num_results": 1, "num_results": 1,
"request_type": "score_prompt",
}, },
) )
is not None is not None
@ -446,30 +464,48 @@ def test_score_run(sqlite_cache: str) -> None:
{ {
"prompt": "Hello is another prompt", "prompt": "Hello is another prompt",
"engine": "dummy", "engine": "dummy",
"request_cls": "LMScoreRequest",
"num_results": 1, "num_results": 1,
"request_type": "score_prompt",
}, },
) )
is not None is not None
) )
assert result == { assert result == {
"generation_key": "choices",
"logits_key": "token_logprobs",
"item_key": "text",
"item_dtype": None,
"response": { "response": {
"choices": [ "choices": [
{"text": "Hello is a prompt", "logprob": 0.3}, {"text": "Hello is a prompt", "token_logprobs": [0.3], "tokens": None},
{"text": "Hello is another prompt", "logprob": 0.3}, {
"text": "Hello is another prompt",
"token_logprobs": [0.3],
"tokens": None,
},
] ]
}, },
"usages": {"usages": []},
"cached": False, "cached": False,
"request_params": { "request": {
"prompt": ["Hello is a prompt", "Hello is another prompt"], "prompt": ["Hello is a prompt", "Hello is another prompt"],
"engine": "dummy", "engine": "text-ada-001",
"num_results": 1, "n": 1,
"request_type": "score_prompt", "client_timeout": 60,
"run_id": None,
"batch_size": 8,
"temperature": 0.7,
"max_tokens": 100,
"top_p": 1.0,
"top_k": 50,
"logprobs": None,
"stop_sequences": None,
"num_beams": 1,
"do_sample": False,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
}, },
"response_type": "text",
"request_type": "LMScoreRequest",
"item_dtype": None,
} }
@ -644,8 +680,8 @@ def test_openai(sqlite_cache: str) -> None:
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0 assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
assert response.get_response() == res assert response.get_response() == res
assert response.is_cached() is True assert response.is_cached() is True
assert "usage" in response.get_json_response() assert response.get_usage_obj().usages
assert response.get_json_response()["usage"][0]["total_tokens"] == 15 assert response.get_usage_obj().usages[0].total_tokens == 15
response = cast(Response, client.run("Why are there apples?", return_response=True)) response = cast(Response, client.run("Why are there apples?", return_response=True))
assert response.is_cached() is True assert response.is_cached() is True
@ -662,12 +698,9 @@ def test_openai(sqlite_cache: str) -> None:
assert ( assert (
isinstance(response.get_response(), list) and len(response.get_response()) == 2 isinstance(response.get_response(), list) and len(response.get_response()) == 2
) )
assert ( assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2
"usage" in response.get_json_response() assert response.get_usage_obj().usages[0].total_tokens == 15
and len(response.get_json_response()["usage"]) == 2 assert response.get_usage_obj().usages[1].total_tokens == 16
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 15
assert response.get_json_response()["usage"][1]["total_tokens"] == 16
response = cast( response = cast(
Response, client.run("Why are there bananas?", return_response=True) Response, client.run("Why are there bananas?", return_response=True)
@ -691,12 +724,9 @@ def test_openai(sqlite_cache: str) -> None:
assert ( assert (
isinstance(response.get_response(), list) and len(response.get_response()) == 2 isinstance(response.get_response(), list) and len(response.get_response()) == 2
) )
assert ( assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2
"usage" in response.get_json_response() assert response.get_usage_obj().usages[0].total_tokens == 17
and len(response.get_json_response()["usage"]) == 2 assert response.get_usage_obj().usages[1].total_tokens == 15
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 17
assert response.get_json_response()["usage"][1]["total_tokens"] == 15
response = cast( response = cast(
Response, client.run("Why are there oranges?", return_response=True) Response, client.run("Why are there oranges?", return_response=True)
@ -721,8 +751,8 @@ def test_openaichat(sqlite_cache: str) -> None:
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0 assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
assert response.get_response() == res assert response.get_response() == res
assert response.is_cached() is True assert response.is_cached() is True
assert "usage" in response.get_json_response() assert response.get_usage_obj().usages
assert response.get_json_response()["usage"][0]["total_tokens"] == 23 assert response.get_usage_obj().usages[0].total_tokens == 23
response = cast(Response, client.run("Why are there apples?", return_response=True)) response = cast(Response, client.run("Why are there apples?", return_response=True))
assert response.is_cached() is True assert response.is_cached() is True
@ -749,12 +779,9 @@ def test_openaichat(sqlite_cache: str) -> None:
assert ( assert (
isinstance(response.get_response(), list) and len(response.get_response()) == 2 isinstance(response.get_response(), list) and len(response.get_response()) == 2
) )
assert ( assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2
"usage" in response.get_json_response() assert response.get_usage_obj().usages[0].total_tokens == 25
and len(response.get_json_response()["usage"]) == 2 assert response.get_usage_obj().usages[1].total_tokens == 23
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 25
assert response.get_json_response()["usage"][1]["total_tokens"] == 23
response = cast( response = cast(
Response, client.run("Why are there oranges?", return_response=True) Response, client.run("Why are there oranges?", return_response=True)
@ -795,8 +822,8 @@ def test_openaiembedding(sqlite_cache: str) -> None:
assert isinstance(response.get_response(), np.ndarray) assert isinstance(response.get_response(), np.ndarray)
assert np.allclose(response.get_response(), res) assert np.allclose(response.get_response(), res)
assert response.is_cached() is True assert response.is_cached() is True
assert "usage" in response.get_json_response() assert response.get_usage_obj().usages
assert response.get_json_response()["usage"][0]["total_tokens"] == 5 assert response.get_usage_obj().usages[0].total_tokens == 5
response = cast(Response, client.run("Why are there apples?", return_response=True)) response = cast(Response, client.run("Why are there apples?", return_response=True))
assert response.is_cached() is True assert response.is_cached() is True
@ -817,12 +844,9 @@ def test_openaiembedding(sqlite_cache: str) -> None:
assert ( assert (
isinstance(response.get_response(), list) and len(response.get_response()) == 2 isinstance(response.get_response(), list) and len(response.get_response()) == 2
) )
assert ( assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2
"usage" in response.get_json_response() assert response.get_usage_obj().usages[0].total_tokens == 5
and len(response.get_json_response()["usage"]) == 2 assert response.get_usage_obj().usages[1].total_tokens == 6
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 5
assert response.get_json_response()["usage"][1]["total_tokens"] == 6
response = cast( response = cast(
Response, client.run("Why are there bananas?", return_response=True) Response, client.run("Why are there bananas?", return_response=True)
@ -857,12 +881,9 @@ def test_openaiembedding(sqlite_cache: str) -> None:
and len(res_list) == 2 and len(res_list) == 2
and isinstance(res_list[0], np.ndarray) and isinstance(res_list[0], np.ndarray)
) )
assert ( assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2
"usage" in response.get_json_response() assert response.get_usage_obj().usages[0].total_tokens == 7
and len(response.get_json_response()["usage"]) == 2 assert response.get_usage_obj().usages[1].total_tokens == 5
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 7
assert response.get_json_response()["usage"][1]["total_tokens"] == 5
response = cast( response = cast(
Response, client.run("Why are there oranges?", return_response=True) Response, client.run("Why are there oranges?", return_response=True)

@ -1,230 +1,301 @@
"""Response test.""" """Response test."""
from typing import Any, Dict from typing import List, cast
import numpy as np import numpy as np
import pytest import pytest
from manifest import Response from manifest import Response
from manifest.request import LMRequest from manifest.request import EmbeddingRequest, LMRequest
from manifest.response import ArrayModelChoice, ModelChoices, Usage, Usages
def test_init() -> None: def test_init(
model_choice: ModelChoices,
model_choice_arr: ModelChoices,
model_choice_arr_int: ModelChoices,
request_lm: LMRequest,
request_array: EmbeddingRequest,
) -> None:
"""Test response initialization.""" """Test response initialization."""
with pytest.raises(ValueError) as exc_info: response = Response(
response = Response(4, False, {}) # type: ignore response=model_choice,
assert str(exc_info.value) == "Response must be dict. Response is\n4." cached=False,
with pytest.raises(ValueError) as exc_info: request=request_lm,
response = Response({"test": "hello"}, False, {}) usages=None,
assert str(exc_info.value) == ( request_type=LMRequest,
"Response must be serialized to a dict with a nonempty list of choices. " response_type="text",
"Response is\n{'test': 'hello'}." )
) assert response._response == model_choice
with pytest.raises(ValueError) as exc_info:
response = Response({"choices": [{"blah": "hello"}]}, False, {})
assert str(exc_info.value) == (
"Response must be serialized to a dict "
"with a list of choices with text field"
)
with pytest.raises(ValueError) as exc_info:
response = Response({"choices": []}, False, {})
assert str(exc_info.value) == (
"Response must be serialized to a dict with a nonempty list of choices. "
"Response is\n{'choices': []}."
)
response = Response({"choices": [{"text": "hello"}]}, False, {})
assert response._response == {"choices": [{"text": "hello"}]}
assert response._cached is False assert response._cached is False
assert response._request_params == {} assert response._request == request_lm
assert response.item_dtype is None assert response._usages == Usages(usages=[])
assert response._request_type == LMRequest
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"}) assert response._response_type == "text"
assert response._response == {"choices": [{"text": "hello"}]} assert response._item_dtype is None
assert response._cached is True
assert response._request_params == {"request": "yoyo"}
assert response.item_dtype is None
response = Response( response = Response(
{"generations": [{"txt": "hello"}], "logits": []}, response=model_choice_arr_int,
False, cached=False,
{}, request=request_array,
generation_key="generations", usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
logits_key="logits", request_type=EmbeddingRequest,
item_key="txt", response_type="array",
) )
assert response._response == {"generations": [{"txt": "hello"}], "logits": []}
assert response._cached is False assert response._cached is False
assert response._request_params == {} assert response._request == request_array
assert response.item_dtype is None assert sum([usg.total_tokens for usg in response._usages.usages]) == 10
assert response._request_type == EmbeddingRequest
assert response._response_type == "array"
assert response._item_dtype == "int64"
int_arr = np.random.randint(20, size=(4, 4)) with pytest.raises(ValueError) as excinfo:
response = Response( Response(
{"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array" response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="blah",
)
assert "blah" in str(excinfo.value)
# Can't convert array with text
with pytest.raises(ValueError) as excinfo:
Response(
response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="array",
)
assert str(excinfo.value) == (
"response_type is array but response is "
"<class 'manifest.response.LMModelChoice'>"
)
# Can't convert text with array
with pytest.raises(ValueError) as excinfo:
Response(
response=model_choice_arr,
cached=False,
request=request_array,
usages=None,
request_type=LMRequest,
response_type="text",
)
assert str(excinfo.value) == (
"response_type is text but response is "
"<class 'manifest.response.ArrayModelChoice'>"
) )
assert response._response == {"choices": [{"array": int_arr}]}
assert response._cached is True
assert response._request_params == {"request": "yoyo"}
assert response.item_dtype == "int64"
def test_getters() -> None: def test_getters(model_choice: ModelChoices, request_lm: LMRequest) -> None:
"""Test response cached.""" """Test response cached."""
response = Response({"choices": [{"text": "hello"}]}, False, {}) response = Response(
assert response.get_json_response() == {"choices": [{"text": "hello"}]} response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="text",
)
assert response.get_response_obj() == model_choice
assert response.is_cached() is False assert response.is_cached() is False
assert response.get_request() == {} assert response.get_request_obj() == request_lm
assert response.get_usage_obj() == Usages(usages=[])
assert response.get_json_response() == model_choice.dict()
assert response.get_response() == ["hello", "bye"]
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
assert response.get_json_response() == {"choices": [{"text": "hello"}]}
assert response.is_cached() is True
assert response.get_request() == {"request": "yoyo"}
int_arr = np.random.randint(20, size=(4, 4)) def test_serialize(
model_choice: ModelChoices,
model_choice_arr: ModelChoices,
model_choice_arr_int: ModelChoices,
request_lm: LMRequest,
request_array: EmbeddingRequest,
) -> None:
"""Test response serialization."""
response = Response( response = Response(
{"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array" response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="text",
) )
assert response.get_json_response() == {"choices": [{"array": int_arr}]} deserialized_response = Response.deserialize(response.serialize())
assert response.is_cached() is True assert deserialized_response.get_response_obj() == model_choice
assert response.get_request() == {"request": "yoyo"} assert deserialized_response.is_cached() is False
assert deserialized_response.get_request_obj() == request_lm
assert deserialized_response.get_usage_obj() == Usages(usages=[])
assert deserialized_response.get_json_response() == model_choice.dict()
assert deserialized_response.get_response() == ["hello", "bye"]
deserialized_response = Response.from_dict(response.to_dict())
assert deserialized_response.get_response_obj() == model_choice
assert deserialized_response.is_cached() is False
assert deserialized_response.get_request_obj() == request_lm
assert deserialized_response.get_usage_obj() == Usages(usages=[])
assert deserialized_response.get_json_response() == model_choice.dict()
assert deserialized_response.get_response() == ["hello", "bye"]
def test_serialize() -> None: deserialized_response = Response.from_dict(
"""Test response serialization.""" response.to_dict(drop_request=True), request_dict={"prompt": "blahhhh"}
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"}) )
deserialized_response = Response.deserialize(response.serialize()) assert deserialized_response.get_response_obj() == model_choice
assert deserialized_response._response == {"choices": [{"text": "hello"}]} assert deserialized_response.is_cached() is False
assert deserialized_response.is_cached() is True assert deserialized_response.get_request_obj().prompt == "blahhhh"
assert deserialized_response._request_params == {"request": "yoyo"} assert deserialized_response.get_usage_obj() == Usages(usages=[])
assert deserialized_response.get_json_response() == model_choice.dict()
assert deserialized_response.get_response() == ["hello", "bye"]
int_arr = np.random.randint(20, size=(4, 4)) # Int type
response = Response( response = Response(
{"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array" response=model_choice_arr_int,
cached=False,
request=request_array,
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
request_type=EmbeddingRequest,
response_type="array",
) )
deserialized_response = Response.deserialize(response.serialize()) deserialized_response = Response.deserialize(response.serialize())
assert deserialized_response._item_dtype == "int64"
assert (
cast(
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
).array.dtype
== np.int64
)
assert np.array_equal( assert np.array_equal(
deserialized_response._response["choices"][0]["array"], int_arr cast(
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
).array,
cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array,
) )
assert deserialized_response.is_cached() is True
assert deserialized_response._request_params == {"request": "yoyo"}
float_arr = np.random.randn(4, 4) # Float type
response = Response( response = Response(
{"choices": [{"array": float_arr}]}, True, {"request": "yoyo"}, item_key="array" response=model_choice_arr,
cached=False,
request=request_array,
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
request_type=EmbeddingRequest,
response_type="array",
) )
deserialized_response = Response.deserialize(response.serialize()) deserialized_response = Response.deserialize(response.serialize())
assert deserialized_response._item_dtype == "float64"
assert (
cast(
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
).array.dtype
== np.float64
)
assert np.array_equal( assert np.array_equal(
deserialized_response._response["choices"][0]["array"], float_arr cast(
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
).array,
cast(ArrayModelChoice, model_choice_arr.choices[0]).array,
) )
assert deserialized_response.is_cached() is True
assert deserialized_response._request_params == {"request": "yoyo"}
def test_get_results() -> None: def test_get_results(
model_choice: ModelChoices,
model_choice_single: ModelChoices,
model_choice_arr: ModelChoices,
request_lm: LMRequest,
request_array: EmbeddingRequest,
) -> None:
"""Test response get results.""" """Test response get results."""
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"}) response = Response(
assert response.get_response() == "hello" response=model_choice_single,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="text",
)
assert response.get_response() == "helloo"
assert response.get_response(stop_token="ll") == "he" assert response.get_response(stop_token="ll") == "he"
assert response.get_response(stop_token="ll", is_batch=True) == ["he"] assert response.get_response(stop_token="ll", is_batch=True) == ["he"]
response = Response( response = Response(
{"choices": [{"text": "hello"}, {"text": "my"}, {"text": "name"}]}, response=model_choice,
True, cached=False,
{"request": "yoyo"}, request=request_lm,
usages=None,
request_type=LMRequest,
response_type="text",
) )
assert response.get_response() == ["hello", "my", "name"] assert response.get_response() == ["hello", "bye"]
assert response.get_response(stop_token="m") == ["hello", "", "na"] assert response.get_response(stop_token="b") == ["hello", ""]
assert response.get_response(stop_token="m", is_batch=True) == ["hello", "", "na"] assert response.get_response(stop_token="y", is_batch=True) == ["hello", "b"]
float_arr = np.random.randn(4, 4) float_arr1 = cast(ArrayModelChoice, model_choice_arr.choices[0]).array
float_arr2 = cast(ArrayModelChoice, model_choice_arr.choices[1]).array
response = Response( response = Response(
{"choices": [{"array": float_arr}, {"array": float_arr}]}, response=model_choice_arr,
True, cached=False,
{"request": "yoyo"}, request=request_array,
item_key="array", usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
request_type=EmbeddingRequest,
response_type="array",
) )
assert response.get_response() == [float_arr, float_arr] assert np.array_equal(response.get_response()[0], float_arr1)
assert response.get_response(stop_token="m") == [float_arr, float_arr] assert np.array_equal(response.get_response()[1], float_arr2)
assert np.array_equal(response.get_response(stop_token="t")[0], float_arr1)
assert np.array_equal(response.get_response(stop_token="t")[1], float_arr2)
def test_union_all() -> None: def test_union_all(
model_choice: ModelChoices,
model_choice_single: ModelChoices,
request_lm: LMRequest,
request_lm_single: LMRequest,
) -> None:
"""Test union all.""" """Test union all."""
request_paramsa = LMRequest(prompt=["apple", "orange", "pear"]).to_dict() response1 = Response(
request_paramsa["model"] = "modelA" response=model_choice,
response_paramsa = { cached=False,
"choices": [ request=request_lm,
{"text": "hello", "token_logprobs": [1]}, usages=None,
{"text": "hello 2", "token_logprobs": [1]}, request_type=LMRequest,
{"text": "hello 3", "token_logprobs": [1]}, response_type="text",
] )
}
responsea = Response(response_paramsa, False, request_paramsa)
request_paramsb = LMRequest(prompt=["banana", "pineapple", "mango"]).to_dict() response2 = Response(
request_paramsb["model"] = "modelB" response=model_choice_single,
response_paramsb = { cached=False,
"choices": [ request=request_lm_single,
{"text": "bye", "token_logprobs": [2]}, usages=None,
{"text": "bye 2", "token_logprobs": [2]}, request_type=LMRequest,
{"text": "bye 3", "token_logprobs": [2]}, response_type="text",
] )
}
responseb = Response(response_paramsb, False, request_paramsb)
final_response = Response.union_all([responsea, responseb]) final_response = Response.union_all([response1, response2])
assert final_response.get_json_response() == { assert final_response.get_json_response() == {
"choices": [ "choices": [
{"text": "hello", "token_logprobs": [1]}, {"text": "hello", "token_logprobs": [0.1, 0.2], "tokens": None},
{"text": "hello 2", "token_logprobs": [1]}, {"text": "bye", "token_logprobs": [0.3], "tokens": None},
{"text": "hello 3", "token_logprobs": [1]}, {"text": "helloo", "token_logprobs": [0.1, 0.2], "tokens": None},
{"text": "bye", "token_logprobs": [2]},
{"text": "bye 2", "token_logprobs": [2]},
{"text": "bye 3", "token_logprobs": [2]},
] ]
} }
final_request = LMRequest( assert final_response.get_usage_obj() == Usages(usages=[Usage(), Usage(), Usage()])
prompt=["apple", "orange", "pear", "banana", "pineapple", "mango"] merged_prompts: List[str] = request_lm.prompt + [request_lm_single.prompt] # type: ignore # noqa: E501
).to_dict() assert final_response.get_request_obj().prompt == merged_prompts
final_request["model"] = "modelA" assert final_response.get_request_obj().engine == "dummy::text-ada-001"
assert final_response.get_request() == final_request
assert not final_response.is_cached()
# Modify A to have usage and cached # Modify A to have usage and cached
response_paramsa_2: Dict[str, Any] = { response1 = Response(
"choices": [ response=model_choice,
{"text": "hello", "token_logprobs": [1]}, cached=False,
{"text": "hello 2", "token_logprobs": [1]}, request=request_lm,
{"text": "hello 3", "token_logprobs": [1]}, usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
], request_type=LMRequest,
"usage": [ response_type="text",
{"completion_tokens": 10}, )
{"completion_tokens": 10},
{"completion_tokens": 10},
],
}
responsea = Response(response_paramsa_2, True, request_paramsa)
final_response = Response.union_all([responsea, responseb]) final_response = Response.union_all([response1, response2])
assert final_response.get_json_response() == { assert final_response.get_usage_obj() == Usages(
"choices": [ usages=[Usage(total_tokens=4), Usage(total_tokens=6), Usage()]
{"text": "hello", "token_logprobs": [1]}, )
{"text": "hello 2", "token_logprobs": [1]},
{"text": "hello 3", "token_logprobs": [1]},
{"text": "bye", "token_logprobs": [2]},
{"text": "bye 2", "token_logprobs": [2]},
{"text": "bye 3", "token_logprobs": [2]},
],
"usage": [
{"completion_tokens": 10},
{"completion_tokens": 10},
{"completion_tokens": 10},
{},
{},
{},
],
}
final_request = LMRequest(
prompt=["apple", "orange", "pear", "banana", "pineapple", "mango"]
).to_dict()
final_request["model"] = "modelA"
assert final_response.get_request() == final_request
assert final_response.is_cached()

@ -10,23 +10,23 @@ def test_response_to_key_array() -> None:
"""Test array serializer initialization.""" """Test array serializer initialization."""
serializer = ArraySerializer() serializer = ArraySerializer()
arr = np.random.rand(4, 4) arr = np.random.rand(4, 4)
res = {"choices": [{"array": arr}]} res = {"response": {"choices": [{"array": arr}]}}
key = serializer.response_to_key(res) key = serializer.response_to_key(res)
key_dct = json.loads(key) key_dct = json.loads(key)
assert isinstance(key_dct["choices"][0]["array"], str) assert isinstance(key_dct["response"]["choices"][0]["array"], str)
res2 = serializer.key_to_response(key) res2 = serializer.key_to_response(key)
assert np.allclose(arr, res2["choices"][0]["array"]) assert np.allclose(arr, res2["response"]["choices"][0]["array"])
def test_response_to_key_numpybytes() -> None: def test_response_to_key_numpybytes() -> None:
"""Test array serializer initialization.""" """Test array serializer initialization."""
serializer = NumpyByteSerializer() serializer = NumpyByteSerializer()
arr = np.random.rand(4, 4) arr = np.random.rand(4, 4)
res = {"choices": [{"array": arr}]} res = {"response": {"choices": [{"array": arr}]}}
key = serializer.response_to_key(res) key = serializer.response_to_key(res)
key_dct = json.loads(key) key_dct = json.loads(key)
assert isinstance(key_dct["choices"][0]["array"], str) assert isinstance(key_dct["response"]["choices"][0]["array"], str)
res2 = serializer.key_to_response(key) res2 = serializer.key_to_response(key)
assert np.allclose(arr, res2["choices"][0]["array"]) assert np.allclose(arr, res2["response"]["choices"][0]["array"])

@ -49,7 +49,7 @@ def prompt_manifest(*, manifest_in: schemas.ManifestCreate) -> Dict:
return { return {
"response": response.get_response(), "response": response.get_response(),
"cached": response.is_cached(), "cached": response.is_cached(),
"request_params": response.get_request(), "request_params": response.get_request_obj(),
} }

Loading…
Cancel
Save