From d7401c6ec5756435f07395b6306ca97db50fca9e Mon Sep 17 00:00:00 2001 From: Laurel Orr <57237365+lorr1@users.noreply.github.com> Date: Mon, 24 Apr 2023 10:10:47 -0700 Subject: [PATCH] fix: added pydantic types to response (#84) --- CHANGELOG.rst | 1 + README.md | 4 +- examples/manifest_connection_pool.ipynb | 20 +- manifest/caches/cache.py | 11 +- manifest/caches/serializers.py | 28 +- manifest/clients/ai21.py | 2 +- manifest/clients/client.py | 62 +++- manifest/clients/cohere.py | 2 +- manifest/clients/diffuser.py | 2 +- manifest/clients/dummy.py | 57 ++- manifest/clients/huggingface.py | 17 +- manifest/clients/huggingface_embedding.py | 2 +- manifest/clients/openai.py | 1 + manifest/clients/openai_embedding.py | 2 +- manifest/clients/toma.py | 2 +- manifest/clients/toma_diffuser.py | 2 +- manifest/manifest.py | 94 +++-- manifest/request.py | 18 +- manifest/response.py | 313 ++++++++-------- tests/conftest.py | 85 +++++ tests/test_cache.py | 206 +++++------ tests/test_client.py | 21 +- tests/test_manifest.py | 199 ++++++----- tests/test_response.py | 411 +++++++++++++--------- tests/test_serializer.py | 12 +- web_app/main.py | 2 +- 26 files changed, 916 insertions(+), 660 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ba620a0..92fa08a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,7 @@ Added Fixed ^^^^^ * Determine cache and response by request type, not client name +* Refactor Response to use Pydantic types for Request and Response 0.1.1 --------------------- diff --git a/README.md b/README.md index 6332b7e..f6e9a9d 100644 --- a/README.md +++ b/README.md @@ -126,9 +126,9 @@ results = asyncio.run(manifest.arun_batch(["Where are the cats?", "Where are the If something doesn't go right, you can also ask to get a raw manifest Response. ```python result_object = manifest.run(["Where are the cats?", "Where are the dogs?"], return_response=True) -print(result_object.get_request()) +print(result_object.get_request_obj()) print(result_object.is_cached()) -print(result_object.get_json_response()) +print(result_object.get_response_obj()) ``` By default, we do not truncate results based on a stop token. You can change this by either passing a new stop token to a Manifest session or to a `run`. diff --git a/examples/manifest_connection_pool.ipynb b/examples/manifest_connection_pool.ipynb index 5328dd2..5b2b861 100644 --- a/examples/manifest_connection_pool.ipynb +++ b/examples/manifest_connection_pool.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -89,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -131,23 +131,23 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "For loop: 229.93\n", + "For loop: 128.68\n", "Running with async single client\n", "Running 1 tasks across all clients.\n", - "Async loop: 1.39\n", + "Async loop: 4.02\n", "Running with async two clients but not chunking\n", "Running 1 tasks across all clients.\n", - "Async loop: 1.65\n", + "Async loop: 3.92\n", "Running with async two clients and chunk size\n", "Running 20 tasks across all clients.\n", - "Async loop: 0.64\n" + "Async loop: 1.44\n" ] } ], diff --git a/manifest/caches/cache.py b/manifest/caches/cache.py index d0cd08f..e4cbe6d 100644 --- a/manifest/caches/cache.py +++ b/manifest/caches/cache.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Type, Union from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer, Serializer from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest, Request -from manifest.response import RESPONSE_CONSTRUCTORS, Response +from manifest.response import Response # Non-text return type caches ARRAY_CACHE_TYPES = {EmbeddingRequest, DiffusionRequest} @@ -119,14 +119,9 @@ class Cache(ABC): key = self.serializer.request_to_key(request) cached_response = self.get_key(key) if cached_response: - cached = True response = self.serializer.key_to_response(cached_response) - return Response( - response, - cached, - request, - **RESPONSE_CONSTRUCTORS.get(self.request_type, {}), - ) + response["cached"] = True + return Response.from_dict(response, request_dict=request) return None def set(self, request: Dict, response: Dict) -> None: diff --git a/manifest/caches/serializers.py b/manifest/caches/serializers.py index d6ad506..a1ec3dd 100644 --- a/manifest/caches/serializers.py +++ b/manifest/caches/serializers.py @@ -77,14 +77,15 @@ class NumpyByteSerializer(Serializer): Returns: normalized key. """ + sub_response = response["response"] # Assume response is a dict with keys "choices" -> List dicts # with keys "array". - choices = response["choices"] + choices = sub_response["choices"] # We don't want to modify the response in place # but we want to avoid calling deepcopy on an array - del response["choices"] - response_copy = response.copy() - response["choices"] = choices + del sub_response["choices"] + response_copy = sub_response.copy() + sub_response["choices"] = choices response_copy["choices"] = [] for choice in choices: if "array" not in choice: @@ -101,7 +102,8 @@ class NumpyByteSerializer(Serializer): hash_str = f.getvalue().hex() new_choice["array"] = hash_str response_copy["choices"].append(new_choice) - return json.dumps(response_copy, sort_keys=True) + response["response"] = response_copy + return json.dumps(response, sort_keys=True) def key_to_response(self, key: str) -> Dict: """ @@ -114,7 +116,7 @@ class NumpyByteSerializer(Serializer): unnormalized response dict. """ response = json.loads(key) - for choice in response["choices"]: + for choice in response["response"]["choices"]: hash_str = choice["array"] byte_str = bytes.fromhex(hash_str) with io.BytesIO(byte_str) as f: @@ -152,14 +154,15 @@ class ArraySerializer(Serializer): Returns: normalized key. """ + sub_response = response["response"] # Assume response is a dict with keys "choices" -> List dicts # with keys "array". - choices = response["choices"] + choices = sub_response["choices"] # We don't want to modify the response in place # but we want to avoid calling deepcopy on an array - del response["choices"] - response_copy = response.copy() - response["choices"] = choices + del sub_response["choices"] + response_copy = sub_response.copy() + sub_response["choices"] = choices response_copy["choices"] = [] for choice in choices: if "array" not in choice: @@ -179,7 +182,8 @@ class ArraySerializer(Serializer): response_copy["choices"].append(new_choice) if not self.writer.contains_key(hash_str): self.writer.put(hash_str, arr) - return json.dumps(response_copy, sort_keys=True) + response["response"] = response_copy + return json.dumps(response, sort_keys=True) def key_to_response(self, key: str) -> Dict: """ @@ -194,7 +198,7 @@ class ArraySerializer(Serializer): unnormalized response dict. """ response = json.loads(key) - for choice in response["choices"]: + for choice in response["response"]["choices"]: hash_str = choice["array"] choice["array"] = self.writer.get(hash_str) return response diff --git a/manifest/clients/ai21.py b/manifest/clients/ai21.py index 6bd356c..c81c8ba 100644 --- a/manifest/clients/ai21.py +++ b/manifest/clients/ai21.py @@ -94,7 +94,7 @@ class AI21Client(Client): """ return {"model_name": self.NAME, "engine": getattr(self, "engine")} - def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: + def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ Format response to dict. diff --git a/manifest/clients/client.py b/manifest/clients/client.py index 2b3df99..154ed00 100644 --- a/manifest/clients/client.py +++ b/manifest/clients/client.py @@ -10,8 +10,21 @@ import aiohttp import requests from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential -from manifest.request import DEFAULT_REQUEST_KEYS, NOT_CACHE_KEYS, Request -from manifest.response import RESPONSE_CONSTRUCTORS, Response +from manifest.request import ( + DEFAULT_REQUEST_KEYS, + NOT_CACHE_KEYS, + LMScoreRequest, + Request, +) +from manifest.response import ( + RESPONSE_CONSTRUCTORS, + ArrayModelChoice, + LMModelChoice, + ModelChoices, + Response, + Usage, + Usages, +) logger = logging.getLogger(__name__) @@ -161,16 +174,30 @@ class Client(ABC): request_params = self.get_request_params(request) for key in NOT_CACHE_KEYS: request_params.pop(key, None) + # Make sure to add model params and request class request_params.update(self.get_model_params()) + request_params["request_cls"] = request.__class__.__name__ return request_params def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]: """Split usage into list of usages for each prompt.""" return [] - def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: + def get_model_choices(self, response: Dict) -> ModelChoices: + """Format response to ModelChoices.""" + # Array or text response + response_type = RESPONSE_CONSTRUCTORS[self.REQUEST_CLS]["response_type"] + if response_type == "array": + choices: List[Union[LMModelChoice, ArrayModelChoice]] = [ + ArrayModelChoice(**choice) for choice in response["choices"] + ] + else: + choices = [LMModelChoice(**choice) for choice in response["choices"]] + return ModelChoices(choices=choices) + + def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ - Format response to dict. + Validate response as dict. Args: response: response @@ -246,7 +273,7 @@ class Client(ABC): except requests.exceptions.HTTPError: logger.error(res.json()) raise requests.exceptions.HTTPError(res.json()) - return self.format_response(res.json(), request_params) + return self.validate_response(res.json(), request_params) @retry( reraise=True, @@ -277,7 +304,7 @@ class Client(ABC): ) as res: res.raise_for_status() res_json = await res.json(content_type=None) - return self.format_response(res_json, request_params) + return self.validate_response(res_json, request_params) def run_request(self, request: Request) -> Response: """ @@ -301,11 +328,16 @@ class Client(ABC): for key in DEFAULT_REQUEST_KEYS: request_params.pop(key, None) response_dict = self._run_completion(request_params, retry_timeout) + usages = None + if "usage" in response_dict: + usages = [Usage(**usage) for usage in response_dict["usage"]] + return Response( - response_dict, + response=self.get_model_choices(response_dict), cached=False, - request_params=request_params, - **RESPONSE_CONSTRUCTORS.get(self.REQUEST_CLS, {}), # type: ignore + request=request, + usages=Usages(usages=usages) if usages else None, + **RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore ) async def arun_batch_request(self, request: Request) -> Response: @@ -353,18 +385,20 @@ class Client(ABC): if "usage" in res_dict: usages.extend(res_dict["usage"]) final_response_dict = {"choices": choices} + final_usages = None if usages: - final_response_dict["usage"] = usages + final_usages = Usages(usages=[Usage(**usage) for usage in usages]) return Response( - final_response_dict, + self.get_model_choices(final_response_dict), cached=False, - request_params=request_params, - **RESPONSE_CONSTRUCTORS.get(self.REQUEST_CLS, {}), # type: ignore + request=request, + usages=final_usages, + **RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore ) def get_score_prompt_request( self, - request: Request, + request: LMScoreRequest, ) -> Response: """ Get the logit score of the prompt via a forward pass of the model. diff --git a/manifest/clients/cohere.py b/manifest/clients/cohere.py index 62b3d89..7697958 100644 --- a/manifest/clients/cohere.py +++ b/manifest/clients/cohere.py @@ -93,7 +93,7 @@ class CohereClient(Client): """ return {"model_name": self.NAME, "engine": getattr(self, "engine")} - def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: + def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ Format response to dict. diff --git a/manifest/clients/diffuser.py b/manifest/clients/diffuser.py index e0254d2..9f8364c 100644 --- a/manifest/clients/diffuser.py +++ b/manifest/clients/diffuser.py @@ -86,7 +86,7 @@ class DiffuserClient(Client): res["client_name"] = self.NAME return res - def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: + def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ Format response to dict. diff --git a/manifest/clients/dummy.py b/manifest/clients/dummy.py index 8b74712..e195401 100644 --- a/manifest/clients/dummy.py +++ b/manifest/clients/dummy.py @@ -3,8 +3,8 @@ import logging from typing import Any, Dict, Optional from manifest.clients.client import Client -from manifest.request import LMRequest, Request -from manifest.response import Response +from manifest.request import LMRequest, LMScoreRequest, Request +from manifest.response import LMModelChoice, ModelChoices, Response, Usage, Usages logger = logging.getLogger(__name__) @@ -86,15 +86,30 @@ class DummyClient(Client): num_results = 1 request_params = request.to_dict(self.PARAMS) - response_dict = { - "choices": [{"text": "hello"}] - * int(request_params["num_results"]) - * num_results, - "usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] - * int(request_params["num_results"]) - * num_results, - } - return Response(response_dict, False, request_params) + return Response( + response=ModelChoices( + choices=[LMModelChoice(text="hello")] # type: ignore + * int(request_params["num_results"]) + * num_results + ), + cached=False, + request=request, + usages=Usages( + usages=[ + Usage( + **{ + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + } + ) + ] + * int(request_params["num_results"]) + * num_results + ), + response_type="text", + request_type=self.REQUEST_CLS, + ) async def arun_batch_request(self, request: Request) -> Response: """ @@ -110,7 +125,7 @@ class DummyClient(Client): def get_score_prompt_request( self, - request: Request, + request: LMScoreRequest, ) -> Response: """ Get the logit score of the prompt via a forward pass of the model. @@ -126,17 +141,27 @@ class DummyClient(Client): num_results = len(request.prompt) else: num_results = 1 - request_params = {"prompt": request.prompt} - response_dict = { "choices": [ { "text": request.prompt if isinstance(request.prompt, str) else request.prompt[i], - "logprob": 0.3, + "token_logprobs": [0.3], } for i in range(num_results) ] } - return Response(response_dict, False, request_params) + return Response( + response=ModelChoices( + choices=[ + LMModelChoice(**choice) # type: ignore + for choice in response_dict["choices"] + ] + ), + cached=False, + request=request, + usages=None, + response_type="text", + request_type=LMScoreRequest, + ) diff --git a/manifest/clients/huggingface.py b/manifest/clients/huggingface.py index 4e1d95d..5c02afc 100644 --- a/manifest/clients/huggingface.py +++ b/manifest/clients/huggingface.py @@ -5,8 +5,8 @@ from typing import Any, Dict, Optional import requests from manifest.clients.client import Client -from manifest.request import DEFAULT_REQUEST_KEYS, LMRequest, Request -from manifest.response import Response +from manifest.request import DEFAULT_REQUEST_KEYS, LMRequest, LMScoreRequest +from manifest.response import LMModelChoice, ModelChoices, Response logger = logging.getLogger(__name__) @@ -82,7 +82,7 @@ class HuggingFaceClient(Client): def get_score_prompt_request( self, - request: Request, + request: LMScoreRequest, ) -> Response: """ Get the logit score of the prompt via a forward pass of the model. @@ -116,4 +116,13 @@ class HuggingFaceClient(Client): logger.error(res.text) raise e response_dict = res.json() - return Response(response_dict, cached=False, request_params=request_params) + return Response( + response=ModelChoices( + choices=[LMModelChoice(**choice) for choice in response_dict["choices"]] + ), + cached=False, + request=request, + usages=None, + response_type="text", + request_type=LMScoreRequest, + ) diff --git a/manifest/clients/huggingface_embedding.py b/manifest/clients/huggingface_embedding.py index a052b85..8a14387 100644 --- a/manifest/clients/huggingface_embedding.py +++ b/manifest/clients/huggingface_embedding.py @@ -72,7 +72,7 @@ class HuggingFaceEmbeddingClient(Client): res["client_name"] = self.NAME return res - def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: + def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ Format response to dict. diff --git a/manifest/clients/openai.py b/manifest/clients/openai.py index 1a59369..3855580 100644 --- a/manifest/clients/openai.py +++ b/manifest/clients/openai.py @@ -37,6 +37,7 @@ class OpenAIClient(Client): "n": ("n", 1), "top_p": ("top_p", 1.0), "top_k": ("best_of", 1), + "logprobs": ("logprobs", None), "stop_sequences": ("stop", None), # OpenAI doesn't like empty lists "presence_penalty": ("presence_penalty", 0.0), "frequency_penalty": ("frequency_penalty", 0.0), diff --git a/manifest/clients/openai_embedding.py b/manifest/clients/openai_embedding.py index ff41b5a..a5984da 100644 --- a/manifest/clients/openai_embedding.py +++ b/manifest/clients/openai_embedding.py @@ -76,7 +76,7 @@ class OpenAIEmbeddingClient(OpenAIClient): """ return {"model_name": self.NAME, "engine": getattr(self, "engine")} - def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: + def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ Format response to dict. diff --git a/manifest/clients/toma.py b/manifest/clients/toma.py index 259a310..69c8b8b 100644 --- a/manifest/clients/toma.py +++ b/manifest/clients/toma.py @@ -143,7 +143,7 @@ class TOMAClient(Client): } return heartbeats - def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: + def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ Format response to dict. diff --git a/manifest/clients/toma_diffuser.py b/manifest/clients/toma_diffuser.py index c38c46a..1ea6ad0 100644 --- a/manifest/clients/toma_diffuser.py +++ b/manifest/clients/toma_diffuser.py @@ -46,7 +46,7 @@ class TOMADiffuserClient(TOMAClient): """ return {"model_name": self.NAME, "engine": getattr(self, "engine")} - def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: + def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ Format response to dict. diff --git a/manifest/manifest.py b/manifest/manifest.py index 14db3c9..ff3b319 100644 --- a/manifest/manifest.py +++ b/manifest/manifest.py @@ -2,7 +2,7 @@ import asyncio import copy import logging -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast import numpy as np @@ -17,8 +17,8 @@ from manifest.connections.client_pool import ( ClientConnection, ClientConnectionPool, ) -from manifest.request import Request -from manifest.response import Response +from manifest.request import LMScoreRequest, Request +from manifest.response import ModelChoices, Response, Usage, Usages logging.getLogger("openai").setLevel(logging.WARNING) logger = logging.getLogger(__name__) @@ -178,82 +178,72 @@ class Manifest: number_prompts = len(cached_idx_to_response) single_output = False if response: - if isinstance(response.get_request()["prompt"], str): + if isinstance(response.get_request_obj().prompt, str): single_output = True number_prompts += 1 else: - number_prompts += len(response.get_request()["prompt"]) - response_gen_key = None - response_logits_key = None - response_item_key = None + number_prompts += len(response.get_request_obj().prompt) + response_type = None + request_type: Type[Request] = None for idx in range(number_prompts): if idx in cached_idx_to_response: cached_res = cached_idx_to_response[idx] - response_gen_key = cached_res.generation_key - response_logits_key = cached_res.logits_key - response_item_key = cached_res.item_key - response_usage_key = cached_res.usage_key - all_input_prompts.append(cached_res.get_request()["prompt"]) - json_response = cached_res.get_json_response() + response_type = cached_res._response_type + request_type = cached_res._request_type + all_input_prompts.append(cached_res.get_request_obj().prompt) if request.n == 1: assert ( - len(json_response[response_gen_key]) == 1 + len(cached_res.get_response_obj().choices) == 1 ), "cached response should have only one choice" - all_model_choices.extend(json_response[response_gen_key]) - if response_usage_key: - all_usages.extend(json_response[response_usage_key]) + all_model_choices.extend(cached_res.get_response_obj().choices) + if cached_res.get_usage_obj().usages: + all_usages.extend(cached_res.get_usage_obj().usages) else: assert response is not None, "response should not be None" response = cast(Response, response) - response_gen_key = response.generation_key - response_logits_key = response.logits_key - response_item_key = response.item_key - response_usage_key = response.usage_key + response_type = response._response_type + request_type = response._request_type # the choices list in the response is a flat one. # length is request.n * num_prompts - current_choices = response.get_json_response()[response_gen_key][ + current_choices = response.get_response_obj().choices[ response_idx * request.n : (response_idx + 1) * request.n ] all_model_choices.extend(current_choices) - if isinstance(response.get_request()["prompt"], list): - prompt = response.get_request()["prompt"][response_idx] + if isinstance(response.get_request_obj().prompt, list): + prompt = response.get_request_obj().prompt[response_idx] else: - prompt = str(response.get_request()["prompt"]) - if response_usage_key: - usage = response.get_json_response()[response_usage_key][ + prompt = str(response.get_request_obj().prompt) + usages: Optional[List[Usage]] = None + if response.get_usage_obj().usages: + usages = response.get_usage_obj().usages[ response_idx * request.n : (response_idx + 1) * request.n ] - all_usages.extend(usage) + all_usages.extend(usages) all_input_prompts.append(prompt) # set cache new_request = copy.deepcopy(request) new_request.prompt = prompt cache_key = client.get_cache_key(new_request) - new_response_key = copy.deepcopy(response.get_json_response()) - new_response_key[response_gen_key] = current_choices - if response_usage_key: - new_response_key[response_usage_key] = usage - self.cache.set(cache_key, new_response_key) + new_response = copy.deepcopy(response) + new_response._response.choices = current_choices + new_response._usages = Usages(usages=(usages or [])) + self.cache.set(cache_key, new_response.to_dict(drop_request=True)) response_idx += 1 new_request = copy.deepcopy(request) new_request.prompt = ( - all_input_prompts + all_input_prompts # type: ignore if len(all_input_prompts) > 1 or not single_output else all_input_prompts[0] ) - new_response = {response_gen_key: all_model_choices} - if response_usage_key: - new_response[response_usage_key] = all_usages response_obj = Response( - new_response, + response=ModelChoices(choices=all_model_choices), cached=len(cached_idx_to_response) > 0, - request_params=client.get_cache_key(new_request), - generation_key=response_gen_key, - logits_key=response_logits_key, - item_key=response_item_key, - usage_key=response_usage_key, + request=new_request, + usages=Usages(usages=all_usages), + response_type=response_type, + request_type=request_type, ) return response_obj @@ -457,20 +447,20 @@ class Manifest: client = self.client_pool.get_client() # Must pass kwargs as dict for client "pop" methods removed used arguments request_params = client.get_request(prompt, kwargs) - request_params.request_type = "score_prompt" + request_params_as_score = LMScoreRequest(**request_params.to_dict()) - if request_params.n > 1: + if request_params_as_score.n > 1: raise ValueError("Sequence scoring does not support n > 1.") - self._validate_kwargs(kwargs, request_params) + self._validate_kwargs(kwargs, request_params_as_score) - cached_idx_to_response, request_params = self._split_cached_requests( - request_params, client, overwrite_cache + cached_idx_to_response, request_params_as_score = self._split_cached_requests( # type: ignore # noqa: E501 + request_params_as_score, client, overwrite_cache ) # If not None value or empty list - run new request - if request_params.prompt: + if request_params_as_score.prompt: try: response = cast(HuggingFaceClient, client).get_score_prompt_request( - request_params + request_params_as_score ) except AttributeError: raise ValueError("`score_prompt` only supported for HF models.") @@ -479,7 +469,7 @@ class Manifest: response = None final_response = self._stitch_responses_and_cache( - request=request_params, + request=request_params_as_score, client=client, response=response, cached_idx_to_response=cached_idx_to_response, diff --git a/manifest/request.py b/manifest/request.py index 242f308..097732e 100644 --- a/manifest/request.py +++ b/manifest/request.py @@ -3,13 +3,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel +# Used when unioning requests after async connection pool +ENGINE_SEP = "::" NOT_CACHE_KEYS = {"client_timeout", "batch_size"} # The below should match those in Request. DEFAULT_REQUEST_KEYS = { "client_timeout": ("client_timeout", 60), # seconds "batch_size": ("batch_size", 8), "run_id": ("run_id", None), - "request_type": ("request_type", None), } @@ -34,9 +35,6 @@ class Request(BaseModel): # Batch size for async batch run batch_size: int = 8 - # Request type None is for completion. Used for scoring prompt - request_type: str = None - def to_dict( self, allowable_keys: Dict[str, Tuple[str, Any]] = None, add_prompt: bool = True ) -> Dict[str, Any]: @@ -78,6 +76,9 @@ class LMRequest(Request): # Top k sampling taking top_k highest probability tokens top_k: int = 50 + # Logprobs return value + logprobs: Optional[int] = None + # Stop sequences stop_sequences: Optional[List[str]] = None @@ -100,6 +101,12 @@ class LMRequest(Request): frequency_penalty: float = 0 +class LMScoreRequest(LMRequest): + """Language Model Score Request object.""" + + pass + + class EmbeddingRequest(Request): """Embedding Request object.""" @@ -109,9 +116,6 @@ class EmbeddingRequest(Request): class DiffusionRequest(Request): """Diffusion Model Request object.""" - # Request type - request_type: str = "diffusion" - # Number of steps num_inference_steps: int = 50 diff --git a/manifest/response.py b/manifest/response.py index f149b79..9a35898 100644 --- a/manifest/response.py +++ b/manifest/response.py @@ -1,21 +1,25 @@ """Client response.""" import copy import json -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Type, Union, cast import numpy as np - -from manifest.request import DiffusionRequest, EmbeddingRequest - -RESPONSE_CONSTRUCTORS = { - EmbeddingRequest: { - "logits_key": "token_logprobs", - "item_key": "array", - }, - DiffusionRequest: { - "logits_key": "token_logprobs", - "item_key": "array", - }, +from pydantic import BaseModel + +from manifest.request import ( + ENGINE_SEP, + DiffusionRequest, + EmbeddingRequest, + LMRequest, + LMScoreRequest, + Request, +) + +RESPONSE_CONSTRUCTORS: Dict[Type[Request], Dict[str, Union[str, Type[Request]]]] = { + LMRequest: {"response_type": "text", "request_type": LMRequest}, + LMScoreRequest: {"response_type": "text", "request_type": LMScoreRequest}, + EmbeddingRequest: {"response_type": "array", "request_type": EmbeddingRequest}, + DiffusionRequest: {"response_type": "array", "request_type": DiffusionRequest}, } @@ -29,94 +33,114 @@ class NumpyArrayEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, obj) +class Usage(BaseModel): + """Prompt usage class.""" + + completion_tokens: int = 0 + prompt_tokens: int = 0 + total_tokens: int = 0 + + +class Usages(BaseModel): + """Prompt usage class.""" + + usages: List[Usage] + + +class LMModelChoice(BaseModel): + """Model single completion.""" + + text: str + token_logprobs: Optional[List[float]] = None + tokens: Optional[List[int]] = None + + +class ArrayModelChoice(BaseModel): + """Model single completion.""" + + array: np.ndarray + token_logprobs: Optional[List[float]] = None + + class Config: + """Pydantic config class.""" + + arbitrary_types_allowed = True + + +class ModelChoices(BaseModel): + """Model choices.""" + + choices: List[Union[LMModelChoice, ArrayModelChoice]] + + class Response: """Response class.""" def __init__( self, - response: Dict, # TODO: make pydantic model + response: ModelChoices, cached: bool, - request_params: Dict, # TODO: use request pydantic model - generation_key: str = "choices", - logits_key: str = "token_logprobs", - item_key: str = "text", - usage_key: str = "usage", + request: Request, + response_type: str, + request_type: Type[Request], + usages: Optional[Usages] = None, ): """ Initialize response. Args: response: response dict. + usages: usage dict. cached: whether response is cached. - request_params: request parameters. - generation_key: key for generation results. - logits_key: key for logits. - item_key: key for item in the generations. + request: request. + response_type: response type. + request_type: request type. """ - self.generation_key = generation_key - self.logits_key = logits_key - self.item_key = item_key - self.usage_key = usage_key - self.item_dtype = None - if isinstance(response, dict): - self._response = response - else: - raise ValueError(f"Response must be dict. Response is\n{response}.") - if ( - (self.generation_key not in self._response) - or (not isinstance(self._response[self.generation_key], list)) - or (len(self._response[self.generation_key]) <= 0) - ): - raise ValueError( - "Response must be serialized to a dict with a nonempty" - f" list of choices. Response is\n{self._response}." - ) - # Turn off usage if it is not in response - if self.usage_key not in self._response: - self.usage_key = None - else: - if not isinstance(self._response[self.usage_key], list): - raise ValueError( - "Response must be a list with usage dicts, one per choice." - f" Response is\n{self._response}." - ) - - if self.item_key not in self._response[self.generation_key][0]: - raise ValueError( - "Response must be serialized to a dict with a " - f"list of choices with {self.item_key} field" - ) - if ( - self.logits_key in self._response[self.generation_key][0] - and self._response[self.generation_key][0][self.logits_key] - ): - if not isinstance( - self._response[self.generation_key][0][self.logits_key], list - ): - raise ValueError( - f"{self.logits_key} must be a list of items " - "one for each token in the choice." - ) - if isinstance( - self._response[self.generation_key][0][self.item_key], np.ndarray - ): - self.item_dtype = str( - self._response[self.generation_key][0][self.item_key].dtype - ) + self._item_dtype = None + self._response_type = response_type + if self._response_type not in {"array", "text"}: + raise ValueError(f"Invalid response type {self._response_type}") + self._request_type = request_type + self._response = response + self._usages = usages or Usages(usages=[]) self._cached = cached - self._request_params = request_params + self._request = request + if self._response.choices: + if response_type == "array": + if not isinstance(self._response.choices[0], ArrayModelChoice): + raise ValueError( + "response_type is array but response is " + f"{self._response.choices[0].__class__}" + ) + self._item_dtype = str( + cast(ArrayModelChoice, self._response.choices[0]).array.dtype + ) + else: + if not isinstance(self._response.choices[0], LMModelChoice): + raise ValueError( + "response_type is text but response is " + f"{self._response.choices[0].__class__}" + ) def is_cached(self) -> bool: """Check if response is cached.""" return self._cached - def get_request(self) -> Dict: + def get_request_obj(self) -> Request: """Get request parameters.""" - return self._request_params + return self._request + + def get_response_obj(self) -> ModelChoices: + """Get response object.""" + return self._response + + def get_usage_obj(self) -> Usages: + """Get usage object.""" + return self._usages def get_json_response(self) -> Dict: """Get response dict without parsing.""" - return self._response + return self._response.dict() def get_response( self, stop_token: str = "", is_batch: bool = False @@ -132,7 +156,8 @@ class Response: lambda x: x.strip().split(stop_token)[0] if stop_token else x.strip() ) extracted_items = [ - choice[self.item_key] for choice in self._response[self.generation_key] + choice.text if isinstance(choice, LMModelChoice) else choice.array + for choice in self._response.choices ] if len(extracted_items) == 0: return None @@ -153,25 +178,15 @@ class Response: if len(responses) == 1: return responses[0] first_response = responses[0] - generation_key = first_response.generation_key - logits_key = first_response.logits_key - item_key = first_response.item_key - # Usage key may be None, so get first not-None value - possible_usage_keys = [r.usage_key for r in responses if r.usage_key] - if possible_usage_keys: - usage_key = possible_usage_keys[0] - else: - usage_key = None - request = first_response._request_params + request_type = first_response._request_type + response_type = first_response._response_type + request = first_response.get_request_obj() # Make sure all responses have the same keys if not all( [ - (r.generation_key == generation_key) - and (r.logits_key == logits_key) - and (r.item_key == item_key) - # Usage key can be empty - and (not r.usage_key or not usage_key or r.usage_key == usage_key) + (r._request_type == request_type) + and (r._response_type == response_type) for r in responses ] ): @@ -181,33 +196,31 @@ class Response: all_prompts = [] all_choices = [] all_usages = [] + all_engines = [] for res in responses: - json_response = res.get_json_response() - res_prompt = res.get_request()["prompt"] + all_engines.extend(res.get_request_obj().engine.split(ENGINE_SEP)) + res_prompt = res.get_request_obj().prompt if isinstance(res_prompt, str): res_prompt = [res_prompt] all_prompts.extend(res_prompt) - all_choices.extend(json_response[generation_key]) - if usage_key and usage_key in json_response: - all_usages.extend(json_response[usage_key]) + all_choices.extend(res.get_response_obj().choices) + if res.get_usage_obj().usages: + all_usages.extend(res.get_usage_obj().usages) else: - # Add empty usage - all_usages.extend([{}] * len(res_prompt)) + # Add empty usages if not present + all_usages.extend([Usage()] * len(res_prompt)) new_request = copy.deepcopy(request) - # TODO: add both models back in request. This should be a lot - # easier after I pydantic the response and request more formally - new_request["prompt"] = all_prompts - new_response = {generation_key: all_choices} - if usage_key: - new_response[usage_key] = all_usages + new_request.engine = ENGINE_SEP.join(sorted(set(all_engines))) + new_request.prompt = all_prompts + new_response = ModelChoices(choices=all_choices) + new_usages = Usages(usages=all_usages) response_obj = cls( - new_response, + response=new_response, cached=any(res.is_cached() for res in responses), - request_params=new_request, - generation_key=generation_key, - logits_key=logits_key, - item_key=item_key, - usage_key=usage_key, + request=new_request, + usages=new_usages, + request_type=request_type, + response_type=response_type, ) return response_obj @@ -232,56 +245,74 @@ class Response: serialized response. """ deserialized = json.loads(value) - item_dtype = deserialized["item_dtype"] - if item_dtype: - for choice in deserialized["response"][deserialized["generation_key"]]: - choice[deserialized["item_key"]] = np.array( - choice[deserialized["item_key"]] - ).astype(item_dtype) - return cls( - deserialized["response"], - deserialized["cached"], - deserialized["request_params"], - generation_key=deserialized["generation_key"], - logits_key=deserialized["logits_key"], - item_key=deserialized["item_key"], - ) + return cls.from_dict(deserialized) - def to_dict(self) -> Dict: + def to_dict(self, drop_request: bool = False) -> Dict: """ Get dictionary representation of response. Returns: dictionary representation of response. """ - return { - "generation_key": self.generation_key, - "logits_key": self.logits_key, - "item_key": self.item_key, - "item_dtype": self.item_dtype, - "response": self._response, + to_return = { + "response": self._response.dict(), + "usages": self._usages.dict(), "cached": self._cached, - "request_params": self._request_params, + "request": self._request.dict(), + "response_type": self._response_type, + "request_type": str(self._request_type.__name__), + "item_dtype": self._item_dtype, } + if drop_request: + to_return.pop("request") + return to_return @classmethod - def from_dict(cls, response: Dict) -> "Response": + def from_dict( + cls, response_dict: Dict, request_dict: Optional[Dict] = None + ) -> "Response": """ Create response from dictionary. Args: response: dictionary representation of response. + request_dict: dictionary representation of request which + will override what is in response_dict. Returns: response. """ + if "request" not in response_dict and request_dict is None: + raise ValueError( + "Request dictionary must be provided if " + "request is not in response dictionary." + ) + item_dtype = response_dict["item_dtype"] + response_type = response_dict["response_type"] + if response_dict["request_type"] == "LMRequest": + request_type: Type[Request] = LMRequest + elif response_dict["request_type"] == "LMScoreRequest": + request_type = LMScoreRequest + elif response_dict["request_type"] == "EmbeddingRequest": + request_type = EmbeddingRequest + elif response_dict["request_type"] == "DiffusionRequest": + request_type = DiffusionRequest + choices: List[Union[LMModelChoice, ArrayModelChoice]] = [] + if item_dtype and response_type == "array": + for choice in response_dict["response"]["choices"]: + choice["array"] = np.array(choice["array"]).astype(item_dtype) + choices.append(ArrayModelChoice(**choice)) + else: + for choice in response_dict["response"]["choices"]: + choices.append(LMModelChoice(**choice)) + response = ModelChoices(choices=choices) return cls( - response["response"], - response["cached"], - response["request_params"], - generation_key=response["generation_key"], - logits_key=response["logits_key"], - item_key=response["item_key"], + response=response, + usages=Usages(**response_dict["usages"]), + cached=response_dict["cached"], + request=request_type(**(request_dict or response_dict["request"])), + response_type=response_type, + request_type=request_type, ) def __str__(self) -> str: diff --git a/tests/conftest.py b/tests/conftest.py index 3687751..1a300a5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,9 +4,94 @@ import shutil from pathlib import Path from typing import Generator +import numpy as np import pytest import redis +from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest +from manifest.response import ArrayModelChoice, LMModelChoice, ModelChoices + + +@pytest.fixture +def model_choice() -> ModelChoices: + """Get dummy model choice.""" + model_choices = ModelChoices( + choices=[ + LMModelChoice(text="hello", token_logprobs=[0.1, 0.2]), + LMModelChoice(text="bye", token_logprobs=[0.3]), + ] + ) + return model_choices + + +@pytest.fixture +def model_choice_single() -> ModelChoices: + """Get dummy model choice.""" + model_choices = ModelChoices( + choices=[ + LMModelChoice(text="helloo", token_logprobs=[0.1, 0.2]), + ] + ) + return model_choices + + +@pytest.fixture +def model_choice_arr() -> ModelChoices: + """Get dummy model choice.""" + np.random.seed(0) + model_choices = ModelChoices( + choices=[ + ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.1, 0.2]), + ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.3]), + ] + ) + return model_choices + + +@pytest.fixture +def model_choice_arr_int() -> ModelChoices: + """Get dummy model choice.""" + np.random.seed(0) + model_choices = ModelChoices( + choices=[ + ArrayModelChoice( + array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.1, 0.2] + ), + ArrayModelChoice( + array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.3] + ), + ] + ) + return model_choices + + +@pytest.fixture +def request_lm() -> LMRequest: + """Get dummy request.""" + request = LMRequest(prompt=["what", "cat"]) + return request + + +@pytest.fixture +def request_lm_single() -> LMRequest: + """Get dummy request.""" + request = LMRequest(prompt="monkey", engine="dummy") + return request + + +@pytest.fixture +def request_array() -> EmbeddingRequest: + """Get dummy request.""" + request = EmbeddingRequest(prompt="hello") + return request + + +@pytest.fixture +def request_diff() -> DiffusionRequest: + """Get dummy request.""" + request = DiffusionRequest(prompt="hello") + return request + @pytest.fixture def sqlite_cache(tmp_path: Path) -> Generator[str, None, None]: diff --git a/tests/test_cache.py b/tests/test_cache.py index 69bef7b..266a60f 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -12,6 +12,7 @@ from manifest.caches.postgres import PostgresCache from manifest.caches.redis import RedisCache from manifest.caches.sqlite import SQLiteCache from manifest.request import DiffusionRequest, LMRequest, Request +from manifest.response import ArrayModelChoice, ModelChoices, Response def _get_postgres_cache( @@ -78,7 +79,16 @@ def test_key_get_and_set( @pytest.mark.usefixtures("postgres_cache") @pytest.mark.parametrize("cache_type", ["sqlite", "redis", "postgres"]) def test_get( - sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str + sqlite_cache: str, + redis_cache: str, + postgres_cache: str, + cache_type: str, + model_choice: ModelChoices, + model_choice_single: ModelChoices, + model_choice_arr_int: ModelChoices, + request_lm: LMRequest, + request_lm_single: LMRequest, + request_diff: DiffusionRequest, ) -> None: """Test cache save prompt.""" if cache_type == "sqlite": @@ -88,22 +98,51 @@ def test_get( elif cache_type == "postgres": cache = cast(Cache, _get_postgres_cache()) - test_request = {"test": "hello", "testA": "world"} - test_response = {"choices": [{"text": "hello"}]} + response = Response( + response=model_choice_single, + cached=False, + request=request_lm_single, + usages=None, + request_type=LMRequest, + response_type="text", + ) - response = cache.get(test_request) - assert response is None + cache_response = cache.get(request_lm_single.dict()) + assert cache_response is None + + cache.set(request_lm_single.dict(), response.to_dict(drop_request=True)) + cache_response = cache.get(request_lm_single.dict()) + assert cache_response.get_response() == "helloo" + assert cache_response.is_cached() + assert cache_response.get_request_obj() == request_lm_single + + response = Response( + response=model_choice, + cached=False, + request=request_lm, + usages=None, + request_type=LMRequest, + response_type="text", + ) - cache.set(test_request, test_response) - response = cache.get(test_request) - assert response.get_response() == "hello" - assert response.is_cached() - assert response.get_request() == test_request + cache_response = cache.get(request_lm.dict()) + assert cache_response is None + + cache.set(request_lm.dict(), response.to_dict(drop_request=True)) + cache_response = cache.get(request_lm.dict()) + assert cache_response.get_response() == ["hello", "bye"] + assert cache_response.is_cached() + assert cache_response.get_request_obj() == request_lm # Test array - arr = np.random.rand(4, 4) - test_request = {"test": "hello", "testA": "world of images"} - compute_arr_response = {"choices": [{"array": arr}]} + response = Response( + response=model_choice_arr_int, + cached=False, + request=request_diff, + usages=None, + request_type=DiffusionRequest, + response_type="array", + ) if cache_type == "sqlite": cache = SQLiteCache(sqlite_cache, request_type=DiffusionRequest) @@ -112,103 +151,34 @@ def test_get( elif cache_type == "postgres": cache = _get_postgres_cache(request_type=DiffusionRequest) - response = cache.get(test_request) - assert response is None + cache_response = cache.get(request_diff.dict()) + assert cache_response is None - cache.set(test_request, compute_arr_response) - response = cache.get(test_request) - assert np.allclose(response.get_response(), arr) - assert response.is_cached() - assert response.get_request() == test_request + cache.set(request_diff.dict(), response.to_dict(drop_request=True)) + cached_response = cache.get(request_diff.dict()) + assert np.allclose( + cached_response.get_response()[0], + cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array, + ) + assert np.allclose( + cached_response.get_response()[1], + cast(ArrayModelChoice, model_choice_arr_int.choices[1]).array, + ) + assert cached_response.is_cached() + assert cached_response.get_request_obj() == request_diff # Test array byte string - arr = np.random.rand(4, 4) - test_request = {"test": "hello", "testA": "world of images 2"} - compute_arr_response = {"choices": [{"array": arr}]} - - if cache_type == "sqlite": - cache = SQLiteCache( - sqlite_cache, - request_type=DiffusionRequest, - cache_args={"array_serializer": "byte_string"}, - ) - elif cache_type == "redis": - cache = RedisCache( - redis_cache, - request_type=DiffusionRequest, - cache_args={"array_serializer": "byte_string"}, - ) - elif cache_type == "postgres": - cache = _get_postgres_cache( - request_type=DiffusionRequest, - cache_args={"array_serializer": "byte_string"}, - ) - - response = cache.get(test_request) - assert response is None - - cache.set(test_request, compute_arr_response) - response = cache.get(test_request) - assert np.allclose(response.get_response(), arr) - assert response.is_cached() - assert response.get_request() == test_request - - -@pytest.mark.usefixtures("sqlite_cache") -@pytest.mark.usefixtures("redis_cache") -@pytest.mark.usefixtures("postgres_cache") -@pytest.mark.parametrize("cache_type", ["sqlite", "redis", "postgres"]) -def test_get_batch_prompt( - sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str -) -> None: - """Test cache save prompt.""" - if cache_type == "sqlite": - cache = cast(Cache, SQLiteCache(sqlite_cache)) - elif cache_type == "redis": - cache = cast(Cache, RedisCache(redis_cache)) - elif cache_type == "postgres": - cache = cast(Cache, _get_postgres_cache()) - - test_request = {"test": ["hello", "goodbye"], "testA": "world"} - test_response = {"choices": [{"text": "hello"}, {"text": "goodbye"}]} - - response = cache.get(test_request) - assert response is None - - cache.set(test_request, test_response) - response = cache.get(test_request) - assert response.get_response() == ["hello", "goodbye"] - assert response.is_cached() - assert response.get_request() == test_request - - # Test arrays - arr = np.random.rand(4, 4) - arr2 = np.random.rand(4, 4) - test_request = {"test": ["hello", "goodbye"], "testA": "world of images"} - compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]} - - if cache_type == "sqlite": - cache = SQLiteCache(sqlite_cache, request_type=DiffusionRequest) - elif cache_type == "redis": - cache = RedisCache(redis_cache, request_type=DiffusionRequest) - elif cache_type == "postgres": - cache = _get_postgres_cache(request_type=DiffusionRequest) - - response = cache.get(test_request) - assert response is None - - cache.set(test_request, compute_arr_response) - response = cache.get(test_request) - assert np.allclose(response.get_response()[0], arr) - assert np.allclose(response.get_response()[1], arr2) - assert response.is_cached() - assert response.get_request() == test_request - - # Test arrays byte serializer - arr = np.random.rand(4, 4) - arr2 = np.random.rand(4, 4) - test_request = {"test": ["hello", "goodbye"], "testA": "world of images 2"} - compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]} + # Make sure to not hit the cache + new_request_diff = DiffusionRequest(**request_diff.dict()) + new_request_diff.prompt = ["blahhh", "yayayay"] + response = Response( + response=model_choice_arr_int, + cached=False, + request=new_request_diff, + usages=None, + request_type=DiffusionRequest, + response_type="array", + ) if cache_type == "sqlite": cache = SQLiteCache( @@ -228,15 +198,21 @@ def test_get_batch_prompt( cache_args={"array_serializer": "byte_string"}, ) - response = cache.get(test_request) - assert response is None + cached_response = cache.get(new_request_diff.dict()) + assert cached_response is None - cache.set(test_request, compute_arr_response) - response = cache.get(test_request) - assert np.allclose(response.get_response()[0], arr) - assert np.allclose(response.get_response()[1], arr2) - assert response.is_cached() - assert response.get_request() == test_request + cache.set(new_request_diff.dict(), response.to_dict(drop_request=True)) + cached_response = cache.get(new_request_diff.dict()) + assert np.allclose( + cached_response.get_response()[0], + cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array, + ) + assert np.allclose( + cached_response.get_response()[1], + cast(ArrayModelChoice, model_choice_arr_int.choices[1]).array, + ) + assert cached_response.is_cached() + assert cached_response.get_request_obj() == new_request_diff def test_noop_cache() -> None: diff --git a/tests/test_client.py b/tests/test_client.py index 71126e5..5f5810b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -33,10 +33,13 @@ def test_get_request() -> None: "prompt": "hello", "num_results": 3, "engine": "dummy", + "request_cls": "LMRequest", } assert response.get_json_response() == { - "choices": [{"text": "hello"}] * 3, - "usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 3, + "choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 3, + } + assert response.get_usage_obj().dict() == { + "usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 3, } request_params = client.get_request("hello", {"n": 5}) @@ -45,10 +48,13 @@ def test_get_request() -> None: "prompt": "hello", "num_results": 5, "engine": "dummy", + "request_cls": "LMRequest", } assert response.get_json_response() == { - "choices": [{"text": "hello"}] * 5, - "usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5, + "choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5, + } + assert response.get_usage_obj().dict() == { + "usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5, } request_params = client.get_request(["hello"] * 5, {"n": 1}) @@ -57,8 +63,11 @@ def test_get_request() -> None: "prompt": ["hello"] * 5, "num_results": 1, "engine": "dummy", + "request_cls": "LMRequest", } assert response.get_json_response() == { - "choices": [{"text": "hello"}] * 5, - "usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5, + "choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5, + } + assert response.get_usage_obj().dict() == { + "usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5, } diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 52af0e3..9ecd9b6 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -90,8 +90,8 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: if return_response: assert isinstance(result, Response) result = cast(Response, result) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] + assert len(result.get_usage_obj().usages) == len( + result.get_response_obj().choices ) res = result.get_response(manifest.stop_token) else: @@ -101,6 +101,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: { "prompt": "This is a prompt", "engine": "dummy", + "request_cls": "LMRequest", "num_results": n, }, ) @@ -116,8 +117,8 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: if return_response: assert isinstance(result, Response) result = cast(Response, result) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] + assert len(result.get_usage_obj().usages) == len( + result.get_response_obj().choices ) res = result.get_response(manifest.stop_token) else: @@ -127,6 +128,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: { "prompt": "This is a prompt", "engine": "dummy", + "request_cls": "LMRequest", "num_results": n, "run_id": "34", } @@ -143,8 +145,8 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: if return_response: assert isinstance(result, Response) result = cast(Response, result) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] + assert len(result.get_usage_obj().usages) == len( + result.get_response_obj().choices ) res = result.get_response(manifest.stop_token) else: @@ -154,6 +156,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: { "prompt": "Hello is a prompt", "engine": "dummy", + "request_cls": "LMRequest", "num_results": n, }, ) @@ -169,8 +172,8 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: if return_response: assert isinstance(result, Response) result = cast(Response, result) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] + assert len(result.get_usage_obj().usages) == len( + result.get_response_obj().choices ) res = result.get_response(stop_token="ll") else: @@ -180,6 +183,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: { "prompt": "Hello is a prompt", "engine": "dummy", + "request_cls": "LMRequest", "num_results": n, }, ) @@ -212,8 +216,8 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: if return_response: assert isinstance(result, Response) result = cast(Response, result) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] + assert len(result.get_usage_obj().usages) == len( + result.get_response_obj().choices ) res = result.get_response(manifest.stop_token, is_batch=True) else: @@ -224,6 +228,7 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: { "prompt": "This is a prompt", "engine": "dummy", + "request_cls": "LMRequest", "num_results": n, }, ) @@ -235,8 +240,8 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: if return_response: assert isinstance(result, Response) result = cast(Response, result) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] + assert len(result.get_usage_obj().usages) == len( + result.get_response_obj().choices ) res = result.get_response(manifest.stop_token, is_batch=True) else: @@ -247,6 +252,7 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: { "prompt": "Hello is a prompt", "engine": "dummy", + "request_cls": "LMRequest", "num_results": n, }, ) @@ -262,6 +268,7 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: { "prompt": "New prompt", "engine": "dummy", + "request_cls": "LMRequest", "num_results": n, }, ) @@ -272,8 +279,8 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: if return_response: assert isinstance(result, Response) result = cast(Response, result) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] + assert len(result.get_usage_obj().usages) == len( + result.get_response_obj().choices ) res = result.get_response(manifest.stop_token, is_batch=True) # Cached because one item is in cache @@ -287,8 +294,8 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: if return_response: assert isinstance(result, Response) result = cast(Response, result) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] + assert len(result.get_usage_obj().usages) == len( + result.get_response_obj().choices ) res = result.get_response(stop_token="ll", is_batch=True) else: @@ -309,9 +316,7 @@ def test_abatch_run(sqlite_cache: str) -> None: Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) ) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] - ) + assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) res = result.get_response(manifest.stop_token, is_batch=True) assert res == ["hello"] assert ( @@ -319,6 +324,7 @@ def test_abatch_run(sqlite_cache: str) -> None: { "prompt": "This is a prompt", "engine": "dummy", + "request_cls": "LMRequest", "num_results": 1, }, ) @@ -330,9 +336,7 @@ def test_abatch_run(sqlite_cache: str) -> None: Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) ) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] - ) + assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) res = result.get_response(manifest.stop_token, is_batch=True) assert res == ["hello", "hello"] assert ( @@ -340,6 +344,7 @@ def test_abatch_run(sqlite_cache: str) -> None: { "prompt": "Hello is a prompt", "engine": "dummy", + "request_cls": "LMRequest", "num_results": 1, }, ) @@ -350,9 +355,7 @@ def test_abatch_run(sqlite_cache: str) -> None: Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) ) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] - ) + assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) res = result.get_response(manifest.stop_token, is_batch=True) assert result.is_cached() @@ -361,6 +364,7 @@ def test_abatch_run(sqlite_cache: str) -> None: { "prompt": "New prompt", "engine": "dummy", + "request_cls": "LMRequest", "num_results": 1, }, ) @@ -371,9 +375,7 @@ def test_abatch_run(sqlite_cache: str) -> None: Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) ) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] - ) + assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) res = result.get_response(manifest.stop_token, is_batch=True) # Cached because one item is in cache assert result.is_cached() @@ -384,9 +386,7 @@ def test_abatch_run(sqlite_cache: str) -> None: Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) ) - assert len(result.get_json_response()["usage"]) == len( - result.get_json_response()["choices"] - ) + assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) res = result.get_response(stop_token="ll", is_batch=True) assert res == ["he", "he"] @@ -407,25 +407,43 @@ def test_score_run(sqlite_cache: str) -> None: { "prompt": "This is a prompt", "engine": "dummy", + "request_cls": "LMScoreRequest", "num_results": 1, - "request_type": "score_prompt", }, ) is not None ) assert result == { - "generation_key": "choices", - "logits_key": "token_logprobs", - "item_key": "text", - "item_dtype": None, - "response": {"choices": [{"text": "This is a prompt", "logprob": 0.3}]}, + "response": { + "choices": [ + {"text": "This is a prompt", "token_logprobs": [0.3], "tokens": None} + ] + }, + "usages": {"usages": []}, "cached": False, - "request_params": { + "request": { "prompt": "This is a prompt", - "engine": "dummy", - "num_results": 1, - "request_type": "score_prompt", + "engine": "text-ada-001", + "n": 1, + "client_timeout": 60, + "run_id": None, + "batch_size": 8, + "temperature": 0.7, + "max_tokens": 100, + "top_p": 1.0, + "top_k": 50, + "logprobs": None, + "stop_sequences": None, + "num_beams": 1, + "do_sample": False, + "repetition_penalty": 1.0, + "length_penalty": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, }, + "response_type": "text", + "request_type": "LMScoreRequest", + "item_dtype": None, } prompt_list = ["Hello is a prompt", "Hello is another prompt"] @@ -435,8 +453,8 @@ def test_score_run(sqlite_cache: str) -> None: { "prompt": "Hello is a prompt", "engine": "dummy", + "request_cls": "LMScoreRequest", "num_results": 1, - "request_type": "score_prompt", }, ) is not None @@ -446,30 +464,48 @@ def test_score_run(sqlite_cache: str) -> None: { "prompt": "Hello is another prompt", "engine": "dummy", + "request_cls": "LMScoreRequest", "num_results": 1, - "request_type": "score_prompt", }, ) is not None ) assert result == { - "generation_key": "choices", - "logits_key": "token_logprobs", - "item_key": "text", - "item_dtype": None, "response": { "choices": [ - {"text": "Hello is a prompt", "logprob": 0.3}, - {"text": "Hello is another prompt", "logprob": 0.3}, + {"text": "Hello is a prompt", "token_logprobs": [0.3], "tokens": None}, + { + "text": "Hello is another prompt", + "token_logprobs": [0.3], + "tokens": None, + }, ] }, + "usages": {"usages": []}, "cached": False, - "request_params": { + "request": { "prompt": ["Hello is a prompt", "Hello is another prompt"], - "engine": "dummy", - "num_results": 1, - "request_type": "score_prompt", + "engine": "text-ada-001", + "n": 1, + "client_timeout": 60, + "run_id": None, + "batch_size": 8, + "temperature": 0.7, + "max_tokens": 100, + "top_p": 1.0, + "top_k": 50, + "logprobs": None, + "stop_sequences": None, + "num_beams": 1, + "do_sample": False, + "repetition_penalty": 1.0, + "length_penalty": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, }, + "response_type": "text", + "request_type": "LMScoreRequest", + "item_dtype": None, } @@ -644,8 +680,8 @@ def test_openai(sqlite_cache: str) -> None: assert isinstance(response.get_response(), str) and len(response.get_response()) > 0 assert response.get_response() == res assert response.is_cached() is True - assert "usage" in response.get_json_response() - assert response.get_json_response()["usage"][0]["total_tokens"] == 15 + assert response.get_usage_obj().usages + assert response.get_usage_obj().usages[0].total_tokens == 15 response = cast(Response, client.run("Why are there apples?", return_response=True)) assert response.is_cached() is True @@ -662,12 +698,9 @@ def test_openai(sqlite_cache: str) -> None: assert ( isinstance(response.get_response(), list) and len(response.get_response()) == 2 ) - assert ( - "usage" in response.get_json_response() - and len(response.get_json_response()["usage"]) == 2 - ) - assert response.get_json_response()["usage"][0]["total_tokens"] == 15 - assert response.get_json_response()["usage"][1]["total_tokens"] == 16 + assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2 + assert response.get_usage_obj().usages[0].total_tokens == 15 + assert response.get_usage_obj().usages[1].total_tokens == 16 response = cast( Response, client.run("Why are there bananas?", return_response=True) @@ -691,12 +724,9 @@ def test_openai(sqlite_cache: str) -> None: assert ( isinstance(response.get_response(), list) and len(response.get_response()) == 2 ) - assert ( - "usage" in response.get_json_response() - and len(response.get_json_response()["usage"]) == 2 - ) - assert response.get_json_response()["usage"][0]["total_tokens"] == 17 - assert response.get_json_response()["usage"][1]["total_tokens"] == 15 + assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2 + assert response.get_usage_obj().usages[0].total_tokens == 17 + assert response.get_usage_obj().usages[1].total_tokens == 15 response = cast( Response, client.run("Why are there oranges?", return_response=True) @@ -721,8 +751,8 @@ def test_openaichat(sqlite_cache: str) -> None: assert isinstance(response.get_response(), str) and len(response.get_response()) > 0 assert response.get_response() == res assert response.is_cached() is True - assert "usage" in response.get_json_response() - assert response.get_json_response()["usage"][0]["total_tokens"] == 23 + assert response.get_usage_obj().usages + assert response.get_usage_obj().usages[0].total_tokens == 23 response = cast(Response, client.run("Why are there apples?", return_response=True)) assert response.is_cached() is True @@ -749,12 +779,9 @@ def test_openaichat(sqlite_cache: str) -> None: assert ( isinstance(response.get_response(), list) and len(response.get_response()) == 2 ) - assert ( - "usage" in response.get_json_response() - and len(response.get_json_response()["usage"]) == 2 - ) - assert response.get_json_response()["usage"][0]["total_tokens"] == 25 - assert response.get_json_response()["usage"][1]["total_tokens"] == 23 + assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2 + assert response.get_usage_obj().usages[0].total_tokens == 25 + assert response.get_usage_obj().usages[1].total_tokens == 23 response = cast( Response, client.run("Why are there oranges?", return_response=True) @@ -795,8 +822,8 @@ def test_openaiembedding(sqlite_cache: str) -> None: assert isinstance(response.get_response(), np.ndarray) assert np.allclose(response.get_response(), res) assert response.is_cached() is True - assert "usage" in response.get_json_response() - assert response.get_json_response()["usage"][0]["total_tokens"] == 5 + assert response.get_usage_obj().usages + assert response.get_usage_obj().usages[0].total_tokens == 5 response = cast(Response, client.run("Why are there apples?", return_response=True)) assert response.is_cached() is True @@ -817,12 +844,9 @@ def test_openaiembedding(sqlite_cache: str) -> None: assert ( isinstance(response.get_response(), list) and len(response.get_response()) == 2 ) - assert ( - "usage" in response.get_json_response() - and len(response.get_json_response()["usage"]) == 2 - ) - assert response.get_json_response()["usage"][0]["total_tokens"] == 5 - assert response.get_json_response()["usage"][1]["total_tokens"] == 6 + assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2 + assert response.get_usage_obj().usages[0].total_tokens == 5 + assert response.get_usage_obj().usages[1].total_tokens == 6 response = cast( Response, client.run("Why are there bananas?", return_response=True) @@ -857,12 +881,9 @@ def test_openaiembedding(sqlite_cache: str) -> None: and len(res_list) == 2 and isinstance(res_list[0], np.ndarray) ) - assert ( - "usage" in response.get_json_response() - and len(response.get_json_response()["usage"]) == 2 - ) - assert response.get_json_response()["usage"][0]["total_tokens"] == 7 - assert response.get_json_response()["usage"][1]["total_tokens"] == 5 + assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2 + assert response.get_usage_obj().usages[0].total_tokens == 7 + assert response.get_usage_obj().usages[1].total_tokens == 5 response = cast( Response, client.run("Why are there oranges?", return_response=True) diff --git a/tests/test_response.py b/tests/test_response.py index ba8283b..3208876 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,230 +1,301 @@ """Response test.""" -from typing import Any, Dict +from typing import List, cast import numpy as np import pytest from manifest import Response -from manifest.request import LMRequest +from manifest.request import EmbeddingRequest, LMRequest +from manifest.response import ArrayModelChoice, ModelChoices, Usage, Usages -def test_init() -> None: +def test_init( + model_choice: ModelChoices, + model_choice_arr: ModelChoices, + model_choice_arr_int: ModelChoices, + request_lm: LMRequest, + request_array: EmbeddingRequest, +) -> None: """Test response initialization.""" - with pytest.raises(ValueError) as exc_info: - response = Response(4, False, {}) # type: ignore - assert str(exc_info.value) == "Response must be dict. Response is\n4." - with pytest.raises(ValueError) as exc_info: - response = Response({"test": "hello"}, False, {}) - assert str(exc_info.value) == ( - "Response must be serialized to a dict with a nonempty list of choices. " - "Response is\n{'test': 'hello'}." - ) - with pytest.raises(ValueError) as exc_info: - response = Response({"choices": [{"blah": "hello"}]}, False, {}) - assert str(exc_info.value) == ( - "Response must be serialized to a dict " - "with a list of choices with text field" - ) - with pytest.raises(ValueError) as exc_info: - response = Response({"choices": []}, False, {}) - assert str(exc_info.value) == ( - "Response must be serialized to a dict with a nonempty list of choices. " - "Response is\n{'choices': []}." - ) - - response = Response({"choices": [{"text": "hello"}]}, False, {}) - assert response._response == {"choices": [{"text": "hello"}]} + response = Response( + response=model_choice, + cached=False, + request=request_lm, + usages=None, + request_type=LMRequest, + response_type="text", + ) + assert response._response == model_choice assert response._cached is False - assert response._request_params == {} - assert response.item_dtype is None - - response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"}) - assert response._response == {"choices": [{"text": "hello"}]} - assert response._cached is True - assert response._request_params == {"request": "yoyo"} - assert response.item_dtype is None + assert response._request == request_lm + assert response._usages == Usages(usages=[]) + assert response._request_type == LMRequest + assert response._response_type == "text" + assert response._item_dtype is None response = Response( - {"generations": [{"txt": "hello"}], "logits": []}, - False, - {}, - generation_key="generations", - logits_key="logits", - item_key="txt", - ) - assert response._response == {"generations": [{"txt": "hello"}], "logits": []} + response=model_choice_arr_int, + cached=False, + request=request_array, + usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]), + request_type=EmbeddingRequest, + response_type="array", + ) assert response._cached is False - assert response._request_params == {} - assert response.item_dtype is None + assert response._request == request_array + assert sum([usg.total_tokens for usg in response._usages.usages]) == 10 + assert response._request_type == EmbeddingRequest + assert response._response_type == "array" + assert response._item_dtype == "int64" - int_arr = np.random.randint(20, size=(4, 4)) - response = Response( - {"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array" + with pytest.raises(ValueError) as excinfo: + Response( + response=model_choice, + cached=False, + request=request_lm, + usages=None, + request_type=LMRequest, + response_type="blah", + ) + assert "blah" in str(excinfo.value) + + # Can't convert array with text + with pytest.raises(ValueError) as excinfo: + Response( + response=model_choice, + cached=False, + request=request_lm, + usages=None, + request_type=LMRequest, + response_type="array", + ) + assert str(excinfo.value) == ( + "response_type is array but response is " + "" + ) + + # Can't convert text with array + with pytest.raises(ValueError) as excinfo: + Response( + response=model_choice_arr, + cached=False, + request=request_array, + usages=None, + request_type=LMRequest, + response_type="text", + ) + assert str(excinfo.value) == ( + "response_type is text but response is " + "" ) - assert response._response == {"choices": [{"array": int_arr}]} - assert response._cached is True - assert response._request_params == {"request": "yoyo"} - assert response.item_dtype == "int64" -def test_getters() -> None: +def test_getters(model_choice: ModelChoices, request_lm: LMRequest) -> None: """Test response cached.""" - response = Response({"choices": [{"text": "hello"}]}, False, {}) - assert response.get_json_response() == {"choices": [{"text": "hello"}]} + response = Response( + response=model_choice, + cached=False, + request=request_lm, + usages=None, + request_type=LMRequest, + response_type="text", + ) + assert response.get_response_obj() == model_choice assert response.is_cached() is False - assert response.get_request() == {} + assert response.get_request_obj() == request_lm + assert response.get_usage_obj() == Usages(usages=[]) + assert response.get_json_response() == model_choice.dict() + assert response.get_response() == ["hello", "bye"] - response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"}) - assert response.get_json_response() == {"choices": [{"text": "hello"}]} - assert response.is_cached() is True - assert response.get_request() == {"request": "yoyo"} - int_arr = np.random.randint(20, size=(4, 4)) +def test_serialize( + model_choice: ModelChoices, + model_choice_arr: ModelChoices, + model_choice_arr_int: ModelChoices, + request_lm: LMRequest, + request_array: EmbeddingRequest, +) -> None: + """Test response serialization.""" response = Response( - {"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array" + response=model_choice, + cached=False, + request=request_lm, + usages=None, + request_type=LMRequest, + response_type="text", ) - assert response.get_json_response() == {"choices": [{"array": int_arr}]} - assert response.is_cached() is True - assert response.get_request() == {"request": "yoyo"} + deserialized_response = Response.deserialize(response.serialize()) + assert deserialized_response.get_response_obj() == model_choice + assert deserialized_response.is_cached() is False + assert deserialized_response.get_request_obj() == request_lm + assert deserialized_response.get_usage_obj() == Usages(usages=[]) + assert deserialized_response.get_json_response() == model_choice.dict() + assert deserialized_response.get_response() == ["hello", "bye"] + deserialized_response = Response.from_dict(response.to_dict()) + assert deserialized_response.get_response_obj() == model_choice + assert deserialized_response.is_cached() is False + assert deserialized_response.get_request_obj() == request_lm + assert deserialized_response.get_usage_obj() == Usages(usages=[]) + assert deserialized_response.get_json_response() == model_choice.dict() + assert deserialized_response.get_response() == ["hello", "bye"] -def test_serialize() -> None: - """Test response serialization.""" - response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"}) - deserialized_response = Response.deserialize(response.serialize()) - assert deserialized_response._response == {"choices": [{"text": "hello"}]} - assert deserialized_response.is_cached() is True - assert deserialized_response._request_params == {"request": "yoyo"} + deserialized_response = Response.from_dict( + response.to_dict(drop_request=True), request_dict={"prompt": "blahhhh"} + ) + assert deserialized_response.get_response_obj() == model_choice + assert deserialized_response.is_cached() is False + assert deserialized_response.get_request_obj().prompt == "blahhhh" + assert deserialized_response.get_usage_obj() == Usages(usages=[]) + assert deserialized_response.get_json_response() == model_choice.dict() + assert deserialized_response.get_response() == ["hello", "bye"] - int_arr = np.random.randint(20, size=(4, 4)) + # Int type response = Response( - {"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array" + response=model_choice_arr_int, + cached=False, + request=request_array, + usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]), + request_type=EmbeddingRequest, + response_type="array", ) deserialized_response = Response.deserialize(response.serialize()) + assert deserialized_response._item_dtype == "int64" + assert ( + cast( + ArrayModelChoice, deserialized_response.get_response_obj().choices[0] + ).array.dtype + == np.int64 + ) assert np.array_equal( - deserialized_response._response["choices"][0]["array"], int_arr + cast( + ArrayModelChoice, deserialized_response.get_response_obj().choices[0] + ).array, + cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array, ) - assert deserialized_response.is_cached() is True - assert deserialized_response._request_params == {"request": "yoyo"} - float_arr = np.random.randn(4, 4) + # Float type response = Response( - {"choices": [{"array": float_arr}]}, True, {"request": "yoyo"}, item_key="array" + response=model_choice_arr, + cached=False, + request=request_array, + usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]), + request_type=EmbeddingRequest, + response_type="array", ) deserialized_response = Response.deserialize(response.serialize()) + assert deserialized_response._item_dtype == "float64" + assert ( + cast( + ArrayModelChoice, deserialized_response.get_response_obj().choices[0] + ).array.dtype + == np.float64 + ) assert np.array_equal( - deserialized_response._response["choices"][0]["array"], float_arr + cast( + ArrayModelChoice, deserialized_response.get_response_obj().choices[0] + ).array, + cast(ArrayModelChoice, model_choice_arr.choices[0]).array, ) - assert deserialized_response.is_cached() is True - assert deserialized_response._request_params == {"request": "yoyo"} -def test_get_results() -> None: +def test_get_results( + model_choice: ModelChoices, + model_choice_single: ModelChoices, + model_choice_arr: ModelChoices, + request_lm: LMRequest, + request_array: EmbeddingRequest, +) -> None: """Test response get results.""" - response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"}) - assert response.get_response() == "hello" + response = Response( + response=model_choice_single, + cached=False, + request=request_lm, + usages=None, + request_type=LMRequest, + response_type="text", + ) + assert response.get_response() == "helloo" assert response.get_response(stop_token="ll") == "he" assert response.get_response(stop_token="ll", is_batch=True) == ["he"] response = Response( - {"choices": [{"text": "hello"}, {"text": "my"}, {"text": "name"}]}, - True, - {"request": "yoyo"}, + response=model_choice, + cached=False, + request=request_lm, + usages=None, + request_type=LMRequest, + response_type="text", ) - assert response.get_response() == ["hello", "my", "name"] - assert response.get_response(stop_token="m") == ["hello", "", "na"] - assert response.get_response(stop_token="m", is_batch=True) == ["hello", "", "na"] + assert response.get_response() == ["hello", "bye"] + assert response.get_response(stop_token="b") == ["hello", ""] + assert response.get_response(stop_token="y", is_batch=True) == ["hello", "b"] - float_arr = np.random.randn(4, 4) + float_arr1 = cast(ArrayModelChoice, model_choice_arr.choices[0]).array + float_arr2 = cast(ArrayModelChoice, model_choice_arr.choices[1]).array response = Response( - {"choices": [{"array": float_arr}, {"array": float_arr}]}, - True, - {"request": "yoyo"}, - item_key="array", + response=model_choice_arr, + cached=False, + request=request_array, + usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]), + request_type=EmbeddingRequest, + response_type="array", ) - assert response.get_response() == [float_arr, float_arr] - assert response.get_response(stop_token="m") == [float_arr, float_arr] + assert np.array_equal(response.get_response()[0], float_arr1) + assert np.array_equal(response.get_response()[1], float_arr2) + assert np.array_equal(response.get_response(stop_token="t")[0], float_arr1) + assert np.array_equal(response.get_response(stop_token="t")[1], float_arr2) -def test_union_all() -> None: +def test_union_all( + model_choice: ModelChoices, + model_choice_single: ModelChoices, + request_lm: LMRequest, + request_lm_single: LMRequest, +) -> None: """Test union all.""" - request_paramsa = LMRequest(prompt=["apple", "orange", "pear"]).to_dict() - request_paramsa["model"] = "modelA" - response_paramsa = { - "choices": [ - {"text": "hello", "token_logprobs": [1]}, - {"text": "hello 2", "token_logprobs": [1]}, - {"text": "hello 3", "token_logprobs": [1]}, - ] - } - responsea = Response(response_paramsa, False, request_paramsa) + response1 = Response( + response=model_choice, + cached=False, + request=request_lm, + usages=None, + request_type=LMRequest, + response_type="text", + ) - request_paramsb = LMRequest(prompt=["banana", "pineapple", "mango"]).to_dict() - request_paramsb["model"] = "modelB" - response_paramsb = { - "choices": [ - {"text": "bye", "token_logprobs": [2]}, - {"text": "bye 2", "token_logprobs": [2]}, - {"text": "bye 3", "token_logprobs": [2]}, - ] - } - responseb = Response(response_paramsb, False, request_paramsb) + response2 = Response( + response=model_choice_single, + cached=False, + request=request_lm_single, + usages=None, + request_type=LMRequest, + response_type="text", + ) - final_response = Response.union_all([responsea, responseb]) + final_response = Response.union_all([response1, response2]) assert final_response.get_json_response() == { "choices": [ - {"text": "hello", "token_logprobs": [1]}, - {"text": "hello 2", "token_logprobs": [1]}, - {"text": "hello 3", "token_logprobs": [1]}, - {"text": "bye", "token_logprobs": [2]}, - {"text": "bye 2", "token_logprobs": [2]}, - {"text": "bye 3", "token_logprobs": [2]}, + {"text": "hello", "token_logprobs": [0.1, 0.2], "tokens": None}, + {"text": "bye", "token_logprobs": [0.3], "tokens": None}, + {"text": "helloo", "token_logprobs": [0.1, 0.2], "tokens": None}, ] } - final_request = LMRequest( - prompt=["apple", "orange", "pear", "banana", "pineapple", "mango"] - ).to_dict() - final_request["model"] = "modelA" - assert final_response.get_request() == final_request - assert not final_response.is_cached() + assert final_response.get_usage_obj() == Usages(usages=[Usage(), Usage(), Usage()]) + merged_prompts: List[str] = request_lm.prompt + [request_lm_single.prompt] # type: ignore # noqa: E501 + assert final_response.get_request_obj().prompt == merged_prompts + assert final_response.get_request_obj().engine == "dummy::text-ada-001" # Modify A to have usage and cached - response_paramsa_2: Dict[str, Any] = { - "choices": [ - {"text": "hello", "token_logprobs": [1]}, - {"text": "hello 2", "token_logprobs": [1]}, - {"text": "hello 3", "token_logprobs": [1]}, - ], - "usage": [ - {"completion_tokens": 10}, - {"completion_tokens": 10}, - {"completion_tokens": 10}, - ], - } - responsea = Response(response_paramsa_2, True, request_paramsa) + response1 = Response( + response=model_choice, + cached=False, + request=request_lm, + usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]), + request_type=LMRequest, + response_type="text", + ) - final_response = Response.union_all([responsea, responseb]) - assert final_response.get_json_response() == { - "choices": [ - {"text": "hello", "token_logprobs": [1]}, - {"text": "hello 2", "token_logprobs": [1]}, - {"text": "hello 3", "token_logprobs": [1]}, - {"text": "bye", "token_logprobs": [2]}, - {"text": "bye 2", "token_logprobs": [2]}, - {"text": "bye 3", "token_logprobs": [2]}, - ], - "usage": [ - {"completion_tokens": 10}, - {"completion_tokens": 10}, - {"completion_tokens": 10}, - {}, - {}, - {}, - ], - } - final_request = LMRequest( - prompt=["apple", "orange", "pear", "banana", "pineapple", "mango"] - ).to_dict() - final_request["model"] = "modelA" - assert final_response.get_request() == final_request - assert final_response.is_cached() + final_response = Response.union_all([response1, response2]) + assert final_response.get_usage_obj() == Usages( + usages=[Usage(total_tokens=4), Usage(total_tokens=6), Usage()] + ) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index bbc269f..0fbe3b9 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -10,23 +10,23 @@ def test_response_to_key_array() -> None: """Test array serializer initialization.""" serializer = ArraySerializer() arr = np.random.rand(4, 4) - res = {"choices": [{"array": arr}]} + res = {"response": {"choices": [{"array": arr}]}} key = serializer.response_to_key(res) key_dct = json.loads(key) - assert isinstance(key_dct["choices"][0]["array"], str) + assert isinstance(key_dct["response"]["choices"][0]["array"], str) res2 = serializer.key_to_response(key) - assert np.allclose(arr, res2["choices"][0]["array"]) + assert np.allclose(arr, res2["response"]["choices"][0]["array"]) def test_response_to_key_numpybytes() -> None: """Test array serializer initialization.""" serializer = NumpyByteSerializer() arr = np.random.rand(4, 4) - res = {"choices": [{"array": arr}]} + res = {"response": {"choices": [{"array": arr}]}} key = serializer.response_to_key(res) key_dct = json.loads(key) - assert isinstance(key_dct["choices"][0]["array"], str) + assert isinstance(key_dct["response"]["choices"][0]["array"], str) res2 = serializer.key_to_response(key) - assert np.allclose(arr, res2["choices"][0]["array"]) + assert np.allclose(arr, res2["response"]["choices"][0]["array"]) diff --git a/web_app/main.py b/web_app/main.py index 87dbe6c..22368e0 100644 --- a/web_app/main.py +++ b/web_app/main.py @@ -49,7 +49,7 @@ def prompt_manifest(*, manifest_in: schemas.ManifestCreate) -> Dict: return { "response": response.get_response(), "cached": response.is_cached(), - "request_params": response.get_request(), + "request_params": response.get_request_obj(), }