fix: added pydantic types to response (#84)

pull/85/head
Laurel Orr 1 year ago committed by GitHub
parent 4602fb919b
commit d7401c6ec5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,6 +8,7 @@ Added
Fixed
^^^^^
* Determine cache and response by request type, not client name
* Refactor Response to use Pydantic types for Request and Response
0.1.1
---------------------

@ -126,9 +126,9 @@ results = asyncio.run(manifest.arun_batch(["Where are the cats?", "Where are the
If something doesn't go right, you can also ask to get a raw manifest Response.
```python
result_object = manifest.run(["Where are the cats?", "Where are the dogs?"], return_response=True)
print(result_object.get_request())
print(result_object.get_request_obj())
print(result_object.is_cached())
print(result_object.get_json_response())
print(result_object.get_response_obj())
```
By default, we do not truncate results based on a stop token. You can change this by either passing a new stop token to a Manifest session or to a `run`.

@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -32,7 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -56,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@ -89,7 +89,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -100,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -131,23 +131,23 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"For loop: 229.93\n",
"For loop: 128.68\n",
"Running with async single client\n",
"Running 1 tasks across all clients.\n",
"Async loop: 1.39\n",
"Async loop: 4.02\n",
"Running with async two clients but not chunking\n",
"Running 1 tasks across all clients.\n",
"Async loop: 1.65\n",
"Async loop: 3.92\n",
"Running with async two clients and chunk size\n",
"Running 20 tasks across all clients.\n",
"Async loop: 0.64\n"
"Async loop: 1.44\n"
]
}
],

@ -4,7 +4,7 @@ from typing import Any, Dict, Type, Union
from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer, Serializer
from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest, Request
from manifest.response import RESPONSE_CONSTRUCTORS, Response
from manifest.response import Response
# Non-text return type caches
ARRAY_CACHE_TYPES = {EmbeddingRequest, DiffusionRequest}
@ -119,14 +119,9 @@ class Cache(ABC):
key = self.serializer.request_to_key(request)
cached_response = self.get_key(key)
if cached_response:
cached = True
response = self.serializer.key_to_response(cached_response)
return Response(
response,
cached,
request,
**RESPONSE_CONSTRUCTORS.get(self.request_type, {}),
)
response["cached"] = True
return Response.from_dict(response, request_dict=request)
return None
def set(self, request: Dict, response: Dict) -> None:

@ -77,14 +77,15 @@ class NumpyByteSerializer(Serializer):
Returns:
normalized key.
"""
sub_response = response["response"]
# Assume response is a dict with keys "choices" -> List dicts
# with keys "array".
choices = response["choices"]
choices = sub_response["choices"]
# We don't want to modify the response in place
# but we want to avoid calling deepcopy on an array
del response["choices"]
response_copy = response.copy()
response["choices"] = choices
del sub_response["choices"]
response_copy = sub_response.copy()
sub_response["choices"] = choices
response_copy["choices"] = []
for choice in choices:
if "array" not in choice:
@ -101,7 +102,8 @@ class NumpyByteSerializer(Serializer):
hash_str = f.getvalue().hex()
new_choice["array"] = hash_str
response_copy["choices"].append(new_choice)
return json.dumps(response_copy, sort_keys=True)
response["response"] = response_copy
return json.dumps(response, sort_keys=True)
def key_to_response(self, key: str) -> Dict:
"""
@ -114,7 +116,7 @@ class NumpyByteSerializer(Serializer):
unnormalized response dict.
"""
response = json.loads(key)
for choice in response["choices"]:
for choice in response["response"]["choices"]:
hash_str = choice["array"]
byte_str = bytes.fromhex(hash_str)
with io.BytesIO(byte_str) as f:
@ -152,14 +154,15 @@ class ArraySerializer(Serializer):
Returns:
normalized key.
"""
sub_response = response["response"]
# Assume response is a dict with keys "choices" -> List dicts
# with keys "array".
choices = response["choices"]
choices = sub_response["choices"]
# We don't want to modify the response in place
# but we want to avoid calling deepcopy on an array
del response["choices"]
response_copy = response.copy()
response["choices"] = choices
del sub_response["choices"]
response_copy = sub_response.copy()
sub_response["choices"] = choices
response_copy["choices"] = []
for choice in choices:
if "array" not in choice:
@ -179,7 +182,8 @@ class ArraySerializer(Serializer):
response_copy["choices"].append(new_choice)
if not self.writer.contains_key(hash_str):
self.writer.put(hash_str, arr)
return json.dumps(response_copy, sort_keys=True)
response["response"] = response_copy
return json.dumps(response, sort_keys=True)
def key_to_response(self, key: str) -> Dict:
"""
@ -194,7 +198,7 @@ class ArraySerializer(Serializer):
unnormalized response dict.
"""
response = json.loads(key)
for choice in response["choices"]:
for choice in response["response"]["choices"]:
hash_str = choice["array"]
choice["array"] = self.writer.get(hash_str)
return response

@ -94,7 +94,7 @@ class AI21Client(Client):
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -10,8 +10,21 @@ import aiohttp
import requests
from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential
from manifest.request import DEFAULT_REQUEST_KEYS, NOT_CACHE_KEYS, Request
from manifest.response import RESPONSE_CONSTRUCTORS, Response
from manifest.request import (
DEFAULT_REQUEST_KEYS,
NOT_CACHE_KEYS,
LMScoreRequest,
Request,
)
from manifest.response import (
RESPONSE_CONSTRUCTORS,
ArrayModelChoice,
LMModelChoice,
ModelChoices,
Response,
Usage,
Usages,
)
logger = logging.getLogger(__name__)
@ -161,16 +174,30 @@ class Client(ABC):
request_params = self.get_request_params(request)
for key in NOT_CACHE_KEYS:
request_params.pop(key, None)
# Make sure to add model params and request class
request_params.update(self.get_model_params())
request_params["request_cls"] = request.__class__.__name__
return request_params
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
"""Split usage into list of usages for each prompt."""
return []
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def get_model_choices(self, response: Dict) -> ModelChoices:
"""Format response to ModelChoices."""
# Array or text response
response_type = RESPONSE_CONSTRUCTORS[self.REQUEST_CLS]["response_type"]
if response_type == "array":
choices: List[Union[LMModelChoice, ArrayModelChoice]] = [
ArrayModelChoice(**choice) for choice in response["choices"]
]
else:
choices = [LMModelChoice(**choice) for choice in response["choices"]]
return ModelChoices(choices=choices)
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Validate response as dict.
Args:
response: response
@ -246,7 +273,7 @@ class Client(ABC):
except requests.exceptions.HTTPError:
logger.error(res.json())
raise requests.exceptions.HTTPError(res.json())
return self.format_response(res.json(), request_params)
return self.validate_response(res.json(), request_params)
@retry(
reraise=True,
@ -277,7 +304,7 @@ class Client(ABC):
) as res:
res.raise_for_status()
res_json = await res.json(content_type=None)
return self.format_response(res_json, request_params)
return self.validate_response(res_json, request_params)
def run_request(self, request: Request) -> Response:
"""
@ -301,11 +328,16 @@ class Client(ABC):
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
response_dict = self._run_completion(request_params, retry_timeout)
usages = None
if "usage" in response_dict:
usages = [Usage(**usage) for usage in response_dict["usage"]]
return Response(
response_dict,
response=self.get_model_choices(response_dict),
cached=False,
request_params=request_params,
**RESPONSE_CONSTRUCTORS.get(self.REQUEST_CLS, {}), # type: ignore
request=request,
usages=Usages(usages=usages) if usages else None,
**RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
)
async def arun_batch_request(self, request: Request) -> Response:
@ -353,18 +385,20 @@ class Client(ABC):
if "usage" in res_dict:
usages.extend(res_dict["usage"])
final_response_dict = {"choices": choices}
final_usages = None
if usages:
final_response_dict["usage"] = usages
final_usages = Usages(usages=[Usage(**usage) for usage in usages])
return Response(
final_response_dict,
self.get_model_choices(final_response_dict),
cached=False,
request_params=request_params,
**RESPONSE_CONSTRUCTORS.get(self.REQUEST_CLS, {}), # type: ignore
request=request,
usages=final_usages,
**RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
)
def get_score_prompt_request(
self,
request: Request,
request: LMScoreRequest,
) -> Response:
"""
Get the logit score of the prompt via a forward pass of the model.

@ -93,7 +93,7 @@ class CohereClient(Client):
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -86,7 +86,7 @@ class DiffuserClient(Client):
res["client_name"] = self.NAME
return res
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -3,8 +3,8 @@ import logging
from typing import Any, Dict, Optional
from manifest.clients.client import Client
from manifest.request import LMRequest, Request
from manifest.response import Response
from manifest.request import LMRequest, LMScoreRequest, Request
from manifest.response import LMModelChoice, ModelChoices, Response, Usage, Usages
logger = logging.getLogger(__name__)
@ -86,15 +86,30 @@ class DummyClient(Client):
num_results = 1
request_params = request.to_dict(self.PARAMS)
response_dict = {
"choices": [{"text": "hello"}]
* int(request_params["num_results"])
* num_results,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}]
* int(request_params["num_results"])
* num_results,
}
return Response(response_dict, False, request_params)
return Response(
response=ModelChoices(
choices=[LMModelChoice(text="hello")] # type: ignore
* int(request_params["num_results"])
* num_results
),
cached=False,
request=request,
usages=Usages(
usages=[
Usage(
**{
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
}
)
]
* int(request_params["num_results"])
* num_results
),
response_type="text",
request_type=self.REQUEST_CLS,
)
async def arun_batch_request(self, request: Request) -> Response:
"""
@ -110,7 +125,7 @@ class DummyClient(Client):
def get_score_prompt_request(
self,
request: Request,
request: LMScoreRequest,
) -> Response:
"""
Get the logit score of the prompt via a forward pass of the model.
@ -126,17 +141,27 @@ class DummyClient(Client):
num_results = len(request.prompt)
else:
num_results = 1
request_params = {"prompt": request.prompt}
response_dict = {
"choices": [
{
"text": request.prompt
if isinstance(request.prompt, str)
else request.prompt[i],
"logprob": 0.3,
"token_logprobs": [0.3],
}
for i in range(num_results)
]
}
return Response(response_dict, False, request_params)
return Response(
response=ModelChoices(
choices=[
LMModelChoice(**choice) # type: ignore
for choice in response_dict["choices"]
]
),
cached=False,
request=request,
usages=None,
response_type="text",
request_type=LMScoreRequest,
)

@ -5,8 +5,8 @@ from typing import Any, Dict, Optional
import requests
from manifest.clients.client import Client
from manifest.request import DEFAULT_REQUEST_KEYS, LMRequest, Request
from manifest.response import Response
from manifest.request import DEFAULT_REQUEST_KEYS, LMRequest, LMScoreRequest
from manifest.response import LMModelChoice, ModelChoices, Response
logger = logging.getLogger(__name__)
@ -82,7 +82,7 @@ class HuggingFaceClient(Client):
def get_score_prompt_request(
self,
request: Request,
request: LMScoreRequest,
) -> Response:
"""
Get the logit score of the prompt via a forward pass of the model.
@ -116,4 +116,13 @@ class HuggingFaceClient(Client):
logger.error(res.text)
raise e
response_dict = res.json()
return Response(response_dict, cached=False, request_params=request_params)
return Response(
response=ModelChoices(
choices=[LMModelChoice(**choice) for choice in response_dict["choices"]]
),
cached=False,
request=request,
usages=None,
response_type="text",
request_type=LMScoreRequest,
)

@ -72,7 +72,7 @@ class HuggingFaceEmbeddingClient(Client):
res["client_name"] = self.NAME
return res
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -37,6 +37,7 @@ class OpenAIClient(Client):
"n": ("n", 1),
"top_p": ("top_p", 1.0),
"top_k": ("best_of", 1),
"logprobs": ("logprobs", None),
"stop_sequences": ("stop", None), # OpenAI doesn't like empty lists
"presence_penalty": ("presence_penalty", 0.0),
"frequency_penalty": ("frequency_penalty", 0.0),

@ -76,7 +76,7 @@ class OpenAIEmbeddingClient(OpenAIClient):
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -143,7 +143,7 @@ class TOMAClient(Client):
}
return heartbeats
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -46,7 +46,7 @@ class TOMADiffuserClient(TOMAClient):
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -2,7 +2,7 @@
import asyncio
import copy
import logging
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
import numpy as np
@ -17,8 +17,8 @@ from manifest.connections.client_pool import (
ClientConnection,
ClientConnectionPool,
)
from manifest.request import Request
from manifest.response import Response
from manifest.request import LMScoreRequest, Request
from manifest.response import ModelChoices, Response, Usage, Usages
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
@ -178,82 +178,72 @@ class Manifest:
number_prompts = len(cached_idx_to_response)
single_output = False
if response:
if isinstance(response.get_request()["prompt"], str):
if isinstance(response.get_request_obj().prompt, str):
single_output = True
number_prompts += 1
else:
number_prompts += len(response.get_request()["prompt"])
response_gen_key = None
response_logits_key = None
response_item_key = None
number_prompts += len(response.get_request_obj().prompt)
response_type = None
request_type: Type[Request] = None
for idx in range(number_prompts):
if idx in cached_idx_to_response:
cached_res = cached_idx_to_response[idx]
response_gen_key = cached_res.generation_key
response_logits_key = cached_res.logits_key
response_item_key = cached_res.item_key
response_usage_key = cached_res.usage_key
all_input_prompts.append(cached_res.get_request()["prompt"])
json_response = cached_res.get_json_response()
response_type = cached_res._response_type
request_type = cached_res._request_type
all_input_prompts.append(cached_res.get_request_obj().prompt)
if request.n == 1:
assert (
len(json_response[response_gen_key]) == 1
len(cached_res.get_response_obj().choices) == 1
), "cached response should have only one choice"
all_model_choices.extend(json_response[response_gen_key])
if response_usage_key:
all_usages.extend(json_response[response_usage_key])
all_model_choices.extend(cached_res.get_response_obj().choices)
if cached_res.get_usage_obj().usages:
all_usages.extend(cached_res.get_usage_obj().usages)
else:
assert response is not None, "response should not be None"
response = cast(Response, response)
response_gen_key = response.generation_key
response_logits_key = response.logits_key
response_item_key = response.item_key
response_usage_key = response.usage_key
response_type = response._response_type
request_type = response._request_type
# the choices list in the response is a flat one.
# length is request.n * num_prompts
current_choices = response.get_json_response()[response_gen_key][
current_choices = response.get_response_obj().choices[
response_idx * request.n : (response_idx + 1) * request.n
]
all_model_choices.extend(current_choices)
if isinstance(response.get_request()["prompt"], list):
prompt = response.get_request()["prompt"][response_idx]
if isinstance(response.get_request_obj().prompt, list):
prompt = response.get_request_obj().prompt[response_idx]
else:
prompt = str(response.get_request()["prompt"])
if response_usage_key:
usage = response.get_json_response()[response_usage_key][
prompt = str(response.get_request_obj().prompt)
usages: Optional[List[Usage]] = None
if response.get_usage_obj().usages:
usages = response.get_usage_obj().usages[
response_idx * request.n : (response_idx + 1) * request.n
]
all_usages.extend(usage)
all_usages.extend(usages)
all_input_prompts.append(prompt)
# set cache
new_request = copy.deepcopy(request)
new_request.prompt = prompt
cache_key = client.get_cache_key(new_request)
new_response_key = copy.deepcopy(response.get_json_response())
new_response_key[response_gen_key] = current_choices
if response_usage_key:
new_response_key[response_usage_key] = usage
self.cache.set(cache_key, new_response_key)
new_response = copy.deepcopy(response)
new_response._response.choices = current_choices
new_response._usages = Usages(usages=(usages or []))
self.cache.set(cache_key, new_response.to_dict(drop_request=True))
response_idx += 1
new_request = copy.deepcopy(request)
new_request.prompt = (
all_input_prompts
all_input_prompts # type: ignore
if len(all_input_prompts) > 1 or not single_output
else all_input_prompts[0]
)
new_response = {response_gen_key: all_model_choices}
if response_usage_key:
new_response[response_usage_key] = all_usages
response_obj = Response(
new_response,
response=ModelChoices(choices=all_model_choices),
cached=len(cached_idx_to_response) > 0,
request_params=client.get_cache_key(new_request),
generation_key=response_gen_key,
logits_key=response_logits_key,
item_key=response_item_key,
usage_key=response_usage_key,
request=new_request,
usages=Usages(usages=all_usages),
response_type=response_type,
request_type=request_type,
)
return response_obj
@ -457,20 +447,20 @@ class Manifest:
client = self.client_pool.get_client()
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = client.get_request(prompt, kwargs)
request_params.request_type = "score_prompt"
request_params_as_score = LMScoreRequest(**request_params.to_dict())
if request_params.n > 1:
if request_params_as_score.n > 1:
raise ValueError("Sequence scoring does not support n > 1.")
self._validate_kwargs(kwargs, request_params)
self._validate_kwargs(kwargs, request_params_as_score)
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, client, overwrite_cache
cached_idx_to_response, request_params_as_score = self._split_cached_requests( # type: ignore # noqa: E501
request_params_as_score, client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params.prompt:
if request_params_as_score.prompt:
try:
response = cast(HuggingFaceClient, client).get_score_prompt_request(
request_params
request_params_as_score
)
except AttributeError:
raise ValueError("`score_prompt` only supported for HF models.")
@ -479,7 +469,7 @@ class Manifest:
response = None
final_response = self._stitch_responses_and_cache(
request=request_params,
request=request_params_as_score,
client=client,
response=response,
cached_idx_to_response=cached_idx_to_response,

@ -3,13 +3,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
# Used when unioning requests after async connection pool
ENGINE_SEP = "::"
NOT_CACHE_KEYS = {"client_timeout", "batch_size"}
# The below should match those in Request.
DEFAULT_REQUEST_KEYS = {
"client_timeout": ("client_timeout", 60), # seconds
"batch_size": ("batch_size", 8),
"run_id": ("run_id", None),
"request_type": ("request_type", None),
}
@ -34,9 +35,6 @@ class Request(BaseModel):
# Batch size for async batch run
batch_size: int = 8
# Request type None is for completion. Used for scoring prompt
request_type: str = None
def to_dict(
self, allowable_keys: Dict[str, Tuple[str, Any]] = None, add_prompt: bool = True
) -> Dict[str, Any]:
@ -78,6 +76,9 @@ class LMRequest(Request):
# Top k sampling taking top_k highest probability tokens
top_k: int = 50
# Logprobs return value
logprobs: Optional[int] = None
# Stop sequences
stop_sequences: Optional[List[str]] = None
@ -100,6 +101,12 @@ class LMRequest(Request):
frequency_penalty: float = 0
class LMScoreRequest(LMRequest):
"""Language Model Score Request object."""
pass
class EmbeddingRequest(Request):
"""Embedding Request object."""
@ -109,9 +116,6 @@ class EmbeddingRequest(Request):
class DiffusionRequest(Request):
"""Diffusion Model Request object."""
# Request type
request_type: str = "diffusion"
# Number of steps
num_inference_steps: int = 50

@ -1,21 +1,25 @@
"""Client response."""
import copy
import json
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Type, Union, cast
import numpy as np
from manifest.request import DiffusionRequest, EmbeddingRequest
RESPONSE_CONSTRUCTORS = {
EmbeddingRequest: {
"logits_key": "token_logprobs",
"item_key": "array",
},
DiffusionRequest: {
"logits_key": "token_logprobs",
"item_key": "array",
},
from pydantic import BaseModel
from manifest.request import (
ENGINE_SEP,
DiffusionRequest,
EmbeddingRequest,
LMRequest,
LMScoreRequest,
Request,
)
RESPONSE_CONSTRUCTORS: Dict[Type[Request], Dict[str, Union[str, Type[Request]]]] = {
LMRequest: {"response_type": "text", "request_type": LMRequest},
LMScoreRequest: {"response_type": "text", "request_type": LMScoreRequest},
EmbeddingRequest: {"response_type": "array", "request_type": EmbeddingRequest},
DiffusionRequest: {"response_type": "array", "request_type": DiffusionRequest},
}
@ -29,94 +33,114 @@ class NumpyArrayEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, obj)
class Usage(BaseModel):
"""Prompt usage class."""
completion_tokens: int = 0
prompt_tokens: int = 0
total_tokens: int = 0
class Usages(BaseModel):
"""Prompt usage class."""
usages: List[Usage]
class LMModelChoice(BaseModel):
"""Model single completion."""
text: str
token_logprobs: Optional[List[float]] = None
tokens: Optional[List[int]] = None
class ArrayModelChoice(BaseModel):
"""Model single completion."""
array: np.ndarray
token_logprobs: Optional[List[float]] = None
class Config:
"""Pydantic config class."""
arbitrary_types_allowed = True
class ModelChoices(BaseModel):
"""Model choices."""
choices: List[Union[LMModelChoice, ArrayModelChoice]]
class Response:
"""Response class."""
def __init__(
self,
response: Dict, # TODO: make pydantic model
response: ModelChoices,
cached: bool,
request_params: Dict, # TODO: use request pydantic model
generation_key: str = "choices",
logits_key: str = "token_logprobs",
item_key: str = "text",
usage_key: str = "usage",
request: Request,
response_type: str,
request_type: Type[Request],
usages: Optional[Usages] = None,
):
"""
Initialize response.
Args:
response: response dict.
usages: usage dict.
cached: whether response is cached.
request_params: request parameters.
generation_key: key for generation results.
logits_key: key for logits.
item_key: key for item in the generations.
request: request.
response_type: response type.
request_type: request type.
"""
self.generation_key = generation_key
self.logits_key = logits_key
self.item_key = item_key
self.usage_key = usage_key
self.item_dtype = None
if isinstance(response, dict):
self._response = response
else:
raise ValueError(f"Response must be dict. Response is\n{response}.")
if (
(self.generation_key not in self._response)
or (not isinstance(self._response[self.generation_key], list))
or (len(self._response[self.generation_key]) <= 0)
):
raise ValueError(
"Response must be serialized to a dict with a nonempty"
f" list of choices. Response is\n{self._response}."
)
# Turn off usage if it is not in response
if self.usage_key not in self._response:
self.usage_key = None
else:
if not isinstance(self._response[self.usage_key], list):
raise ValueError(
"Response must be a list with usage dicts, one per choice."
f" Response is\n{self._response}."
)
if self.item_key not in self._response[self.generation_key][0]:
raise ValueError(
"Response must be serialized to a dict with a "
f"list of choices with {self.item_key} field"
)
if (
self.logits_key in self._response[self.generation_key][0]
and self._response[self.generation_key][0][self.logits_key]
):
if not isinstance(
self._response[self.generation_key][0][self.logits_key], list
):
raise ValueError(
f"{self.logits_key} must be a list of items "
"one for each token in the choice."
)
if isinstance(
self._response[self.generation_key][0][self.item_key], np.ndarray
):
self.item_dtype = str(
self._response[self.generation_key][0][self.item_key].dtype
)
self._item_dtype = None
self._response_type = response_type
if self._response_type not in {"array", "text"}:
raise ValueError(f"Invalid response type {self._response_type}")
self._request_type = request_type
self._response = response
self._usages = usages or Usages(usages=[])
self._cached = cached
self._request_params = request_params
self._request = request
if self._response.choices:
if response_type == "array":
if not isinstance(self._response.choices[0], ArrayModelChoice):
raise ValueError(
"response_type is array but response is "
f"{self._response.choices[0].__class__}"
)
self._item_dtype = str(
cast(ArrayModelChoice, self._response.choices[0]).array.dtype
)
else:
if not isinstance(self._response.choices[0], LMModelChoice):
raise ValueError(
"response_type is text but response is "
f"{self._response.choices[0].__class__}"
)
def is_cached(self) -> bool:
"""Check if response is cached."""
return self._cached
def get_request(self) -> Dict:
def get_request_obj(self) -> Request:
"""Get request parameters."""
return self._request_params
return self._request
def get_response_obj(self) -> ModelChoices:
"""Get response object."""
return self._response
def get_usage_obj(self) -> Usages:
"""Get usage object."""
return self._usages
def get_json_response(self) -> Dict:
"""Get response dict without parsing."""
return self._response
return self._response.dict()
def get_response(
self, stop_token: str = "", is_batch: bool = False
@ -132,7 +156,8 @@ class Response:
lambda x: x.strip().split(stop_token)[0] if stop_token else x.strip()
)
extracted_items = [
choice[self.item_key] for choice in self._response[self.generation_key]
choice.text if isinstance(choice, LMModelChoice) else choice.array
for choice in self._response.choices
]
if len(extracted_items) == 0:
return None
@ -153,25 +178,15 @@ class Response:
if len(responses) == 1:
return responses[0]
first_response = responses[0]
generation_key = first_response.generation_key
logits_key = first_response.logits_key
item_key = first_response.item_key
# Usage key may be None, so get first not-None value
possible_usage_keys = [r.usage_key for r in responses if r.usage_key]
if possible_usage_keys:
usage_key = possible_usage_keys[0]
else:
usage_key = None
request = first_response._request_params
request_type = first_response._request_type
response_type = first_response._response_type
request = first_response.get_request_obj()
# Make sure all responses have the same keys
if not all(
[
(r.generation_key == generation_key)
and (r.logits_key == logits_key)
and (r.item_key == item_key)
# Usage key can be empty
and (not r.usage_key or not usage_key or r.usage_key == usage_key)
(r._request_type == request_type)
and (r._response_type == response_type)
for r in responses
]
):
@ -181,33 +196,31 @@ class Response:
all_prompts = []
all_choices = []
all_usages = []
all_engines = []
for res in responses:
json_response = res.get_json_response()
res_prompt = res.get_request()["prompt"]
all_engines.extend(res.get_request_obj().engine.split(ENGINE_SEP))
res_prompt = res.get_request_obj().prompt
if isinstance(res_prompt, str):
res_prompt = [res_prompt]
all_prompts.extend(res_prompt)
all_choices.extend(json_response[generation_key])
if usage_key and usage_key in json_response:
all_usages.extend(json_response[usage_key])
all_choices.extend(res.get_response_obj().choices)
if res.get_usage_obj().usages:
all_usages.extend(res.get_usage_obj().usages)
else:
# Add empty usage
all_usages.extend([{}] * len(res_prompt))
# Add empty usages if not present
all_usages.extend([Usage()] * len(res_prompt))
new_request = copy.deepcopy(request)
# TODO: add both models back in request. This should be a lot
# easier after I pydantic the response and request more formally
new_request["prompt"] = all_prompts
new_response = {generation_key: all_choices}
if usage_key:
new_response[usage_key] = all_usages
new_request.engine = ENGINE_SEP.join(sorted(set(all_engines)))
new_request.prompt = all_prompts
new_response = ModelChoices(choices=all_choices)
new_usages = Usages(usages=all_usages)
response_obj = cls(
new_response,
response=new_response,
cached=any(res.is_cached() for res in responses),
request_params=new_request,
generation_key=generation_key,
logits_key=logits_key,
item_key=item_key,
usage_key=usage_key,
request=new_request,
usages=new_usages,
request_type=request_type,
response_type=response_type,
)
return response_obj
@ -232,56 +245,74 @@ class Response:
serialized response.
"""
deserialized = json.loads(value)
item_dtype = deserialized["item_dtype"]
if item_dtype:
for choice in deserialized["response"][deserialized["generation_key"]]:
choice[deserialized["item_key"]] = np.array(
choice[deserialized["item_key"]]
).astype(item_dtype)
return cls(
deserialized["response"],
deserialized["cached"],
deserialized["request_params"],
generation_key=deserialized["generation_key"],
logits_key=deserialized["logits_key"],
item_key=deserialized["item_key"],
)
return cls.from_dict(deserialized)
def to_dict(self) -> Dict:
def to_dict(self, drop_request: bool = False) -> Dict:
"""
Get dictionary representation of response.
Returns:
dictionary representation of response.
"""
return {
"generation_key": self.generation_key,
"logits_key": self.logits_key,
"item_key": self.item_key,
"item_dtype": self.item_dtype,
"response": self._response,
to_return = {
"response": self._response.dict(),
"usages": self._usages.dict(),
"cached": self._cached,
"request_params": self._request_params,
"request": self._request.dict(),
"response_type": self._response_type,
"request_type": str(self._request_type.__name__),
"item_dtype": self._item_dtype,
}
if drop_request:
to_return.pop("request")
return to_return
@classmethod
def from_dict(cls, response: Dict) -> "Response":
def from_dict(
cls, response_dict: Dict, request_dict: Optional[Dict] = None
) -> "Response":
"""
Create response from dictionary.
Args:
response: dictionary representation of response.
request_dict: dictionary representation of request which
will override what is in response_dict.
Returns:
response.
"""
if "request" not in response_dict and request_dict is None:
raise ValueError(
"Request dictionary must be provided if "
"request is not in response dictionary."
)
item_dtype = response_dict["item_dtype"]
response_type = response_dict["response_type"]
if response_dict["request_type"] == "LMRequest":
request_type: Type[Request] = LMRequest
elif response_dict["request_type"] == "LMScoreRequest":
request_type = LMScoreRequest
elif response_dict["request_type"] == "EmbeddingRequest":
request_type = EmbeddingRequest
elif response_dict["request_type"] == "DiffusionRequest":
request_type = DiffusionRequest
choices: List[Union[LMModelChoice, ArrayModelChoice]] = []
if item_dtype and response_type == "array":
for choice in response_dict["response"]["choices"]:
choice["array"] = np.array(choice["array"]).astype(item_dtype)
choices.append(ArrayModelChoice(**choice))
else:
for choice in response_dict["response"]["choices"]:
choices.append(LMModelChoice(**choice))
response = ModelChoices(choices=choices)
return cls(
response["response"],
response["cached"],
response["request_params"],
generation_key=response["generation_key"],
logits_key=response["logits_key"],
item_key=response["item_key"],
response=response,
usages=Usages(**response_dict["usages"]),
cached=response_dict["cached"],
request=request_type(**(request_dict or response_dict["request"])),
response_type=response_type,
request_type=request_type,
)
def __str__(self) -> str:

@ -4,9 +4,94 @@ import shutil
from pathlib import Path
from typing import Generator
import numpy as np
import pytest
import redis
from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest
from manifest.response import ArrayModelChoice, LMModelChoice, ModelChoices
@pytest.fixture
def model_choice() -> ModelChoices:
"""Get dummy model choice."""
model_choices = ModelChoices(
choices=[
LMModelChoice(text="hello", token_logprobs=[0.1, 0.2]),
LMModelChoice(text="bye", token_logprobs=[0.3]),
]
)
return model_choices
@pytest.fixture
def model_choice_single() -> ModelChoices:
"""Get dummy model choice."""
model_choices = ModelChoices(
choices=[
LMModelChoice(text="helloo", token_logprobs=[0.1, 0.2]),
]
)
return model_choices
@pytest.fixture
def model_choice_arr() -> ModelChoices:
"""Get dummy model choice."""
np.random.seed(0)
model_choices = ModelChoices(
choices=[
ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.1, 0.2]),
ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.3]),
]
)
return model_choices
@pytest.fixture
def model_choice_arr_int() -> ModelChoices:
"""Get dummy model choice."""
np.random.seed(0)
model_choices = ModelChoices(
choices=[
ArrayModelChoice(
array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.1, 0.2]
),
ArrayModelChoice(
array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.3]
),
]
)
return model_choices
@pytest.fixture
def request_lm() -> LMRequest:
"""Get dummy request."""
request = LMRequest(prompt=["what", "cat"])
return request
@pytest.fixture
def request_lm_single() -> LMRequest:
"""Get dummy request."""
request = LMRequest(prompt="monkey", engine="dummy")
return request
@pytest.fixture
def request_array() -> EmbeddingRequest:
"""Get dummy request."""
request = EmbeddingRequest(prompt="hello")
return request
@pytest.fixture
def request_diff() -> DiffusionRequest:
"""Get dummy request."""
request = DiffusionRequest(prompt="hello")
return request
@pytest.fixture
def sqlite_cache(tmp_path: Path) -> Generator[str, None, None]:

@ -12,6 +12,7 @@ from manifest.caches.postgres import PostgresCache
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
from manifest.request import DiffusionRequest, LMRequest, Request
from manifest.response import ArrayModelChoice, ModelChoices, Response
def _get_postgres_cache(
@ -78,7 +79,16 @@ def test_key_get_and_set(
@pytest.mark.usefixtures("postgres_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis", "postgres"])
def test_get(
sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str
sqlite_cache: str,
redis_cache: str,
postgres_cache: str,
cache_type: str,
model_choice: ModelChoices,
model_choice_single: ModelChoices,
model_choice_arr_int: ModelChoices,
request_lm: LMRequest,
request_lm_single: LMRequest,
request_diff: DiffusionRequest,
) -> None:
"""Test cache save prompt."""
if cache_type == "sqlite":
@ -88,22 +98,51 @@ def test_get(
elif cache_type == "postgres":
cache = cast(Cache, _get_postgres_cache())
test_request = {"test": "hello", "testA": "world"}
test_response = {"choices": [{"text": "hello"}]}
response = Response(
response=model_choice_single,
cached=False,
request=request_lm_single,
usages=None,
request_type=LMRequest,
response_type="text",
)
response = cache.get(test_request)
assert response is None
cache_response = cache.get(request_lm_single.dict())
assert cache_response is None
cache.set(request_lm_single.dict(), response.to_dict(drop_request=True))
cache_response = cache.get(request_lm_single.dict())
assert cache_response.get_response() == "helloo"
assert cache_response.is_cached()
assert cache_response.get_request_obj() == request_lm_single
response = Response(
response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="text",
)
cache.set(test_request, test_response)
response = cache.get(test_request)
assert response.get_response() == "hello"
assert response.is_cached()
assert response.get_request() == test_request
cache_response = cache.get(request_lm.dict())
assert cache_response is None
cache.set(request_lm.dict(), response.to_dict(drop_request=True))
cache_response = cache.get(request_lm.dict())
assert cache_response.get_response() == ["hello", "bye"]
assert cache_response.is_cached()
assert cache_response.get_request_obj() == request_lm
# Test array
arr = np.random.rand(4, 4)
test_request = {"test": "hello", "testA": "world of images"}
compute_arr_response = {"choices": [{"array": arr}]}
response = Response(
response=model_choice_arr_int,
cached=False,
request=request_diff,
usages=None,
request_type=DiffusionRequest,
response_type="array",
)
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, request_type=DiffusionRequest)
@ -112,103 +151,34 @@ def test_get(
elif cache_type == "postgres":
cache = _get_postgres_cache(request_type=DiffusionRequest)
response = cache.get(test_request)
assert response is None
cache_response = cache.get(request_diff.dict())
assert cache_response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response(), arr)
assert response.is_cached()
assert response.get_request() == test_request
cache.set(request_diff.dict(), response.to_dict(drop_request=True))
cached_response = cache.get(request_diff.dict())
assert np.allclose(
cached_response.get_response()[0],
cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array,
)
assert np.allclose(
cached_response.get_response()[1],
cast(ArrayModelChoice, model_choice_arr_int.choices[1]).array,
)
assert cached_response.is_cached()
assert cached_response.get_request_obj() == request_diff
# Test array byte string
arr = np.random.rand(4, 4)
test_request = {"test": "hello", "testA": "world of images 2"}
compute_arr_response = {"choices": [{"array": arr}]}
if cache_type == "sqlite":
cache = SQLiteCache(
sqlite_cache,
request_type=DiffusionRequest,
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "redis":
cache = RedisCache(
redis_cache,
request_type=DiffusionRequest,
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "postgres":
cache = _get_postgres_cache(
request_type=DiffusionRequest,
cache_args={"array_serializer": "byte_string"},
)
response = cache.get(test_request)
assert response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response(), arr)
assert response.is_cached()
assert response.get_request() == test_request
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.usefixtures("postgres_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis", "postgres"])
def test_get_batch_prompt(
sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str
) -> None:
"""Test cache save prompt."""
if cache_type == "sqlite":
cache = cast(Cache, SQLiteCache(sqlite_cache))
elif cache_type == "redis":
cache = cast(Cache, RedisCache(redis_cache))
elif cache_type == "postgres":
cache = cast(Cache, _get_postgres_cache())
test_request = {"test": ["hello", "goodbye"], "testA": "world"}
test_response = {"choices": [{"text": "hello"}, {"text": "goodbye"}]}
response = cache.get(test_request)
assert response is None
cache.set(test_request, test_response)
response = cache.get(test_request)
assert response.get_response() == ["hello", "goodbye"]
assert response.is_cached()
assert response.get_request() == test_request
# Test arrays
arr = np.random.rand(4, 4)
arr2 = np.random.rand(4, 4)
test_request = {"test": ["hello", "goodbye"], "testA": "world of images"}
compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]}
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, request_type=DiffusionRequest)
elif cache_type == "redis":
cache = RedisCache(redis_cache, request_type=DiffusionRequest)
elif cache_type == "postgres":
cache = _get_postgres_cache(request_type=DiffusionRequest)
response = cache.get(test_request)
assert response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response()[0], arr)
assert np.allclose(response.get_response()[1], arr2)
assert response.is_cached()
assert response.get_request() == test_request
# Test arrays byte serializer
arr = np.random.rand(4, 4)
arr2 = np.random.rand(4, 4)
test_request = {"test": ["hello", "goodbye"], "testA": "world of images 2"}
compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]}
# Make sure to not hit the cache
new_request_diff = DiffusionRequest(**request_diff.dict())
new_request_diff.prompt = ["blahhh", "yayayay"]
response = Response(
response=model_choice_arr_int,
cached=False,
request=new_request_diff,
usages=None,
request_type=DiffusionRequest,
response_type="array",
)
if cache_type == "sqlite":
cache = SQLiteCache(
@ -228,15 +198,21 @@ def test_get_batch_prompt(
cache_args={"array_serializer": "byte_string"},
)
response = cache.get(test_request)
assert response is None
cached_response = cache.get(new_request_diff.dict())
assert cached_response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response()[0], arr)
assert np.allclose(response.get_response()[1], arr2)
assert response.is_cached()
assert response.get_request() == test_request
cache.set(new_request_diff.dict(), response.to_dict(drop_request=True))
cached_response = cache.get(new_request_diff.dict())
assert np.allclose(
cached_response.get_response()[0],
cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array,
)
assert np.allclose(
cached_response.get_response()[1],
cast(ArrayModelChoice, model_choice_arr_int.choices[1]).array,
)
assert cached_response.is_cached()
assert cached_response.get_request_obj() == new_request_diff
def test_noop_cache() -> None:

@ -33,10 +33,13 @@ def test_get_request() -> None:
"prompt": "hello",
"num_results": 3,
"engine": "dummy",
"request_cls": "LMRequest",
}
assert response.get_json_response() == {
"choices": [{"text": "hello"}] * 3,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 3,
"choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 3,
}
assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 3,
}
request_params = client.get_request("hello", {"n": 5})
@ -45,10 +48,13 @@ def test_get_request() -> None:
"prompt": "hello",
"num_results": 5,
"engine": "dummy",
"request_cls": "LMRequest",
}
assert response.get_json_response() == {
"choices": [{"text": "hello"}] * 5,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
"choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5,
}
assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
}
request_params = client.get_request(["hello"] * 5, {"n": 1})
@ -57,8 +63,11 @@ def test_get_request() -> None:
"prompt": ["hello"] * 5,
"num_results": 1,
"engine": "dummy",
"request_cls": "LMRequest",
}
assert response.get_json_response() == {
"choices": [{"text": "hello"}] * 5,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
"choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5,
}
assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
}

@ -90,8 +90,8 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response:
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
assert len(result.get_usage_obj().usages) == len(
result.get_response_obj().choices
)
res = result.get_response(manifest.stop_token)
else:
@ -101,6 +101,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{
"prompt": "This is a prompt",
"engine": "dummy",
"request_cls": "LMRequest",
"num_results": n,
},
)
@ -116,8 +117,8 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response:
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
assert len(result.get_usage_obj().usages) == len(
result.get_response_obj().choices
)
res = result.get_response(manifest.stop_token)
else:
@ -127,6 +128,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{
"prompt": "This is a prompt",
"engine": "dummy",
"request_cls": "LMRequest",
"num_results": n,
"run_id": "34",
}
@ -143,8 +145,8 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response:
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
assert len(result.get_usage_obj().usages) == len(
result.get_response_obj().choices
)
res = result.get_response(manifest.stop_token)
else:
@ -154,6 +156,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"request_cls": "LMRequest",
"num_results": n,
},
)
@ -169,8 +172,8 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response:
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
assert len(result.get_usage_obj().usages) == len(
result.get_response_obj().choices
)
res = result.get_response(stop_token="ll")
else:
@ -180,6 +183,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"request_cls": "LMRequest",
"num_results": n,
},
)
@ -212,8 +216,8 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response:
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
assert len(result.get_usage_obj().usages) == len(
result.get_response_obj().choices
)
res = result.get_response(manifest.stop_token, is_batch=True)
else:
@ -224,6 +228,7 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{
"prompt": "This is a prompt",
"engine": "dummy",
"request_cls": "LMRequest",
"num_results": n,
},
)
@ -235,8 +240,8 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response:
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
assert len(result.get_usage_obj().usages) == len(
result.get_response_obj().choices
)
res = result.get_response(manifest.stop_token, is_batch=True)
else:
@ -247,6 +252,7 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"request_cls": "LMRequest",
"num_results": n,
},
)
@ -262,6 +268,7 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
{
"prompt": "New prompt",
"engine": "dummy",
"request_cls": "LMRequest",
"num_results": n,
},
)
@ -272,8 +279,8 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response:
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
assert len(result.get_usage_obj().usages) == len(
result.get_response_obj().choices
)
res = result.get_response(manifest.stop_token, is_batch=True)
# Cached because one item is in cache
@ -287,8 +294,8 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
if return_response:
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
assert len(result.get_usage_obj().usages) == len(
result.get_response_obj().choices
)
res = result.get_response(stop_token="ll", is_batch=True)
else:
@ -309,9 +316,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response(manifest.stop_token, is_batch=True)
assert res == ["hello"]
assert (
@ -319,6 +324,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
{
"prompt": "This is a prompt",
"engine": "dummy",
"request_cls": "LMRequest",
"num_results": 1,
},
)
@ -330,9 +336,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response(manifest.stop_token, is_batch=True)
assert res == ["hello", "hello"]
assert (
@ -340,6 +344,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"request_cls": "LMRequest",
"num_results": 1,
},
)
@ -350,9 +355,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response(manifest.stop_token, is_batch=True)
assert result.is_cached()
@ -361,6 +364,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
{
"prompt": "New prompt",
"engine": "dummy",
"request_cls": "LMRequest",
"num_results": 1,
},
)
@ -371,9 +375,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response(manifest.stop_token, is_batch=True)
# Cached because one item is in cache
assert result.is_cached()
@ -384,9 +386,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response(stop_token="ll", is_batch=True)
assert res == ["he", "he"]
@ -407,25 +407,43 @@ def test_score_run(sqlite_cache: str) -> None:
{
"prompt": "This is a prompt",
"engine": "dummy",
"request_cls": "LMScoreRequest",
"num_results": 1,
"request_type": "score_prompt",
},
)
is not None
)
assert result == {
"generation_key": "choices",
"logits_key": "token_logprobs",
"item_key": "text",
"item_dtype": None,
"response": {"choices": [{"text": "This is a prompt", "logprob": 0.3}]},
"response": {
"choices": [
{"text": "This is a prompt", "token_logprobs": [0.3], "tokens": None}
]
},
"usages": {"usages": []},
"cached": False,
"request_params": {
"request": {
"prompt": "This is a prompt",
"engine": "dummy",
"num_results": 1,
"request_type": "score_prompt",
"engine": "text-ada-001",
"n": 1,
"client_timeout": 60,
"run_id": None,
"batch_size": 8,
"temperature": 0.7,
"max_tokens": 100,
"top_p": 1.0,
"top_k": 50,
"logprobs": None,
"stop_sequences": None,
"num_beams": 1,
"do_sample": False,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
},
"response_type": "text",
"request_type": "LMScoreRequest",
"item_dtype": None,
}
prompt_list = ["Hello is a prompt", "Hello is another prompt"]
@ -435,8 +453,8 @@ def test_score_run(sqlite_cache: str) -> None:
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"request_cls": "LMScoreRequest",
"num_results": 1,
"request_type": "score_prompt",
},
)
is not None
@ -446,30 +464,48 @@ def test_score_run(sqlite_cache: str) -> None:
{
"prompt": "Hello is another prompt",
"engine": "dummy",
"request_cls": "LMScoreRequest",
"num_results": 1,
"request_type": "score_prompt",
},
)
is not None
)
assert result == {
"generation_key": "choices",
"logits_key": "token_logprobs",
"item_key": "text",
"item_dtype": None,
"response": {
"choices": [
{"text": "Hello is a prompt", "logprob": 0.3},
{"text": "Hello is another prompt", "logprob": 0.3},
{"text": "Hello is a prompt", "token_logprobs": [0.3], "tokens": None},
{
"text": "Hello is another prompt",
"token_logprobs": [0.3],
"tokens": None,
},
]
},
"usages": {"usages": []},
"cached": False,
"request_params": {
"request": {
"prompt": ["Hello is a prompt", "Hello is another prompt"],
"engine": "dummy",
"num_results": 1,
"request_type": "score_prompt",
"engine": "text-ada-001",
"n": 1,
"client_timeout": 60,
"run_id": None,
"batch_size": 8,
"temperature": 0.7,
"max_tokens": 100,
"top_p": 1.0,
"top_k": 50,
"logprobs": None,
"stop_sequences": None,
"num_beams": 1,
"do_sample": False,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
},
"response_type": "text",
"request_type": "LMScoreRequest",
"item_dtype": None,
}
@ -644,8 +680,8 @@ def test_openai(sqlite_cache: str) -> None:
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
assert response.get_response() == res
assert response.is_cached() is True
assert "usage" in response.get_json_response()
assert response.get_json_response()["usage"][0]["total_tokens"] == 15
assert response.get_usage_obj().usages
assert response.get_usage_obj().usages[0].total_tokens == 15
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert response.is_cached() is True
@ -662,12 +698,9 @@ def test_openai(sqlite_cache: str) -> None:
assert (
isinstance(response.get_response(), list) and len(response.get_response()) == 2
)
assert (
"usage" in response.get_json_response()
and len(response.get_json_response()["usage"]) == 2
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 15
assert response.get_json_response()["usage"][1]["total_tokens"] == 16
assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2
assert response.get_usage_obj().usages[0].total_tokens == 15
assert response.get_usage_obj().usages[1].total_tokens == 16
response = cast(
Response, client.run("Why are there bananas?", return_response=True)
@ -691,12 +724,9 @@ def test_openai(sqlite_cache: str) -> None:
assert (
isinstance(response.get_response(), list) and len(response.get_response()) == 2
)
assert (
"usage" in response.get_json_response()
and len(response.get_json_response()["usage"]) == 2
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 17
assert response.get_json_response()["usage"][1]["total_tokens"] == 15
assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2
assert response.get_usage_obj().usages[0].total_tokens == 17
assert response.get_usage_obj().usages[1].total_tokens == 15
response = cast(
Response, client.run("Why are there oranges?", return_response=True)
@ -721,8 +751,8 @@ def test_openaichat(sqlite_cache: str) -> None:
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
assert response.get_response() == res
assert response.is_cached() is True
assert "usage" in response.get_json_response()
assert response.get_json_response()["usage"][0]["total_tokens"] == 23
assert response.get_usage_obj().usages
assert response.get_usage_obj().usages[0].total_tokens == 23
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert response.is_cached() is True
@ -749,12 +779,9 @@ def test_openaichat(sqlite_cache: str) -> None:
assert (
isinstance(response.get_response(), list) and len(response.get_response()) == 2
)
assert (
"usage" in response.get_json_response()
and len(response.get_json_response()["usage"]) == 2
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 25
assert response.get_json_response()["usage"][1]["total_tokens"] == 23
assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2
assert response.get_usage_obj().usages[0].total_tokens == 25
assert response.get_usage_obj().usages[1].total_tokens == 23
response = cast(
Response, client.run("Why are there oranges?", return_response=True)
@ -795,8 +822,8 @@ def test_openaiembedding(sqlite_cache: str) -> None:
assert isinstance(response.get_response(), np.ndarray)
assert np.allclose(response.get_response(), res)
assert response.is_cached() is True
assert "usage" in response.get_json_response()
assert response.get_json_response()["usage"][0]["total_tokens"] == 5
assert response.get_usage_obj().usages
assert response.get_usage_obj().usages[0].total_tokens == 5
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert response.is_cached() is True
@ -817,12 +844,9 @@ def test_openaiembedding(sqlite_cache: str) -> None:
assert (
isinstance(response.get_response(), list) and len(response.get_response()) == 2
)
assert (
"usage" in response.get_json_response()
and len(response.get_json_response()["usage"]) == 2
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 5
assert response.get_json_response()["usage"][1]["total_tokens"] == 6
assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2
assert response.get_usage_obj().usages[0].total_tokens == 5
assert response.get_usage_obj().usages[1].total_tokens == 6
response = cast(
Response, client.run("Why are there bananas?", return_response=True)
@ -857,12 +881,9 @@ def test_openaiembedding(sqlite_cache: str) -> None:
and len(res_list) == 2
and isinstance(res_list[0], np.ndarray)
)
assert (
"usage" in response.get_json_response()
and len(response.get_json_response()["usage"]) == 2
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 7
assert response.get_json_response()["usage"][1]["total_tokens"] == 5
assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2
assert response.get_usage_obj().usages[0].total_tokens == 7
assert response.get_usage_obj().usages[1].total_tokens == 5
response = cast(
Response, client.run("Why are there oranges?", return_response=True)

@ -1,230 +1,301 @@
"""Response test."""
from typing import Any, Dict
from typing import List, cast
import numpy as np
import pytest
from manifest import Response
from manifest.request import LMRequest
from manifest.request import EmbeddingRequest, LMRequest
from manifest.response import ArrayModelChoice, ModelChoices, Usage, Usages
def test_init() -> None:
def test_init(
model_choice: ModelChoices,
model_choice_arr: ModelChoices,
model_choice_arr_int: ModelChoices,
request_lm: LMRequest,
request_array: EmbeddingRequest,
) -> None:
"""Test response initialization."""
with pytest.raises(ValueError) as exc_info:
response = Response(4, False, {}) # type: ignore
assert str(exc_info.value) == "Response must be dict. Response is\n4."
with pytest.raises(ValueError) as exc_info:
response = Response({"test": "hello"}, False, {})
assert str(exc_info.value) == (
"Response must be serialized to a dict with a nonempty list of choices. "
"Response is\n{'test': 'hello'}."
)
with pytest.raises(ValueError) as exc_info:
response = Response({"choices": [{"blah": "hello"}]}, False, {})
assert str(exc_info.value) == (
"Response must be serialized to a dict "
"with a list of choices with text field"
)
with pytest.raises(ValueError) as exc_info:
response = Response({"choices": []}, False, {})
assert str(exc_info.value) == (
"Response must be serialized to a dict with a nonempty list of choices. "
"Response is\n{'choices': []}."
)
response = Response({"choices": [{"text": "hello"}]}, False, {})
assert response._response == {"choices": [{"text": "hello"}]}
response = Response(
response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="text",
)
assert response._response == model_choice
assert response._cached is False
assert response._request_params == {}
assert response.item_dtype is None
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
assert response._response == {"choices": [{"text": "hello"}]}
assert response._cached is True
assert response._request_params == {"request": "yoyo"}
assert response.item_dtype is None
assert response._request == request_lm
assert response._usages == Usages(usages=[])
assert response._request_type == LMRequest
assert response._response_type == "text"
assert response._item_dtype is None
response = Response(
{"generations": [{"txt": "hello"}], "logits": []},
False,
{},
generation_key="generations",
logits_key="logits",
item_key="txt",
)
assert response._response == {"generations": [{"txt": "hello"}], "logits": []}
response=model_choice_arr_int,
cached=False,
request=request_array,
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
request_type=EmbeddingRequest,
response_type="array",
)
assert response._cached is False
assert response._request_params == {}
assert response.item_dtype is None
assert response._request == request_array
assert sum([usg.total_tokens for usg in response._usages.usages]) == 10
assert response._request_type == EmbeddingRequest
assert response._response_type == "array"
assert response._item_dtype == "int64"
int_arr = np.random.randint(20, size=(4, 4))
response = Response(
{"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array"
with pytest.raises(ValueError) as excinfo:
Response(
response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="blah",
)
assert "blah" in str(excinfo.value)
# Can't convert array with text
with pytest.raises(ValueError) as excinfo:
Response(
response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="array",
)
assert str(excinfo.value) == (
"response_type is array but response is "
"<class 'manifest.response.LMModelChoice'>"
)
# Can't convert text with array
with pytest.raises(ValueError) as excinfo:
Response(
response=model_choice_arr,
cached=False,
request=request_array,
usages=None,
request_type=LMRequest,
response_type="text",
)
assert str(excinfo.value) == (
"response_type is text but response is "
"<class 'manifest.response.ArrayModelChoice'>"
)
assert response._response == {"choices": [{"array": int_arr}]}
assert response._cached is True
assert response._request_params == {"request": "yoyo"}
assert response.item_dtype == "int64"
def test_getters() -> None:
def test_getters(model_choice: ModelChoices, request_lm: LMRequest) -> None:
"""Test response cached."""
response = Response({"choices": [{"text": "hello"}]}, False, {})
assert response.get_json_response() == {"choices": [{"text": "hello"}]}
response = Response(
response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="text",
)
assert response.get_response_obj() == model_choice
assert response.is_cached() is False
assert response.get_request() == {}
assert response.get_request_obj() == request_lm
assert response.get_usage_obj() == Usages(usages=[])
assert response.get_json_response() == model_choice.dict()
assert response.get_response() == ["hello", "bye"]
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
assert response.get_json_response() == {"choices": [{"text": "hello"}]}
assert response.is_cached() is True
assert response.get_request() == {"request": "yoyo"}
int_arr = np.random.randint(20, size=(4, 4))
def test_serialize(
model_choice: ModelChoices,
model_choice_arr: ModelChoices,
model_choice_arr_int: ModelChoices,
request_lm: LMRequest,
request_array: EmbeddingRequest,
) -> None:
"""Test response serialization."""
response = Response(
{"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array"
response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="text",
)
assert response.get_json_response() == {"choices": [{"array": int_arr}]}
assert response.is_cached() is True
assert response.get_request() == {"request": "yoyo"}
deserialized_response = Response.deserialize(response.serialize())
assert deserialized_response.get_response_obj() == model_choice
assert deserialized_response.is_cached() is False
assert deserialized_response.get_request_obj() == request_lm
assert deserialized_response.get_usage_obj() == Usages(usages=[])
assert deserialized_response.get_json_response() == model_choice.dict()
assert deserialized_response.get_response() == ["hello", "bye"]
deserialized_response = Response.from_dict(response.to_dict())
assert deserialized_response.get_response_obj() == model_choice
assert deserialized_response.is_cached() is False
assert deserialized_response.get_request_obj() == request_lm
assert deserialized_response.get_usage_obj() == Usages(usages=[])
assert deserialized_response.get_json_response() == model_choice.dict()
assert deserialized_response.get_response() == ["hello", "bye"]
def test_serialize() -> None:
"""Test response serialization."""
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
deserialized_response = Response.deserialize(response.serialize())
assert deserialized_response._response == {"choices": [{"text": "hello"}]}
assert deserialized_response.is_cached() is True
assert deserialized_response._request_params == {"request": "yoyo"}
deserialized_response = Response.from_dict(
response.to_dict(drop_request=True), request_dict={"prompt": "blahhhh"}
)
assert deserialized_response.get_response_obj() == model_choice
assert deserialized_response.is_cached() is False
assert deserialized_response.get_request_obj().prompt == "blahhhh"
assert deserialized_response.get_usage_obj() == Usages(usages=[])
assert deserialized_response.get_json_response() == model_choice.dict()
assert deserialized_response.get_response() == ["hello", "bye"]
int_arr = np.random.randint(20, size=(4, 4))
# Int type
response = Response(
{"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array"
response=model_choice_arr_int,
cached=False,
request=request_array,
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
request_type=EmbeddingRequest,
response_type="array",
)
deserialized_response = Response.deserialize(response.serialize())
assert deserialized_response._item_dtype == "int64"
assert (
cast(
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
).array.dtype
== np.int64
)
assert np.array_equal(
deserialized_response._response["choices"][0]["array"], int_arr
cast(
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
).array,
cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array,
)
assert deserialized_response.is_cached() is True
assert deserialized_response._request_params == {"request": "yoyo"}
float_arr = np.random.randn(4, 4)
# Float type
response = Response(
{"choices": [{"array": float_arr}]}, True, {"request": "yoyo"}, item_key="array"
response=model_choice_arr,
cached=False,
request=request_array,
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
request_type=EmbeddingRequest,
response_type="array",
)
deserialized_response = Response.deserialize(response.serialize())
assert deserialized_response._item_dtype == "float64"
assert (
cast(
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
).array.dtype
== np.float64
)
assert np.array_equal(
deserialized_response._response["choices"][0]["array"], float_arr
cast(
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
).array,
cast(ArrayModelChoice, model_choice_arr.choices[0]).array,
)
assert deserialized_response.is_cached() is True
assert deserialized_response._request_params == {"request": "yoyo"}
def test_get_results() -> None:
def test_get_results(
model_choice: ModelChoices,
model_choice_single: ModelChoices,
model_choice_arr: ModelChoices,
request_lm: LMRequest,
request_array: EmbeddingRequest,
) -> None:
"""Test response get results."""
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
assert response.get_response() == "hello"
response = Response(
response=model_choice_single,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="text",
)
assert response.get_response() == "helloo"
assert response.get_response(stop_token="ll") == "he"
assert response.get_response(stop_token="ll", is_batch=True) == ["he"]
response = Response(
{"choices": [{"text": "hello"}, {"text": "my"}, {"text": "name"}]},
True,
{"request": "yoyo"},
response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="text",
)
assert response.get_response() == ["hello", "my", "name"]
assert response.get_response(stop_token="m") == ["hello", "", "na"]
assert response.get_response(stop_token="m", is_batch=True) == ["hello", "", "na"]
assert response.get_response() == ["hello", "bye"]
assert response.get_response(stop_token="b") == ["hello", ""]
assert response.get_response(stop_token="y", is_batch=True) == ["hello", "b"]
float_arr = np.random.randn(4, 4)
float_arr1 = cast(ArrayModelChoice, model_choice_arr.choices[0]).array
float_arr2 = cast(ArrayModelChoice, model_choice_arr.choices[1]).array
response = Response(
{"choices": [{"array": float_arr}, {"array": float_arr}]},
True,
{"request": "yoyo"},
item_key="array",
response=model_choice_arr,
cached=False,
request=request_array,
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
request_type=EmbeddingRequest,
response_type="array",
)
assert response.get_response() == [float_arr, float_arr]
assert response.get_response(stop_token="m") == [float_arr, float_arr]
assert np.array_equal(response.get_response()[0], float_arr1)
assert np.array_equal(response.get_response()[1], float_arr2)
assert np.array_equal(response.get_response(stop_token="t")[0], float_arr1)
assert np.array_equal(response.get_response(stop_token="t")[1], float_arr2)
def test_union_all() -> None:
def test_union_all(
model_choice: ModelChoices,
model_choice_single: ModelChoices,
request_lm: LMRequest,
request_lm_single: LMRequest,
) -> None:
"""Test union all."""
request_paramsa = LMRequest(prompt=["apple", "orange", "pear"]).to_dict()
request_paramsa["model"] = "modelA"
response_paramsa = {
"choices": [
{"text": "hello", "token_logprobs": [1]},
{"text": "hello 2", "token_logprobs": [1]},
{"text": "hello 3", "token_logprobs": [1]},
]
}
responsea = Response(response_paramsa, False, request_paramsa)
response1 = Response(
response=model_choice,
cached=False,
request=request_lm,
usages=None,
request_type=LMRequest,
response_type="text",
)
request_paramsb = LMRequest(prompt=["banana", "pineapple", "mango"]).to_dict()
request_paramsb["model"] = "modelB"
response_paramsb = {
"choices": [
{"text": "bye", "token_logprobs": [2]},
{"text": "bye 2", "token_logprobs": [2]},
{"text": "bye 3", "token_logprobs": [2]},
]
}
responseb = Response(response_paramsb, False, request_paramsb)
response2 = Response(
response=model_choice_single,
cached=False,
request=request_lm_single,
usages=None,
request_type=LMRequest,
response_type="text",
)
final_response = Response.union_all([responsea, responseb])
final_response = Response.union_all([response1, response2])
assert final_response.get_json_response() == {
"choices": [
{"text": "hello", "token_logprobs": [1]},
{"text": "hello 2", "token_logprobs": [1]},
{"text": "hello 3", "token_logprobs": [1]},
{"text": "bye", "token_logprobs": [2]},
{"text": "bye 2", "token_logprobs": [2]},
{"text": "bye 3", "token_logprobs": [2]},
{"text": "hello", "token_logprobs": [0.1, 0.2], "tokens": None},
{"text": "bye", "token_logprobs": [0.3], "tokens": None},
{"text": "helloo", "token_logprobs": [0.1, 0.2], "tokens": None},
]
}
final_request = LMRequest(
prompt=["apple", "orange", "pear", "banana", "pineapple", "mango"]
).to_dict()
final_request["model"] = "modelA"
assert final_response.get_request() == final_request
assert not final_response.is_cached()
assert final_response.get_usage_obj() == Usages(usages=[Usage(), Usage(), Usage()])
merged_prompts: List[str] = request_lm.prompt + [request_lm_single.prompt] # type: ignore # noqa: E501
assert final_response.get_request_obj().prompt == merged_prompts
assert final_response.get_request_obj().engine == "dummy::text-ada-001"
# Modify A to have usage and cached
response_paramsa_2: Dict[str, Any] = {
"choices": [
{"text": "hello", "token_logprobs": [1]},
{"text": "hello 2", "token_logprobs": [1]},
{"text": "hello 3", "token_logprobs": [1]},
],
"usage": [
{"completion_tokens": 10},
{"completion_tokens": 10},
{"completion_tokens": 10},
],
}
responsea = Response(response_paramsa_2, True, request_paramsa)
response1 = Response(
response=model_choice,
cached=False,
request=request_lm,
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
request_type=LMRequest,
response_type="text",
)
final_response = Response.union_all([responsea, responseb])
assert final_response.get_json_response() == {
"choices": [
{"text": "hello", "token_logprobs": [1]},
{"text": "hello 2", "token_logprobs": [1]},
{"text": "hello 3", "token_logprobs": [1]},
{"text": "bye", "token_logprobs": [2]},
{"text": "bye 2", "token_logprobs": [2]},
{"text": "bye 3", "token_logprobs": [2]},
],
"usage": [
{"completion_tokens": 10},
{"completion_tokens": 10},
{"completion_tokens": 10},
{},
{},
{},
],
}
final_request = LMRequest(
prompt=["apple", "orange", "pear", "banana", "pineapple", "mango"]
).to_dict()
final_request["model"] = "modelA"
assert final_response.get_request() == final_request
assert final_response.is_cached()
final_response = Response.union_all([response1, response2])
assert final_response.get_usage_obj() == Usages(
usages=[Usage(total_tokens=4), Usage(total_tokens=6), Usage()]
)

@ -10,23 +10,23 @@ def test_response_to_key_array() -> None:
"""Test array serializer initialization."""
serializer = ArraySerializer()
arr = np.random.rand(4, 4)
res = {"choices": [{"array": arr}]}
res = {"response": {"choices": [{"array": arr}]}}
key = serializer.response_to_key(res)
key_dct = json.loads(key)
assert isinstance(key_dct["choices"][0]["array"], str)
assert isinstance(key_dct["response"]["choices"][0]["array"], str)
res2 = serializer.key_to_response(key)
assert np.allclose(arr, res2["choices"][0]["array"])
assert np.allclose(arr, res2["response"]["choices"][0]["array"])
def test_response_to_key_numpybytes() -> None:
"""Test array serializer initialization."""
serializer = NumpyByteSerializer()
arr = np.random.rand(4, 4)
res = {"choices": [{"array": arr}]}
res = {"response": {"choices": [{"array": arr}]}}
key = serializer.response_to_key(res)
key_dct = json.loads(key)
assert isinstance(key_dct["choices"][0]["array"], str)
assert isinstance(key_dct["response"]["choices"][0]["array"], str)
res2 = serializer.key_to_response(key)
assert np.allclose(arr, res2["choices"][0]["array"])
assert np.allclose(arr, res2["response"]["choices"][0]["array"])

@ -49,7 +49,7 @@ def prompt_manifest(*, manifest_in: schemas.ManifestCreate) -> Dict:
return {
"response": response.get_response(),
"cached": response.is_cached(),
"request_params": response.get_request(),
"request_params": response.get_request_obj(),
}

Loading…
Cancel
Save