diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 08c13db..0bd2b5a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,18 @@ 0.1.1 - Unreleased --------------------- +Added +^^^^^ +* Async support in arun_batch + +Fixed +^^^^^ +* Batched runs now caches individual items +* Score prompt does not truncate outside token +Removed +^^^^^ +* Deprecated chatGPT in favor of openaichat which uses OpenAI completions +* Deprecated Sessions 0.1.0 - 2022-01-31 --------------------- diff --git a/README.md b/README.md index 5e39018..0cd8baf 100644 --- a/README.md +++ b/README.md @@ -22,12 +22,6 @@ Install with diffusion support: pip install manifest-ml[diffusers] ``` -Install with ChatGPT support: -```bash -pip install manifest-ml[chatgpt] -``` -This installs [pyChatGPT](https://github.com/terry3041/pyChatGPT) and uses the ChatGPT session key to start a session. This key must be set as the `CHATGPT_SESSION_KEY` environment variable or passed in with `client_connection`. - Install with HuggingFace local model support: ```bash pip install manifest-ml[api] @@ -106,24 +100,6 @@ manifest = Manifest( ``` As a hint, if you want to get Redis running, see the `docker run` command below under development. -## Sessions -Each Manifest run supports a session that, in addition to a global cache, connects to a local SQLite DB to store user query history. -```python -manifest = Manifest( - client_name = "openai", - cache_name = "sqlite", - cache_connection = "mycache.sqlite", - session_id = "grass_color", -) -``` -will start a Manifest session with the session name `grass_color`. This can be helpful for a user to logically keep track of sessions, see interaction history, and resume sessions if desired. If the session id provided is `_default`, we generate a random id for the user. - -After a few queries, the user can explore their history -```python -manifest.get_last_queries(4) -``` -will retrieve the last 4 model queries and responses. - ## Running Queries Once you have a session open, you can write and develop prompts. diff --git a/examples/manifest_async.py b/examples/manifest_async.py new file mode 100644 index 0000000..e252c73 --- /dev/null +++ b/examples/manifest_async.py @@ -0,0 +1,27 @@ +import asyncio +import time + +from manifest import Manifest + + +def main(): + + manifest = Manifest( + client_name="openaichat", + ) + + print("Running in serial") + prompts = [f"Tell me something interesting about {i}" for i in range(50)] + st = time.time() + for pmt in prompts: + _ = manifest.run(pmt) + print(f"For loop: {time.time() - st :.2f}") + + print("Running with async") + st = time.time() + _ = asyncio.run(manifest.arun_batch(prompts, max_tokens=30)) + print(f"Async loop: {time.time() - st :.2f}") + + +if __name__ == "__main__": + main() diff --git a/examples/manifest_chatgpt.ipynb b/examples/manifest_chatgpt.ipynb deleted file mode 100644 index 639aef3..0000000 --- a/examples/manifest_chatgpt.ipynb +++ /dev/null @@ -1,63 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "from manifest import Manifest\n", - "import os\n", - "\n", - "# ChatGPT tries hard not to give people programmatic access.\n", - "# As a warning, this will open a browser window.\n", - "# You need to install xvfb and chromium for linux headless mode to work\n", - "# See https://github.com/terry3041/pyChatGPT\n", - "\n", - "# The responses are not fast\n", - "manifest = Manifest(\n", - " client_name=\"chatgpt\",\n", - " client_connection=os.environ.get(\"CHATGPT_SESSION_KEY\"),\n", - ")\n", - "print(manifest.run(\"Describe in a single, short sentence what is the best sandwhich in the world. Be short and concise.\"))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mlcore", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "1ea9cc00d433352044b557b1784ac6e58df03de4b7bb312554014351989eb135" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/manifest/__init__.py b/manifest/__init__.py index cec30e2..00484be 100644 --- a/manifest/__init__.py +++ b/manifest/__init__.py @@ -2,6 +2,5 @@ from manifest.manifest import Manifest from manifest.request import Request from manifest.response import Response -from manifest.session import Session -__all__ = ["Manifest", "Response", "Session"] +__all__ = ["Manifest", "Response", "Request"] diff --git a/manifest/api/app.py b/manifest/api/app.py index c63653c..1856844 100644 --- a/manifest/api/app.py +++ b/manifest/api/app.py @@ -174,13 +174,13 @@ def completions() -> Response: if model_type == "diffuser": # Assign None logprob as it's not supported in diffusers results = [ - {"array": r[0], "logprob": None, "token_logprobs": None} + {"array": r[0], "logprob": None, "tokens": None, "token_logprobs": None} for r in result_gens ] res_type = "image_generation" else: results = [ - {"text": r[0], "logprob": r[1], "token_logprobs": r[2]} + {"text": r[0], "logprob": r[1], "tokens": r[2], "token_logprobs": r[3]} for r in result_gens ] res_type = "text_completion" @@ -241,7 +241,8 @@ def score_sequence() -> Response: { "text": prompt if isinstance(prompt, str) else prompt[i], "logprob": r[0], - "token_logprobs": r[1], + "tokens": r[1], + "token_logprobs": r[2], } for i, r in enumerate(score_list) ] diff --git a/manifest/api/models/diffuser.py b/manifest/api/models/diffuser.py index 26d3a3d..ffbc403 100644 --- a/manifest/api/models/diffuser.py +++ b/manifest/api/models/diffuser.py @@ -74,7 +74,7 @@ class DiffuserModel(Model): @torch.no_grad() def generate( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[Any, float, List[float]]]: + ) -> List[Tuple[Any, float, List[int], List[float]]]: """ Generate the prompt from model. @@ -91,12 +91,12 @@ class DiffuserModel(Model): prompt = [prompt] result = self.pipeline(prompt, output_type="np.array", **kwargs) # Return None for logprobs and token logprobs - return [(im, None, None) for im in result["images"]] + return [(im, None, None, None) for im in result["images"]] @torch.no_grad() def score_sequence( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[float, List[float]]]: + ) -> List[Tuple[float, List[int], List[float]]]: """ Score a sequence of choices. diff --git a/manifest/api/models/huggingface.py b/manifest/api/models/huggingface.py index b504921..038b7c3 100644 --- a/manifest/api/models/huggingface.py +++ b/manifest/api/models/huggingface.py @@ -179,6 +179,7 @@ class GenerationPipeline: "logprobs": logits[ range(num_generated_tokens), i, output_seq[-num_generated_tokens:] ].tolist(), + "tokens": output_seq[-num_generated_tokens:].tolist(), } for i, output_seq in enumerate(output_dict.sequences) ] @@ -547,7 +548,7 @@ class TextGenerationModel(HuggingFaceModel): @torch.no_grad() def generate( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[Any, float, List[float]]]: + ) -> List[Tuple[Any, float, List[int], List[float]]]: """ Generate the prompt from model. @@ -576,6 +577,7 @@ class TextGenerationModel(HuggingFaceModel): ( cast(str, r["generated_text"]), sum(cast(List[float], r["logprobs"])), + cast(List[int], r["tokens"]), cast(List[float], r["logprobs"]), ) for r in result @@ -585,7 +587,7 @@ class TextGenerationModel(HuggingFaceModel): @torch.no_grad() def score_sequence( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[float, List[float]]]: + ) -> List[Tuple[float, List[int], List[float]]]: """ Score a sequence of choices. @@ -610,21 +612,20 @@ class TextGenerationModel(HuggingFaceModel): **encoded_prompt, ).logits # For causal decoders, shift logts and labels - labels_attention_mask = encoded_prompt["attention_mask"].unsqueeze(-1)[ - ..., 1:, : - ] - masked_log_probs = ( - labels_attention_mask.float() - * torch.log_softmax(logits.float(), dim=-1)[..., :-1, :] + labels_attention_mask = encoded_prompt["attention_mask"].unsqueeze(-1) + masked_log_probs = labels_attention_mask.float() * torch.log_softmax( + logits.float(), dim=-1 ) seq_token_log_probs = torch.gather( - masked_log_probs, -1, encoded_prompt["labels"][..., 1:].unsqueeze(-1) + masked_log_probs, -1, encoded_prompt["labels"].unsqueeze(-1) ) seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1) seq_log_prob = seq_token_log_probs.sum(dim=-1) return [ - (seq, seq_token) - for seq, seq_token in zip( - seq_log_prob.tolist(), seq_token_log_probs.tolist() + (seq, tokens, seq_token) + for seq, tokens, seq_token in zip( + seq_log_prob.tolist(), + encoded_prompt["input_ids"].tolist(), + seq_token_log_probs.tolist(), ) ] diff --git a/manifest/api/models/model.py b/manifest/api/models/model.py index 952807d..84d91ab 100644 --- a/manifest/api/models/model.py +++ b/manifest/api/models/model.py @@ -48,7 +48,7 @@ class Model(ABC): def generate( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[Any, float, List[float]]]: + ) -> List[Tuple[Any, float, List[int], List[float]]]: """ Generate the prompt from model. @@ -59,7 +59,7 @@ class Model(ABC): Returns: list of generated text (list of length 1 for 1 generation). - Each item is the response, answer logprob, + Each item is the response, answer logprob, list of tokens, and list of logprobs for each token. """ raise NotImplementedError() @@ -78,7 +78,7 @@ class Model(ABC): def score_sequence( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[float, List[float]]]: + ) -> List[Tuple[float, List[int], List[float]]]: """ Score a sequence of choices. @@ -89,6 +89,6 @@ class Model(ABC): Additional keyword arguments passed along to the :obj:`__call__` method. Returns: - Tuple of scores for each choice and logprobs for the tokens of each choice. + Tuple of total score, tokens, and probs per token. """ raise NotImplementedError() diff --git a/manifest/api/response.py b/manifest/api/response.py index 3ce8051..a0860e9 100644 --- a/manifest/api/response.py +++ b/manifest/api/response.py @@ -36,6 +36,7 @@ class ModelResponse: { key: result[key], "logprob": result["logprob"], + "tokens": result["tokens"], "token_logprobs": result["token_logprobs"], } if key == "text" diff --git a/manifest/caches/cache.py b/manifest/caches/cache.py index 2009f73..2a16557 100644 --- a/manifest/caches/cache.py +++ b/manifest/caches/cache.py @@ -1,22 +1,9 @@ """Cache for queries and responses.""" from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Union +from typing import Any, Dict, Union from manifest.caches.serializers import ArraySerializer, Serializer -from manifest.response import Response - -RESPONSE_CONSTRUCTORS = { - "diffuser": { - "generation_key": "choices", - "logits_key": "token_logprobs", - "item_key": "array", - }, - "tomadiffuser": { - "generation_key": "choices", - "logits_key": "token_logprobs", - "item_key": "array", - }, -} +from manifest.response import RESPONSE_CONSTRUCTORS, Response CACHE_CONSTRUCTOR = { "diffuser": ArraySerializer, @@ -101,20 +88,35 @@ class Cache(ABC): """Commit any results.""" raise NotImplementedError() - def get( - self, request: Dict, overwrite_cache: bool, compute: Callable[[], Dict] - ) -> Response: - """Get the result of request (by calling compute as needed).""" + def get(self, request: Dict) -> Union[Response, None]: + """Get the result of request (by calling compute as needed). + + Args: + request: request to get. + response: response to get. + + Returns: + Response object or None if not in cache. + """ key = self.serializer.request_to_key(request) cached_response = self.get_key(key) - if cached_response and not overwrite_cache: + if cached_response: cached = True response = self.serializer.key_to_response(cached_response) - else: - # Type Response - response = compute() - self.set_key(key, self.serializer.response_to_key(response)) - cached = False - return Response( - response, cached, request, **RESPONSE_CONSTRUCTORS.get(self.client_name, {}) - ) + return Response( + response, + cached, + request, + **RESPONSE_CONSTRUCTORS.get(self.client_name, {}) + ) + return None + + def set(self, request: Dict, response: Dict) -> None: + """Set the value for the key. + + Args: + request: request to set. + response: response to set. + """ + key = self.serializer.request_to_key(request) + self.set_key(key, self.serializer.response_to_key(response)) diff --git a/manifest/caches/postgres.py b/manifest/caches/postgres.py index 46bd211..c7932b1 100644 --- a/manifest/caches/postgres.py +++ b/manifest/caches/postgres.py @@ -100,7 +100,7 @@ class PostgresCache(Cache): table: table to get key in. """ request = ( - self.session.query(Request) + self.session.query(Request) # type: ignore .filter_by(key=self._hash_key(key, table)) .first() ) @@ -119,7 +119,7 @@ class PostgresCache(Cache): table: table to set key in. """ key = self._hash_key(key, table) - request = self.session.query(Request).filter_by(key=key).first() + request = self.session.query(Request).filter_by(key=key).first() # type: ignore if request: request.response = value # type: ignore else: diff --git a/manifest/clients/ai21.py b/manifest/clients/ai21.py index e24738a..24da44d 100644 --- a/manifest/clients/ai21.py +++ b/manifest/clients/ai21.py @@ -27,9 +27,9 @@ class AI21Client(Client): "n": ("numResults", 1), "top_p": ("topP", 1.0), "stop_sequences": ("stopSequences", []), - "client_timeout": ("client_timeout", 60), # seconds } REQUEST_CLS = LMRequest + NAME = "ai21" def connect( self, diff --git a/manifest/clients/chatgpt.py b/manifest/clients/chatgpt.py deleted file mode 100644 index 87f10c5..0000000 --- a/manifest/clients/chatgpt.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Client class.""" -import logging -import os -from typing import Any, Callable, Dict, Optional, Tuple - -from pyChatGPT import ChatGPT - -from manifest.clients.client import Client -from manifest.request import LMRequest, Request - -logger = logging.getLogger(__name__) - - -class ChatGPTClient(Client): - """ChatGPT Client class.""" - - # No params for ChatGPT - PARAMS: Dict[str, Tuple[str, Any]] = {} - REQUEST_CLS = LMRequest - - def connect( - self, connection_str: Optional[str], client_args: Dict[str, Any] - ) -> None: - """ - Connect to ChatGPT. - - We use https://github.com/terry3041/pyChatGPT. - - Arsg: - connection_str: connection string. - client_args: client arguments. - """ - self.session_key = os.environ.get("CHATGPT_SESSION_KEY", connection_str) - if self.session_key is None: - raise ValueError( - "ChatGPT session key not set. Set CHATGPT_SESSION_KEY environment " - "variable or pass through `client_connection`. " - "For details, see https://github.com/terry3041/pyChatGPT " - "and go through instructions for getting a session key." - ) - self.host = None - for key in self.PARAMS: - setattr(self, key, client_args.pop(key, self.PARAMS[key][1])) - self._chat_session = ChatGPT(self.session_key, verbose=False) - - def close(self) -> None: - """Close the client.""" - self._chat_session = None - - def clear_conversations(self) -> None: - """Clear conversations. - - Only works for ChatGPT. - """ - self._chat_session.clear_conversations() - - def get_generation_url(self) -> str: - """Get generation URL.""" - return "" - - def get_generation_header(self) -> Dict[str, str]: - """ - Get generation header. - - Returns: - header. - """ - return {} - - def supports_batch_inference(self) -> bool: - """Return whether the client supports batch inference.""" - return False - - def get_model_params(self) -> Dict: - """ - Get model params. - - By getting model params from the server, we can add to request - and make sure cache keys are unique to model. - - Returns: - model params. - """ - return {"model_name": "chatgpt", "engine": "chatgpt"} - - def format_response(self, response: Dict) -> Dict[str, Any]: - """ - Format response to dict. - - Args: - response: response - - Return: - response as dict - """ - return { - "model": "chatgpt", - "choices": [ - { - "text": response["message"], - } - ], - } - - def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]: - """ - Get request string function. - - Args: - request: request. - - Returns: - request function that takes no input. - request parameters as dict. - """ - if isinstance(request.prompt, list): - raise ValueError("ChatGPT does not support batch inference.") - - prompt = str(request.prompt) - request_params = request.to_dict(self.PARAMS) - - def _run_completion() -> Dict: - try: - res = self._chat_session.send_message(prompt) - except Exception as e: - logger.error(f"ChatGPT error {e}.") - raise e - return self.format_response(res) - - return _run_completion, request_params diff --git a/manifest/clients/client.py b/manifest/clients/client.py index 94a8854..ada9a79 100644 --- a/manifest/clients/client.py +++ b/manifest/clients/client.py @@ -1,11 +1,16 @@ """Client class.""" +import asyncio +import copy import logging +import math from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union +import aiohttp import requests -from manifest.request import Request +from manifest.request import DEFAULT_REQUEST_KEYS, NOT_CACHE_KEYS, Request +from manifest.response import RESPONSE_CONSTRUCTORS, Response logger = logging.getLogger(__name__) @@ -16,6 +21,7 @@ class Client(ABC): # Must be overridden by child class PARAMS: Dict[str, Tuple[str, Any]] = {} REQUEST_CLS = Request + NAME: str = None def __init__( self, connection_str: Optional[str] = None, client_args: Dict[str, Any] = {} @@ -93,7 +99,7 @@ class Client(ABC): """ return list(self.PARAMS.keys()) - def get_request_params( + def get_request( self, prompt: Union[str, List[str]], request_args: Dict[str, Any] ) -> Request: """ @@ -109,7 +115,32 @@ class Client(ABC): params = {"prompt": prompt} for key in self.PARAMS: params[key] = request_args.pop(key, getattr(self, key)) - return self.REQUEST_CLS(**params) + for key in DEFAULT_REQUEST_KEYS: + if key not in params and key in request_args: + params[key] = request_args.pop(key) + return self.REQUEST_CLS(**params) # type: ignore + + def get_request_params(self, request: Request) -> Dict[str, Any]: + """Get request params. + + Add default keys that we need for requests such as batch_size. + We drop these before sending to the model. + """ + params_to_add = DEFAULT_REQUEST_KEYS.copy() + params_to_add.update(self.PARAMS) + request_params = request.to_dict(params_to_add) + return request_params + + def get_cache_key(self, request: Request) -> Dict[str, Any]: + """Get cache key for request. + + Skip keys that are not cache keys such as batch_size. + """ + request_params = self.get_request_params(request) + for key in NOT_CACHE_KEYS: + request_params.pop(key, None) + request_params.update(self.get_model_params()) + return request_params def format_response(self, response: Dict) -> Dict[str, Any]: """ @@ -125,7 +156,90 @@ class Client(ABC): raise ValueError(f"Invalid response: {response}") return response - def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]: + def split_requests( + self, request_params: Dict[str, Any], batch_size: int, key: str = "prompt" + ) -> List[Dict[str, Any]]: + """Split request into batch_sized request. + + Args: + request_params: request params. + batch_size: batch size for requests. + key: key to batch over + + Returns: + list of request params. + """ + data = copy.deepcopy(request_params[key]) + data_size = len(request_params[key]) + request_params_list = [] + for i in range(0, data_size, batch_size): + params = copy.deepcopy(request_params) + params[key] = data[i] if batch_size == 1 else data[i : i + batch_size] + request_params_list.append(params) + return request_params_list + + def _run_completion( + self, request_params: Dict[str, Any], retry_timeout: int + ) -> Dict: + """Execute completion request. + + Args: + request_params: request params. + retry_timeout: retry timeout. + + Returns: + response as dict. + """ + post_str = self.get_generation_url() + try: + res = requests.post( + post_str, + headers=self.get_generation_header(), + json=request_params, + timeout=retry_timeout, + ) + res.raise_for_status() + except requests.Timeout as e: + logger.error( + f"{self.__class__.__name__} request timed out." + " Increase client_timeout." + ) + raise e + except requests.exceptions.HTTPError: + logger.error(res.json()) + raise requests.exceptions.HTTPError(res.json()) + return self.format_response(res.json()) + + async def _arun_completion( + self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int + ) -> Dict: + """Async execute completion request. + + Args: + request_params: request params. + retry_timeout: retry timeout. + batch_size: batch size for requests. + + Returns: + response as dict. + """ + post_str = self.get_generation_url() + try: + async with aiohttp.ClientSession(timeout=retry_timeout) as session: + async with session.post( + post_str, + headers=self.get_generation_header(), + json=request_params, + timeout=retry_timeout, + ) as res: + res.raise_for_status() + res_json = await res.json(content_type=None) + return self.format_response(res_json) + except aiohttp.ClientError as e: + logger.error(f"{self.__class__.__name__} request error {e}") + raise e + + def run_request(self, request: Request) -> Response: """ Get request string function. @@ -133,44 +247,80 @@ class Client(ABC): request: request. Returns: - request function that takes no input. - request parameters as dict. + response. """ if isinstance(request.prompt, list) and not self.supports_batch_inference(): raise ValueError( f"{self.__class__.__name__} does not support batch inference." ) - request_params = request.to_dict(self.PARAMS) + request_params = self.get_request_params(request) + # Take the default keys we need and drop the rest as they + # are not part of the model request. retry_timeout = request_params.pop("client_timeout") + for key in DEFAULT_REQUEST_KEYS: + request_params.pop(key, None) + response_dict = self._run_completion(request_params, retry_timeout) + return Response( + response_dict, + cached=False, + request_params=request_params, + **RESPONSE_CONSTRUCTORS.get(self.NAME, {}), # type: ignore + ) - def _run_completion() -> Dict: - post_str = self.get_generation_url() - try: - res = requests.post( - post_str, - headers=self.get_generation_header(), - json=request_params, - timeout=retry_timeout, - ) - res.raise_for_status() - except requests.Timeout as e: - logger.error( - f"{self.__class__.__name__} request timed out." - " Increase client_timeout." - ) - raise e - except requests.exceptions.HTTPError: - logger.error(res.json()) - raise requests.exceptions.HTTPError(res.json()) - return self.format_response(res.json()) - - return _run_completion, request_params + async def arun_batch_request(self, request: Request) -> Response: + """ + Get async request string function. + + Args: + request: request. + + Returns: + response. + """ + required_batch_size = None + if not self.supports_batch_inference(): + required_batch_size = 1 + if not isinstance(request.prompt, list): + raise AssertionError( + "request.prompt must be a list for async batch inference." + ) + + request_params = self.get_request_params(request) + # Take the default keys we need and drop the rest as they + # are not part of the model request. + retry_timeout = request_params.pop("client_timeout") + batch_size = request_params.pop("batch_size") + batch_size = required_batch_size or batch_size + for key in DEFAULT_REQUEST_KEYS: + request_params.pop(key, None) + + num_batches = len(request.prompt) // batch_size + if len(request.prompt) % batch_size != 0: + batch_size = int(math.ceil(len(request.prompt) / (num_batches + 1))) + + request_batches = self.split_requests(request_params, batch_size) + all_tasks = [ + asyncio.create_task(self._arun_completion(batch, retry_timeout, batch_size)) + for batch in request_batches + ] + responses = await asyncio.gather(*all_tasks) + # Flatten responses + choices = [] + for res_dict in responses: + choices.extend(res_dict["choices"]) + final_response_dict = self.format_response({"choices": choices}) + return Response( + final_response_dict, + cached=False, + request_params=request_params, + **RESPONSE_CONSTRUCTORS.get(self.NAME, {}), # type: ignore + ) def get_score_prompt_request( self, request: Request, - ) -> Tuple[Callable[[], Dict], Dict]: + ) -> Response: """ Get the logit score of the prompt via a forward pass of the model. diff --git a/manifest/clients/cohere.py b/manifest/clients/cohere.py index dd7c337..5391fdf 100644 --- a/manifest/clients/cohere.py +++ b/manifest/clients/cohere.py @@ -26,9 +26,9 @@ class CohereClient(Client): "frequency_penalty": ("frequency_penalty", 0.0), "presence_penalty": ("presence_penalty", 0.0), "stop_sequences": ("stop_sequences", None), - "client_timeout": ("client_timeout", 60), # seconds } REQUEST_CLS = LMRequest + NAME = "cohere" def connect( self, diff --git a/manifest/clients/diffuser.py b/manifest/clients/diffuser.py index 3b17279..9d43f62 100644 --- a/manifest/clients/diffuser.py +++ b/manifest/clients/diffuser.py @@ -22,9 +22,9 @@ class DiffuserClient(Client): "n": ("num_images_per_prompt", 1), "guidance_scale": ("guidance_scale", 7.5), "eta": ("eta", 0.0), - "client_timeout": ("client_timeout", 120), # seconds } REQUEST_CLS = DiffusionRequest + NAME = "diffuser" def connect( self, diff --git a/manifest/clients/dummy.py b/manifest/clients/dummy.py index f47a5a1..e43b519 100644 --- a/manifest/clients/dummy.py +++ b/manifest/clients/dummy.py @@ -1,9 +1,10 @@ """Dummy client.""" import logging -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Dict, Optional from manifest.clients.client import Client from manifest.request import LMRequest, Request +from manifest.response import Response logger = logging.getLogger(__name__) @@ -16,6 +17,7 @@ class DummyClient(Client): "n": ("num_results", 1), } REQUEST_CLS = LMRequest + NAME = "dummy" def connect( self, @@ -67,7 +69,7 @@ class DummyClient(Client): """ return {"engine": "dummy"} - def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]: + def run_request(self, request: Request) -> Response: """ Get request string function. @@ -84,19 +86,29 @@ class DummyClient(Client): num_results = 1 request_params = request.to_dict(self.PARAMS) - def _run_completion() -> Dict: - return { - "choices": [{"text": "hello"}] - * int(request_params["num_results"]) - * num_results - } + response_dict = { + "choices": [{"text": "hello"}] + * int(request_params["num_results"]) + * num_results + } + return Response(response_dict, False, request_params) - return _run_completion, request_params + async def arun_batch_request(self, request: Request) -> Response: + """ + Get async request string function. + + Args: + request: request. + + Returns: + response. + """ + return self.run_request(request) def get_score_prompt_request( self, request: Request, - ) -> Tuple[Callable[[], Dict], Dict]: + ) -> Response: """ Get the logit score of the prompt via a forward pass of the model. @@ -113,17 +125,15 @@ class DummyClient(Client): num_results = 1 request_params = {"prompt": request.prompt} - def _run_completion() -> Dict: - return { - "choices": [ - { - "text": request.prompt - if isinstance(request.prompt, str) - else request.prompt[i], - "logprob": 0.3, - } - for i in range(num_results) - ] - } - - return _run_completion, request_params + response_dict = { + "choices": [ + { + "text": request.prompt + if isinstance(request.prompt, str) + else request.prompt[i], + "logprob": 0.3, + } + for i in range(num_results) + ] + } + return Response(response_dict, False, request_params) diff --git a/manifest/clients/huggingface.py b/manifest/clients/huggingface.py index d0f889e..f29bc43 100644 --- a/manifest/clients/huggingface.py +++ b/manifest/clients/huggingface.py @@ -1,11 +1,12 @@ """Hugging Face client.""" import logging -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Dict, Optional import requests from manifest.clients.client import Client -from manifest.request import LMRequest, Request +from manifest.request import DEFAULT_REQUEST_KEYS, LMRequest, Request +from manifest.response import Response logger = logging.getLogger(__name__) @@ -22,9 +23,9 @@ class HuggingFaceClient(Client): "top_k": ("top_k", 50), "repetition_penalty": ("repetition_penalty", 1.0), "do_sample": ("do_sample", True), - "client_timeout": ("client_timeout", 120), # seconds } REQUEST_CLS = LMRequest + NAME = "huggingface" def connect( self, @@ -81,7 +82,7 @@ class HuggingFaceClient(Client): def get_score_prompt_request( self, request: Request, - ) -> Tuple[Callable[[], Dict], Dict]: + ) -> Response: """ Get the logit score of the prompt via a forward pass of the model. @@ -92,26 +93,26 @@ class HuggingFaceClient(Client): request function that takes no input. request parameters as dict. """ - request_params = request.to_dict(self.PARAMS) + request_params = self.get_request_params(request) retry_timeout = request_params.pop("client_timeout") + for key in DEFAULT_REQUEST_KEYS: + request_params.pop(key, None) # Do not add params like we do with request as the model isn't sampling request_params = {"prompt": request.prompt} - def _run_completion() -> Dict: - post_str = self.host + "/score_sequence" - try: - res = requests.post( - post_str, - json=request_params, - timeout=retry_timeout, - ) - res.raise_for_status() - except requests.Timeout as e: - logger.error("HF request timed out. Increase client_timeout.") - raise e - except requests.exceptions.HTTPError as e: - logger.error(res.text) - raise e - return res.json() - - return _run_completion, request_params + post_str = self.host + "/score_sequence" + try: + res = requests.post( + post_str, + json=request_params, + timeout=retry_timeout, + ) + res.raise_for_status() + except requests.Timeout as e: + logger.error("HF request timed out. Increase client_timeout.") + raise e + except requests.exceptions.HTTPError as e: + logger.error(res.text) + raise e + response_dict = res.json() + return Response(response_dict, cached=False, request_params=request_params) diff --git a/manifest/clients/openai.py b/manifest/clients/openai.py index 5cbe688..fce20dc 100644 --- a/manifest/clients/openai.py +++ b/manifest/clients/openai.py @@ -38,9 +38,9 @@ class OpenAIClient(Client): "stop_sequences": ("stop", None), # OpenAI doesn't like empty lists "presence_penalty": ("presence_penalty", 0.0), "frequency_penalty": ("frequency_penalty", 0.0), - "client_timeout": ("client_timeout", 60), # seconds } REQUEST_CLS = LMRequest + NAME = "openai" def connect( self, diff --git a/manifest/clients/openaichat.py b/manifest/clients/openaichat.py new file mode 100644 index 0000000..2a8b690 --- /dev/null +++ b/manifest/clients/openaichat.py @@ -0,0 +1,171 @@ +"""OpenAIChat client.""" +import copy +import logging +import os +from typing import Any, Dict, Optional + +from manifest.clients.client import Client +from manifest.request import LMRequest + +logger = logging.getLogger(__name__) + +OPENAICHAT_ENGINES = { + "gpt-3.5-turbo", +} + + +class OpenAIChatClient(Client): + """OpenAI Chat client.""" + + # User param -> (client param, default value) + PARAMS = { + "engine": ("model", "gpt-3.5-turbo"), + "temperature": ("temperature", 1.0), + "max_tokens": ("max_tokens", 10), + "n": ("n", 1), + "top_p": ("top_p", 1.0), + "stop_sequences": ("stop", None), # OpenAI doesn't like empty lists + "presence_penalty": ("presence_penalty", 0.0), + "frequency_penalty": ("frequency_penalty", 0.0), + } + REQUEST_CLS = LMRequest + NAME = "openaichat" + + def connect( + self, + connection_str: Optional[str] = None, + client_args: Dict[str, Any] = {}, + ) -> None: + """ + Connect to the OpenAI server. + + connection_str is passed as default OPENAI_API_KEY if variable not set. + + Args: + connection_str: connection string. + client_args: client arguments. + """ + self.api_key = os.environ.get("OPENAI_API_KEY", connection_str) + if self.api_key is None: + raise ValueError( + "OpenAI API key not set. Set OPENAI_API_KEY environment " + "variable or pass through `client_connection`." + ) + self.host = "https://api.openai.com/v1" + for key in self.PARAMS: + setattr(self, key, client_args.pop(key, self.PARAMS[key][1])) + if getattr(self, "engine") not in OPENAICHAT_ENGINES: + raise ValueError( + f"Invalid engine {getattr(self, 'engine')}. " + f"Must be {OPENAICHAT_ENGINES}." + ) + + def close(self) -> None: + """Close the client.""" + pass + + def get_generation_url(self) -> str: + """Get generation URL.""" + return self.host + "/chat/completions" + + def get_generation_header(self) -> Dict[str, str]: + """ + Get generation header. + + Returns: + header. + """ + return {"Authorization": f"Bearer {self.api_key}"} + + def supports_batch_inference(self) -> bool: + """Return whether the client supports batch inference.""" + return False + + def get_model_params(self) -> Dict: + """ + Get model params. + + By getting model params from the server, we can add to request + and make sure cache keys are unique to model. + + Returns: + model params. + """ + return {"model_name": "openaichat", "engine": getattr(self, "engine")} + + def _format_request_for_chat(self, request_params: Dict[str, Any]) -> Dict: + """Format request params for chat. + + Args: + request_params: request params. + + Returns: + formatted request params. + """ + # Format for chat model + request_params = copy.deepcopy(request_params) + prompt = request_params.pop("prompt") + if isinstance(prompt, str): + prompt_list = [prompt] + else: + prompt_list = prompt + messages = [{"role": "user", "content": prompt} for prompt in prompt_list] + request_params["messages"] = messages + return request_params + + def _format_request_for_text(self, response_dict: Dict[str, Any]) -> Dict: + """Format response for text. + + Args: + response_dict: response. + + Return: + formatted response. + """ + new_choices = [] + response_dict = copy.deepcopy(response_dict) + for message in response_dict["choices"]: + new_choices.append({"text": message["message"]["content"]}) + response_dict["choices"] = new_choices + return response_dict + + def _run_completion( + self, request_params: Dict[str, Any], retry_timeout: int + ) -> Dict: + """Execute completion request. + + Args: + request_params: request params. + retry_timeout: retry timeout. + + Returns: + response as dict. + """ + # Format for chat model + request_params = self._format_request_for_chat(request_params) + response_dict = super()._run_completion(request_params, retry_timeout) + # Reformat for text model + response_dict = self._format_request_for_text(response_dict) + return response_dict + + async def _arun_completion( + self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int + ) -> Dict: + """Async execute completion request. + + Args: + request_params: request params. + retry_timeout: retry timeout. + batch_size: batch size for requests. + + Returns: + response as dict. + """ + # Format for chat model + request_params = self._format_request_for_chat(request_params) + response_dict = await super()._arun_completion( + request_params, retry_timeout, batch_size + ) + # Reformat for text model + response_dict = self._format_request_for_text(response_dict) + return response_dict diff --git a/manifest/clients/toma.py b/manifest/clients/toma.py index af7cd4d..28615e4 100644 --- a/manifest/clients/toma.py +++ b/manifest/clients/toma.py @@ -31,9 +31,9 @@ class TOMAClient(Client): "top_p": ("top_p", 0.9), "top_k": ("top_k", 40), "stop_sequences": ("stop", []), - "client_timeout": ("client_timeout", 120), # seconds } REQUEST_CLS = LMRequest + NAME = "toma" def connect( self, diff --git a/manifest/clients/toma_diffuser.py b/manifest/clients/toma_diffuser.py index d5bf7a8..15721a5 100644 --- a/manifest/clients/toma_diffuser.py +++ b/manifest/clients/toma_diffuser.py @@ -30,9 +30,9 @@ class TOMADiffuserClient(TOMAClient): "width": ("width", 512), "n": ("n", 1), "guidance_scale": ("guidance_scale", 7.5), - "client_timeout": ("client_timeout", 120), # seconds } REQUEST_CLS = DiffusionRequest # type: ignore + NAME = "tomadiffuser" def get_model_params(self) -> Dict: """ diff --git a/manifest/manifest.py b/manifest/manifest.py index 3793b7b..ac3c2e0 100644 --- a/manifest/manifest.py +++ b/manifest/manifest.py @@ -1,4 +1,5 @@ """Manifest class.""" +import copy import logging from typing import Any, Dict, List, Optional, Tuple, Union, cast @@ -13,40 +14,32 @@ from manifest.clients.cohere import CohereClient from manifest.clients.dummy import DummyClient from manifest.clients.huggingface import HuggingFaceClient from manifest.clients.openai import OpenAIClient +from manifest.clients.openaichat import OpenAIChatClient from manifest.clients.toma import TOMAClient from manifest.request import Request from manifest.response import Response -from manifest.session import Session logging.getLogger("openai").setLevel(logging.WARNING) logger = logging.getLogger(__name__) CLIENT_CONSTRUCTORS = { - "openai": OpenAIClient, - "cohere": CohereClient, - "ai21": AI21Client, - "huggingface": HuggingFaceClient, - "dummy": DummyClient, - "toma": TOMAClient, + OpenAIClient.NAME: OpenAIClient, + OpenAIChatClient.NAME: OpenAIChatClient, + CohereClient.NAME: CohereClient, + AI21Client.NAME: AI21Client, + HuggingFaceClient.NAME: HuggingFaceClient, + DummyClient.NAME: DummyClient, + TOMAClient.NAME: TOMAClient, } -# ChatGPT -try: - from manifest.clients.chatgpt import ChatGPTClient - - CLIENT_CONSTRUCTORS["chatgpt"] = ChatGPTClient -except Exception: - logger.info("ChatGPT not installed. Skipping import.") - pass - # Diffusion DIFFUSION_CLIENTS = ["diffuser", "tomadiffuser"] try: from manifest.clients.diffuser import DiffuserClient from manifest.clients.toma_diffuser import TOMADiffuserClient - CLIENT_CONSTRUCTORS["diffuser"] = DiffuserClient - CLIENT_CONSTRUCTORS["tomadiffuser"] = TOMADiffuserClient + CLIENT_CONSTRUCTORS[DiffuserClient.NAME] = DiffuserClient + CLIENT_CONSTRUCTORS[TOMADiffuserClient.NAME] = TOMADiffuserClient except Exception: logger.info("Diffusion not supported. Skipping import.") pass @@ -69,7 +62,6 @@ class Manifest: client_connection: Optional[str] = None, cache_name: str = "noop", cache_connection: Optional[str] = None, - session_id: Optional[str] = None, stop_token: str = "", **kwargs: Any, ): @@ -81,9 +73,6 @@ class Manifest: client_connection: connection string for client. cache_name: name of cache. cache_connection: connection string for cache. - session_id: session id for user session cache. - None (default) means no session logging. - "_default" means generate new session id. stop_token: stop token prompt generation. Can be overridden in run @@ -114,17 +103,6 @@ class Manifest: self.client = CLIENT_CONSTRUCTORS[self.client_name]( # type: ignore client_connection, client_args=kwargs ) - if session_id is not None: - if self.client_name == "diffuser": - raise NotImplementedError( - "Session logging not implemented for Diffuser client." - ) - if session_id == "_default": - # Set session_id to None for Session random id - session_id = None - self.session = Session(session_id) - else: - self.session = None if len(kwargs) > 0: raise ValueError(f"{list(kwargs.items())} arguments are not recognized.") @@ -195,11 +173,133 @@ class Manifest: logger.warning(f"{list(request_unused_kwargs)} arguments are unused.") return + def _split_cached_requests( + self, + request: Request, + overwrite_cache: bool, + ) -> Tuple[Dict[int, Response], Request]: + """Split a request into cached responses and Requests to run. + + Args: + request: request object. + overwrite_cache: whether to overwrite cache. + + Returns: + cached_idx_to_response: dict of cached responses. + new_request: request object with only prompts to run. + """ + cached_idx_to_response: Dict[int, Response] = {} + new_request = copy.deepcopy(request) + if not overwrite_cache: + if isinstance(new_request.prompt, list): + new_request.prompt = [] + for idx, prompt_str in enumerate(request.prompt): + single_request = copy.deepcopy(request) + single_request.prompt = prompt_str + possible_response = self.cache.get( + self.client.get_cache_key(single_request) + ) + if possible_response: + cached_idx_to_response[idx] = possible_response + else: + new_request.prompt.append(prompt_str) + else: + possible_response = self.cache.get( + self.client.get_cache_key(new_request) + ) + if possible_response: + cached_idx_to_response[0] = possible_response + new_request.prompt = None + return cached_idx_to_response, new_request + + def _stitch_responses_and_cache( + self, + request: Request, + response: Union[Response, None], + cached_idx_to_response: Dict[int, Response], + ) -> Response: + """Stich together the cached and uncached responses.""" + # We stitch the responses (the choices) here from both the new request the + # cached entries. + all_model_choices = [] + all_input_prompts = [] + response_idx = 0 + number_prompts = len(cached_idx_to_response) + single_output = False + if response: + if isinstance(response.get_request()["prompt"], str): + single_output = True + number_prompts += 1 + else: + number_prompts += len(response.get_request()["prompt"]) + response_gen_key = None + response_logits_key = None + response_item_key = None + for idx in range(number_prompts): + if idx in cached_idx_to_response: + cached_res = cached_idx_to_response[idx] + response_gen_key = cached_res.generation_key + response_logits_key = cached_res.logits_key + response_item_key = cached_res.item_key + all_input_prompts.append(cached_res.get_request()["prompt"]) + if request.n == 1: + assert ( + len(cached_res.get_json_response()[response_gen_key]) == 1 + ), "cached response should have only one choice" + all_model_choices.append( + cached_res.get_json_response()[response_gen_key][0] + ) + else: + all_model_choices.extend( + cached_res.get_json_response()[response_gen_key] + ) + else: + assert response is not None, "response should not be None" + response = cast(Response, response) + response_gen_key = response.generation_key + response_logits_key = response.logits_key + response_item_key = response.item_key + # the choices list in the response is a flat one. + # length is request.n * num_prompts + current_choices = response.get_json_response()[response_gen_key][ + response_idx * request.n : (response_idx + 1) * request.n + ] + all_model_choices.extend(current_choices) + + if isinstance(response.get_request()["prompt"], list): + prompt = response.get_request()["prompt"][response_idx] + else: + prompt = str(response.get_request()["prompt"]) + all_input_prompts.append(prompt) + # set cache + new_request = copy.deepcopy(request) + new_request.prompt = prompt + cache_key = self.client.get_cache_key(new_request) + new_response_key = copy.deepcopy(response.get_json_response()) + new_response_key[response_gen_key] = current_choices + self.cache.set(cache_key, new_response_key) + response_idx += 1 + + new_request = copy.deepcopy(request) + new_request.prompt = ( + all_input_prompts + if len(all_input_prompts) > 1 or not single_output + else all_input_prompts[0] + ) + response_obj = Response( + {response_gen_key: all_model_choices}, + cached=len(cached_idx_to_response) > 0, + request_params=self.client.get_cache_key(new_request), + generation_key=response_gen_key, + logits_key=response_logits_key, + item_key=response_item_key, + ) + return response_obj + def run( self, prompt: Union[str, List[str]], overwrite_cache: bool = False, - run_id: Optional[str] = None, stop_token: Optional[str] = None, return_response: bool = False, **kwargs: Any, @@ -210,7 +310,6 @@ class Manifest: Args: prompt: prompt(s) to run. overwrite_cache: whether to overwrite cache. - run_id: run id for cache to repeat same run. stop_token: stop token for prompt generation. Default is self.stop_token. "" for no stop token. @@ -223,28 +322,88 @@ class Manifest: stop_token = stop_token if stop_token is not None else self.stop_token # Must pass kwargs as dict for client "pop" methods removed used arguments - request_params = self.client.get_request_params(prompt, kwargs) + request_params = self.client.get_request(prompt, kwargs) # Avoid nested list of results - enforce n = 1 for batch if is_batch and request_params.n > 1: raise ValueError("Batch mode does not support n > 1.") - possible_request, full_kwargs = self.client.get_request(request_params) + self._validate_kwargs(kwargs, request_params) + + cached_idx_to_response, request_params = self._split_cached_requests( + request_params, overwrite_cache + ) + # If not None value or empty list - run new request + if request_params.prompt: + response = self.client.run_request(request_params) + else: + # Nothing to run + response = None + + final_response = self._stitch_responses_and_cache( + request=request_params, + response=response, + cached_idx_to_response=cached_idx_to_response, + ) + + # Extract text results + if return_response: + return final_response + else: + return final_response.get_response(stop_token, is_batch) + + async def arun_batch( + self, + prompts: List[str], + overwrite_cache: bool = False, + stop_token: Optional[str] = None, + return_response: bool = False, + **kwargs: Any, + ) -> Union[List[str], List[np.ndarray], Response]: + """ + Run a batch of prompts with async. + + Args: + prompts: prompts to run. + overwrite_cache: whether to overwrite cache. + stop_token: stop token for prompt generation. + Default is self.stop_token. + "" for no stop token. + return_response: whether to return Response object. + Returns: + response from prompt. + """ + stop_token = stop_token if stop_token is not None else self.stop_token + # Must pass kwargs as dict for client "pop" methods removed used arguments + request_params = self.client.get_request(prompts, kwargs) + # Avoid nested list of results - enforce n = 1 for batch + if request_params.n > 1: + raise ValueError("Batch mode does not support n > 1.") self._validate_kwargs(kwargs, request_params) - # Create cacke key - cache_key = full_kwargs.copy() - # Make query model dependent - cache_key.update(self.client.get_model_params()) - if run_id: - cache_key["run_id"] = run_id - response_obj = self.cache.get(cache_key, overwrite_cache, possible_request) - # Log session dictionary values - if self.session: - self.session.log_query(cache_key, response_obj.to_dict()) + + cached_idx_to_response, request_params = self._split_cached_requests( + request_params, overwrite_cache + ) + # If not None value or empty list - run new request + if request_params.prompt: + response = await self.client.arun_batch_request(request_params) + else: + # Nothing to run + response = None + + final_response = self._stitch_responses_and_cache( + request=request_params, + response=response, + cached_idx_to_response=cached_idx_to_response, + ) + # Extract text results if return_response: - return response_obj + return final_response else: - return response_obj.get_response(stop_token, is_batch) + return cast( + Union[List[str], List[np.ndarray]], + final_response.get_response(stop_token, True), + ) def score_prompt( self, @@ -257,8 +416,6 @@ class Manifest: Returns the response object with logits of the prompt. - Prompt scoring is not part of a session cache. - Args: prompt: prompt(s) to run. overwrite_cache: whether to overwrite cache. @@ -267,66 +424,31 @@ class Manifest: response from prompt. """ # Must pass kwargs as dict for client "pop" methods removed used arguments - request_params = self.client.get_request_params(prompt, kwargs) + request_params = self.client.get_request(prompt, kwargs) + request_params.request_type = "score_prompt" if request_params.n > 1: raise ValueError("Sequence scoring does not support n > 1.") - - try: - possible_request, full_kwargs = cast( - HuggingFaceClient, self.client - ).get_score_prompt_request(request_params) - except AttributeError: - raise ValueError("`score_prompt` only supported for HF models.") - self._validate_kwargs(kwargs, request_params) - # Create cacke key - cache_key = full_kwargs.copy() - # Make query model dependent - cache_key.update(self.client.get_model_params()) - response_obj = self.cache.get(cache_key, overwrite_cache, possible_request) - return response_obj.to_dict() - - def get_last_queries( - self, - last_n: int = -1, - return_raw_values: bool = False, - stop_token: Optional[str] = None, - ) -> List[Tuple[Any, Any]]: - """ - Get last n queries from current session. - - If last_n is -1, return all queries. By default will only return the - prompt text and result text unles return_raw_values is False. - Args: - last_n: last n queries. - return_raw_values: whether to return raw values as dicts. - stop_token: stop token for prompt results to be applied to all results. + cached_idx_to_response, request_params = self._split_cached_requests( + request_params, overwrite_cache + ) + # If not None value or empty list - run new request + if request_params.prompt: + try: + response = cast( + HuggingFaceClient, self.client + ).get_score_prompt_request(request_params) + except AttributeError: + raise ValueError("`score_prompt` only supported for HF models.") + else: + # Nothing to run + response = None - Returns: - last n list of queries and outputs. - """ - if self.session is None: - raise ValueError( - "Session was not initialized. Set `session_id` when loading Manifest." - ) - stop_token = stop_token if stop_token is not None else self.stop_token - last_queries = self.session.get_last_queries(last_n) - if not return_raw_values: - last_queries = [ - ( - query["prompt"], - Response.from_dict(response).get_response( - stop_token, is_batch=isinstance(query["prompt"], list) - ), - ) # type: ignore - for query, response in last_queries - ] - return last_queries - - def open_explorer(self) -> None: - """Open the explorer for jupyter widget.""" - # Open explorer - # TODO: implement - pass + final_response = self._stitch_responses_and_cache( + request=request_params, + response=response, + cached_idx_to_response=cached_idx_to_response, + ) + return final_response.to_dict() diff --git a/manifest/request.py b/manifest/request.py index a25732b..27dc57b 100644 --- a/manifest/request.py +++ b/manifest/request.py @@ -3,6 +3,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel +NOT_CACHE_KEYS = {"client_timeout", "batch_size"} +DEFAULT_REQUEST_KEYS = { + "client_timeout": ("client_timeout", 60), # seconds + "batch_size": ("batch_size", 1), + "run_id": ("run_id", None), + "request_type": ("request_type", None), +} + class Request(BaseModel): """Request object.""" @@ -17,7 +25,16 @@ class Request(BaseModel): n: int = 1 # Timeout - client_timeout: int = 60 + client_timeout: int = 120 + + # Run id used to repeat run with same parameters + run_id: Optional[str] = None + + # Batch size for async batch run + batch_size: int = 8 + + # Request type None is for completion. Used to scoring prompt + request_type: str = None def to_dict( self, allowable_keys: Dict[str, Tuple[str, Any]] = None, add_prompt: bool = True @@ -82,6 +99,9 @@ class LMRequest(Request): class DiffusionRequest(Request): """Diffusion Model Request object.""" + # Request type + request_type: str = "diffusion" + # Number of steps num_inference_steps: int = 50 diff --git a/manifest/response.py b/manifest/response.py index 0525a42..f747e61 100644 --- a/manifest/response.py +++ b/manifest/response.py @@ -4,6 +4,17 @@ from typing import Any, Dict, List, Union import numpy as np +RESPONSE_CONSTRUCTORS = { + "diffuser": { + "logits_key": "token_logprobs", + "item_key": "array", + }, + "tomadiffuser": { + "logits_key": "token_logprobs", + "item_key": "array", + }, +} + class NumpyArrayEncoder(json.JSONEncoder): """Numpy array encoder.""" diff --git a/manifest/session.py b/manifest/session.py deleted file mode 100644 index aee505f..0000000 --- a/manifest/session.py +++ /dev/null @@ -1,156 +0,0 @@ -"""User query session logging.""" -import logging -import os -import sqlite3 -import uuid -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple - -from manifest.caches.serializers import Serializer - -logging.getLogger("sqlitedict").setLevel(logging.WARNING) -logger = logging.getLogger(__name__) - - -class Session: - """A user session for caching requests.""" - - def __init__(self, session_id: Optional[str] = None) -> None: - """ - Initialize session. - - If session_id already exists, will append to existing session. - - Args: - session_id: session id. - - """ - manifest_home = Path(os.environ.get("MANIFEST_HOME", Path.home())) - self.db_file = manifest_home / ".manifest" / "session.db" - self.db_file.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_file)) - self.serializer = Serializer() - self._create_table() - if not session_id: - self.session_id = str(uuid.uuid4()) - self.query_id = 0 - else: - self.session_id = session_id - self.query_id = self._get_latest_query_id(self.session_id) - self.query_id += 1 - logger.info(f"Starting session {self.session_id}") - return - - def close(self) -> None: - """Close the client.""" - self.conn.close() - - @classmethod - def get_session_keys(cls, db_file: Path) -> List[str]: - """Get available session keys from cached file.""" - try: - conn = sqlite3.connect(str(db_file)) - query = """SELECT DISTINCT session_id FROM queries""" - cur = conn.cursor() - res = cur.execute(query) - return [x[0] for x in res.fetchall()] - except sqlite3.OperationalError: - logger.info( - "There is no database with the 'queries' table. " - "Are you sure you are using the right session file" - ) - return [] - - def _execute_query(self, query: str, *args: Any) -> Any: - """ - Execute query with optional args. - - Args: - query: query to execute. - """ - cur = self.conn.cursor() - res = cur.execute(query, args) - self.conn.commit() - return res - - def _create_table(self) -> None: - """Create table if not exists.""" - query = """CREATE TABLE IF NOT EXISTS queries ( - query_id integer NOT NULL, - session_id text NOT NULL, - query_key text NOT NULL, - response_key text NOT NULL - );""" - self._execute_query(query) - return - - def _get_latest_query_id(self, session_id: str) -> int: - """ - Get latest query id issued if resuming session. - - If no session_id, return -1. - - Args: - session_id: session id. - - Returns: - latest query id. - """ - query = """SELECT query_id - FROM queries - WHERE session_id = ? - ORDER BY query_id DESC LIMIT 1""" - res = self._execute_query(query, session_id).fetchone() - if res: - return res[0] - return -1 - - def log_query( - self, query_key: Dict[str, Any], response_key: Dict[str, Any] - ) -> None: - """ - Log the query and response. - - Args: - query_key: query of user (dump of request params). - response_key: response of server (dump of response). - """ - query = """INSERT INTO queries VALUES (?, ?, ?, ?);""" - self._execute_query( - query, - self.query_id, - self.session_id, - self.serializer.request_to_key(query_key), - self.serializer.response_to_key(response_key), - ) - self.query_id += 1 - return - - def get_last_queries( - self, last_n: int = -1 - ) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]: - """ - Get last n queries from current session. - - If last_n is -1, return all queries. - - Args: - last_n: last n queries. - - Returns: - last n list of queries and outputs. - """ - first_query = self.query_id - last_n if last_n > 0 else -1 - query = """SELECT query_key, response_key - FROM queries - WHERE session_id = ? AND query_id >= ? - ORDER BY query_id;""" - res = self._execute_query(query, self.session_id, first_query) - parsed_res = [ - ( - self.serializer.key_to_request(pair[0]), - self.serializer.key_to_response(pair[1]), - ) - for pair in res.fetchall() - ] - return parsed_res diff --git a/setup.py b/setup.py index c40e5fa..b69057b 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ REQUIRED = [ "pydantic>=1.9.0", "redis>=4.3.1", "requests>=2.27.1", + "aiohttp>=3.8.0", "sqlitedict>=2.0.0", "xxhash>=3.0.0", ] @@ -50,9 +51,6 @@ EXTRAS = { "fastapi>=0.70.0", "uvicorn>=0.18.0", ], - "chatgpt": [ - "pyChatGPT>=0.4.3", - ], "diffusers": [ "pillow>=9.0.0", ], @@ -60,7 +58,7 @@ EXTRAS = { "pg8000", "cloud-sql-python-connector[pg8000]>=1.0.0", "sqlalchemy", - ], + ], "dev": [ "autopep8>=1.6.0", "black>=22.3.0", diff --git a/tests/conftest.py b/tests/conftest.py index ddc9f68..3687751 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,10 +42,3 @@ def postgres_cache(monkeypatch: pytest.MonkeyPatch) -> Generator[str, None, None engine = sqlalchemy.create_engine(url) monkeypatch.setattr(sqlalchemy, "create_engine", lambda *args, **kwargs: engine) return engine # type: ignore - - -@pytest.fixture -def session_cache(tmpdir: str) -> Generator[Path, None, None]: - """Session cache dir.""" - os.environ["MANIFEST_HOME"] = str(tmpdir) - yield Path(tmpdir) diff --git a/tests/test_cache.py b/tests/test_cache.py index adba0a8..df74929 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -85,26 +85,20 @@ def test_get( cache = cast(Cache, _get_postgres_cache()) test_request = {"test": "hello", "testA": "world"} - compute = lambda: {"choices": [{"text": "hello"}]} + test_response = {"choices": [{"text": "hello"}]} - response = cache.get(test_request, overwrite_cache=False, compute=compute) - assert response.get_response() == "hello" - assert not response.is_cached() - assert response.get_request() == test_request + response = cache.get(test_request) + assert response is None - response = cache.get(test_request, overwrite_cache=False, compute=compute) + cache.set(test_request, test_response) + response = cache.get(test_request) assert response.get_response() == "hello" assert response.is_cached() assert response.get_request() == test_request - response = cache.get(test_request, overwrite_cache=True, compute=compute) - assert response.get_response() == "hello" - assert not response.is_cached() - assert response.get_request() == test_request - arr = np.random.rand(4, 4) test_request = {"test": "hello", "testA": "world of images"} - compute_arr = lambda: {"choices": [{"array": arr}]} + compute_arr_response = {"choices": [{"array": arr}]} # Test array if cache_type == "sqlite": @@ -114,9 +108,64 @@ def test_get( elif cache_type == "postgres": cache = _get_postgres_cache(client_name="diffuser") - response = cache.get(test_request, overwrite_cache=False, compute=compute_arr) + response = cache.get(test_request) + assert response is None + + cache.set(test_request, compute_arr_response) + response = cache.get(test_request) assert np.allclose(response.get_response(), arr) - assert not response.is_cached() + assert response.is_cached() + assert response.get_request() == test_request + + +@pytest.mark.usefixtures("sqlite_cache") +@pytest.mark.usefixtures("redis_cache") +@pytest.mark.usefixtures("postgres_cache") +@pytest.mark.parametrize("cache_type", ["sqlite", "redis", "postgres"]) +def test_get_batch_prompt( + sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str +) -> None: + """Test cache save prompt.""" + if cache_type == "sqlite": + cache = cast(Cache, SQLiteCache(sqlite_cache)) + elif cache_type == "redis": + cache = cast(Cache, RedisCache(redis_cache)) + elif cache_type == "postgres": + cache = cast(Cache, _get_postgres_cache()) + + test_request = {"test": ["hello", "goodbye"], "testA": "world"} + test_response = {"choices": [{"text": "hello"}, {"text": "goodbye"}]} + + response = cache.get(test_request) + assert response is None + + cache.set(test_request, test_response) + response = cache.get(test_request) + assert response.get_response() == ["hello", "goodbye"] + assert response.is_cached() + assert response.get_request() == test_request + + # Test arrays + arr = np.random.rand(4, 4) + arr2 = np.random.rand(4, 4) + test_request = {"test": ["hello", "goodbye"], "testA": "world of images"} + compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]} + + if cache_type == "sqlite": + cache = SQLiteCache(sqlite_cache, client_name="diffuser") + elif cache_type == "redis": + cache = RedisCache(redis_cache, client_name="diffuser") + elif cache_type == "postgres": + cache = _get_postgres_cache(client_name="diffuser") + + response = cache.get(test_request) + assert response is None + + cache.set(test_request, compute_arr_response) + response = cache.get(test_request) + assert np.allclose(response.get_response()[0], arr) + assert np.allclose(response.get_response()[1], arr2) + assert response.is_cached() assert response.get_request() == test_request @@ -137,14 +186,11 @@ def test_noop_cache() -> None: # Assert always not cached test_request = {"test": "hello", "testA": "world"} - compute = lambda: {"choices": [{"text": "hello"}]} + test_response = {"choices": [{"text": "hello"}]} - response = cache.get(test_request, overwrite_cache=False, compute=compute) - assert response.get_response() == "hello" - assert not response.is_cached() - assert response.get_request() == test_request + response = cache.get(test_request) + assert response is None - response = cache.get(test_request, overwrite_cache=False, compute=compute) - assert response.get_response() == "hello" - assert not response.is_cached() - assert response.get_request() == test_request + cache.set(test_request, test_response) + response = cache.get(test_request) + assert response is None diff --git a/tests/test_client.py b/tests/test_client.py index a4c11ef..9a1d38f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -27,17 +27,29 @@ def test_get_request() -> None: """Test client get request.""" args = {"n": 3} client = DummyClient(connection_str=None, client_args=args) - request_params = client.get_request_params("hello", {}) - request_func, request_params_return = client.get_request(request_params) - assert request_params_return == {"prompt": "hello", "num_results": 3} - assert request_func() == {"choices": [{"text": "hello"}] * 3} - - request_params = client.get_request_params("hello", {"n": 5}) - request_func, request_params_return = client.get_request(request_params) - assert request_params_return == {"prompt": "hello", "num_results": 5} - assert request_func() == {"choices": [{"text": "hello"}] * 5} - - request_params = client.get_request_params(["hello"] * 5, {"n": 1}) - request_func, request_params_return = client.get_request(request_params) - assert request_params_return == {"prompt": ["hello"] * 5, "num_results": 1} - assert request_func() == {"choices": [{"text": "hello"}] * 5} + request_params = client.get_request("hello", {}) + response = client.run_request(request_params) + assert client.get_cache_key(request_params) == { + "prompt": "hello", + "num_results": 3, + "engine": "dummy", + } + assert response.get_json_response() == {"choices": [{"text": "hello"}] * 3} + + request_params = client.get_request("hello", {"n": 5}) + response = client.run_request(request_params) + assert client.get_cache_key(request_params) == { + "prompt": "hello", + "num_results": 5, + "engine": "dummy", + } + assert response.get_json_response() == {"choices": [{"text": "hello"}] * 5} + + request_params = client.get_request(["hello"] * 5, {"n": 1}) + response = client.run_request(request_params) + assert client.get_cache_key(request_params) == { + "prompt": ["hello"] * 5, + "num_results": 1, + "engine": "dummy", + } + assert response.get_json_response() == {"choices": [{"text": "hello"}] * 5} diff --git a/tests/test_huggingface_api.py b/tests/test_huggingface_api.py index a72b7a4..b23ff5f 100644 --- a/tests/test_huggingface_api.py +++ b/tests/test_huggingface_api.py @@ -141,8 +141,8 @@ def test_gpt_score() -> None: result = model.score_sequence(inputs) assert result is not None assert len(result) == 2 - assert math.isclose(round(result[0][0], 3), -19.935) - assert math.isclose(round(result[1][0], 3), -45.831) + assert math.isclose(round(result[0][0], 3), -46.71) + assert math.isclose(round(result[1][0], 3), -12.752) assert isinstance(result[0][1], list) assert isinstance(result[1][1], list) diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 020c298..fc4a5e3 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -1,19 +1,25 @@ """Manifest test.""" -import json +import asyncio from typing import cast import pytest +import requests from manifest import Manifest, Response from manifest.caches.noop import NoopCache from manifest.caches.sqlite import SQLiteCache from manifest.clients.dummy import DummyClient -from manifest.session import Session + +URL = "http://localhost:6000" +try: + _ = requests.post(URL + "/params").json() + MODEL_ALIVE = True +except Exception: + MODEL_ALIVE = False @pytest.mark.usefixtures("sqlite_cache") -@pytest.mark.usefixtures("session_cache") -def test_init(sqlite_cache: str, session_cache: str) -> None: +def test_init(sqlite_cache: str) -> None: """Test manifest initialization.""" with pytest.raises(ValueError) as exc_info: Manifest( @@ -32,7 +38,6 @@ def test_init(sqlite_cache: str, session_cache: str) -> None: assert manifest.client_name == "dummy" assert isinstance(manifest.client, DummyClient) assert isinstance(manifest.cache, SQLiteCache) - assert manifest.session is None assert manifest.client.n == 1 # type: ignore assert manifest.stop_token == "" @@ -41,19 +46,16 @@ def test_init(sqlite_cache: str, session_cache: str) -> None: cache_name="noop", n=3, stop_token="\n", - session_id="_default", ) assert manifest.client_name == "dummy" assert isinstance(manifest.client, DummyClient) assert isinstance(manifest.cache, NoopCache) - assert isinstance(manifest.session, Session) assert manifest.client.n == 3 # type: ignore assert manifest.stop_token == "\n" @pytest.mark.usefixtures("sqlite_cache") -@pytest.mark.usefixtures("session_cache") -def test_change_manifest(sqlite_cache: str, session_cache: str) -> None: +def test_change_manifest(sqlite_cache: str) -> None: """Test manifest change.""" manifest = Manifest( client_name="dummy", @@ -65,7 +67,6 @@ def test_change_manifest(sqlite_cache: str, session_cache: str) -> None: assert manifest.client_name == "dummy" assert isinstance(manifest.client, DummyClient) assert isinstance(manifest.cache, SQLiteCache) - assert manifest.session is None assert manifest.client.n == 1 # type: ignore assert manifest.stop_token == "" @@ -73,18 +74,14 @@ def test_change_manifest(sqlite_cache: str, session_cache: str) -> None: assert manifest.client_name == "dummy" assert isinstance(manifest.client, DummyClient) assert isinstance(manifest.cache, SQLiteCache) - assert manifest.session is None assert manifest.client.n == 1 # type: ignore assert manifest.stop_token == "\n" @pytest.mark.usefixtures("sqlite_cache") -@pytest.mark.usefixtures("session_cache") @pytest.mark.parametrize("n", [1, 2]) @pytest.mark.parametrize("return_response", [True, False]) -def test_run( - sqlite_cache: str, session_cache: str, n: int, return_response: bool -) -> None: +def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: """Test manifest run.""" manifest = Manifest( client_name="dummy", @@ -111,15 +108,12 @@ def test_run( else: res = cast(str, result) assert ( - manifest.cache.get_key( - json.dumps( - { - "prompt": "This is a prompt", - "engine": "dummy", - "num_results": n, - }, - sort_keys=True, - ) + manifest.cache.get( + { + "prompt": "This is a prompt", + "engine": "dummy", + "num_results": n, + }, ) is not None ) @@ -136,16 +130,13 @@ def test_run( else: res = cast(str, result) assert ( - manifest.cache.get_key( - json.dumps( - { - "prompt": "This is a prompt", - "engine": "dummy", - "num_results": n, - "run_id": "34", - }, - sort_keys=True, - ) + manifest.cache.get( + { + "prompt": "This is a prompt", + "engine": "dummy", + "num_results": n, + "run_id": "34", + } ) is not None ) @@ -162,15 +153,12 @@ def test_run( else: res = cast(str, result) assert ( - manifest.cache.get_key( - json.dumps( - { - "prompt": "Hello is a prompt", - "engine": "dummy", - "num_results": n, - }, - sort_keys=True, - ) + manifest.cache.get( + { + "prompt": "Hello is a prompt", + "engine": "dummy", + "num_results": n, + }, ) is not None ) @@ -187,15 +175,12 @@ def test_run( else: res = cast(str, result) assert ( - manifest.cache.get_key( - json.dumps( - { - "prompt": "Hello is a prompt", - "engine": "dummy", - "num_results": n, - }, - sort_keys=True, - ) + manifest.cache.get( + { + "prompt": "Hello is a prompt", + "engine": "dummy", + "num_results": n, + }, ) is not None ) @@ -206,12 +191,9 @@ def test_run( @pytest.mark.usefixtures("sqlite_cache") -@pytest.mark.usefixtures("session_cache") @pytest.mark.parametrize("n", [1, 2]) @pytest.mark.parametrize("return_response", [True, False]) -def test_batch_run( - sqlite_cache: str, session_cache: str, n: int, return_response: bool -) -> None: +def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: """Test manifest run.""" manifest = Manifest( client_name="dummy", @@ -233,6 +215,16 @@ def test_batch_run( else: res = cast(str, result) assert res == ["hello"] + assert ( + manifest.cache.get( + { + "prompt": "This is a prompt", + "engine": "dummy", + "num_results": n, + }, + ) + is not None + ) prompt = ["Hello is a prompt", "Hello is a prompt"] result = manifest.run(prompt, return_response=return_response) @@ -243,6 +235,42 @@ def test_batch_run( else: res = cast(str, result) assert res == ["hello", "hello"] + assert ( + manifest.cache.get( + { + "prompt": "Hello is a prompt", + "engine": "dummy", + "num_results": n, + }, + ) + is not None + ) + + result = manifest.run(prompt, return_response=True) + res = cast(Response, result).get_response(manifest.stop_token, is_batch=True) + assert cast(Response, result).is_cached() + + assert ( + manifest.cache.get( + { + "prompt": "New prompt", + "engine": "dummy", + "num_results": n, + }, + ) + is None + ) + prompt = ["This is a prompt", "New prompt"] + result = manifest.run(prompt, return_response=return_response) + if return_response: + res = cast(Response, result).get_response( + manifest.stop_token, is_batch=True + ) + # Cached because one item is in cache + assert cast(Response, result).is_cached() + else: + res = cast(str, result) + assert res == ["hello", "hello"] prompt = ["Hello is a prompt", "Hello is a prompt"] result = manifest.run(prompt, stop_token="ll", return_response=return_response) @@ -253,6 +281,72 @@ def test_batch_run( assert res == ["he", "he"] +@pytest.mark.usefixtures("sqlite_cache") +def test_abatch_run(sqlite_cache: str) -> None: + """Test manifest run.""" + manifest = Manifest( + client_name="dummy", + cache_name="sqlite", + cache_connection=sqlite_cache, + ) + prompt = ["This is a prompt"] + result = asyncio.run(manifest.arun_batch(prompt, return_response=True)) + + res = cast(Response, result).get_response(manifest.stop_token, is_batch=True) + assert res == ["hello"] + assert ( + manifest.cache.get( + { + "prompt": "This is a prompt", + "engine": "dummy", + "num_results": 1, + }, + ) + is not None + ) + + prompt = ["Hello is a prompt", "Hello is a prompt"] + result = asyncio.run(manifest.arun_batch(prompt, return_response=True)) + res = cast(Response, result).get_response(manifest.stop_token, is_batch=True) + assert res == ["hello", "hello"] + assert ( + manifest.cache.get( + { + "prompt": "Hello is a prompt", + "engine": "dummy", + "num_results": 1, + }, + ) + is not None + ) + + result = asyncio.run(manifest.arun_batch(prompt, return_response=True)) + res = cast(Response, result).get_response(manifest.stop_token, is_batch=True) + assert cast(Response, result).is_cached() + + assert ( + manifest.cache.get( + { + "prompt": "New prompt", + "engine": "dummy", + "num_results": 1, + }, + ) + is None + ) + prompt = ["This is a prompt", "New prompt"] + result = asyncio.run(manifest.arun_batch(prompt, return_response=True)) + res = cast(Response, result).get_response(manifest.stop_token, is_batch=True) + # Cached because one item is in cache + assert cast(Response, result).is_cached() + assert res == ["hello", "hello"] + + prompt = ["Hello is a prompt", "Hello is a prompt"] + result = asyncio.run(manifest.arun_batch(prompt, return_response=True)) + res = cast(Response, result).get_response(stop_token="ll", is_batch=True) + assert res == ["he", "he"] + + @pytest.mark.usefixtures("sqlite_cache") def test_score_run(sqlite_cache: str) -> None: """Test manifest run.""" @@ -264,16 +358,14 @@ def test_score_run(sqlite_cache: str) -> None: prompt = "This is a prompt" result = manifest.score_prompt(prompt) - assert ( - manifest.cache.get_key( - json.dumps( - { - "prompt": "This is a prompt", - "engine": "dummy", - }, - sort_keys=True, - ) + manifest.cache.get( + { + "prompt": "This is a prompt", + "engine": "dummy", + "num_results": 1, + "request_type": "score_prompt", + }, ) is not None ) @@ -284,20 +376,35 @@ def test_score_run(sqlite_cache: str) -> None: "item_dtype": None, "response": {"choices": [{"text": "This is a prompt", "logprob": 0.3}]}, "cached": False, - "request_params": {"prompt": "This is a prompt", "engine": "dummy"}, + "request_params": { + "prompt": "This is a prompt", + "engine": "dummy", + "num_results": 1, + "request_type": "score_prompt", + }, } prompt_list = ["Hello is a prompt", "Hello is another prompt"] result = manifest.score_prompt(prompt_list) assert ( - manifest.cache.get_key( - json.dumps( - { - "prompt": ["Hello is a prompt", "Hello is another prompt"], - "engine": "dummy", - }, - sort_keys=True, - ) + manifest.cache.get( + { + "prompt": "Hello is a prompt", + "engine": "dummy", + "num_results": 1, + "request_type": "score_prompt", + }, + ) + is not None + ) + assert ( + manifest.cache.get( + { + "prompt": "Hello is another prompt", + "engine": "dummy", + "num_results": 1, + "request_type": "score_prompt", + }, ) is not None ) @@ -316,76 +423,64 @@ def test_score_run(sqlite_cache: str) -> None: "request_params": { "prompt": ["Hello is a prompt", "Hello is another prompt"], "engine": "dummy", + "num_results": 1, + "request_type": "score_prompt", }, } -@pytest.mark.usefixtures("session_cache") -def test_log_query(session_cache: str) -> None: - """Test manifest session logging.""" - manifest = Manifest(client_name="dummy", cache_name="noop", session_id="_default") - prompt = "This is a prompt" - _ = manifest.run(prompt, return_response=False) - query_key = { - "prompt": "This is a prompt", - "engine": "dummy", - "num_results": 1, - } - response_key = { - "cached": False, - "request_params": query_key, - "response": {"choices": [{"text": "hello"}]}, - "generation_key": "choices", - "item_dtype": None, - "item_key": "text", - "logits_key": "token_logprobs", - } - assert manifest.get_last_queries(1) == [("This is a prompt", "hello")] - assert manifest.get_last_queries(1, return_raw_values=True) == [ - (query_key, response_key) - ] - assert manifest.get_last_queries(3, return_raw_values=True) == [ - (query_key, response_key) - ] - prior_cache_item = (query_key, response_key) - - prompt_lst = ["This is a prompt", "This is a prompt2"] - _ = manifest.run(prompt_lst, return_response=False) - query_key = { - "prompt": ["This is a prompt", "This is a prompt2"], - "engine": "dummy", - "num_results": 1, - } - response_key = { - "cached": False, - "generation_key": "choices", - "item_dtype": None, - "item_key": "text", - "logits_key": "token_logprobs", - "request_params": query_key, - "response": {"choices": [{"text": "hello"}, {"text": "hello"}]}, - } - assert manifest.get_last_queries(1) == [ - (["This is a prompt", "This is a prompt2"], ["hello", "hello"]) - ] - assert manifest.get_last_queries(1, return_raw_values=True) == [ - (query_key, response_key) - ] - assert manifest.get_last_queries(3, return_raw_values=True) == [ - prior_cache_item, - (query_key, response_key), - ] - - # Test no session - manifest = Manifest( - client_name="dummy", - cache_name="noop", +@pytest.mark.skipif(not MODEL_ALIVE, reason=f"No model at {URL}") +@pytest.mark.usefixtures("sqlite_cache") +def test_local_huggingface(sqlite_cache: str) -> None: + """Test local huggingface client.""" + client = Manifest( + client_name="huggingface", + client_connection=URL, + cache_name="sqlite", + cache_connection=sqlite_cache, ) - prompt = "This is a prompt" - _ = manifest.run(prompt, return_response=False) - with pytest.raises(ValueError) as exc_info: - manifest.get_last_queries(1) - assert ( - str(exc_info.value) - == "Session was not initialized. Set `session_id` when loading Manifest." + + res = client.run("Why are there apples?") + assert isinstance(res, str) and len(res) > 0 + + response = cast(Response, client.run("Why are there apples?", return_response=True)) + assert isinstance(response.get_response(), str) and len(response.get_response()) > 0 + assert response.is_cached() is True + + response = cast(Response, client.run("Why are there apples?", return_response=True)) + assert response.is_cached() is True + + res_list = client.run(["Why are there apples?", "Why are there bananas?"]) + assert isinstance(res_list, list) and len(res_list) == 2 + + response = cast( + Response, client.run("Why are there bananas?", return_response=True) + ) + assert response.is_cached() is True + + res_list = asyncio.run( + client.arun_batch(["Why are there pears?", "Why are there oranges?"]) + ) + assert isinstance(res_list, list) and len(res_list) == 2 + + response = cast( + Response, client.run("Why are there oranges?", return_response=True) + ) + assert response.is_cached() is True + + scores = client.score_prompt("Why are there apples?") + assert isinstance(scores, dict) and len(scores) > 0 + assert scores["cached"] is False + assert len(scores["response"]["choices"][0]["token_logprobs"]) == len( + scores["response"]["choices"][0]["tokens"] + ) + + scores = client.score_prompt(["Why are there apples?", "Why are there bananas?"]) + assert isinstance(scores, dict) and len(scores) > 0 + assert scores["cached"] is True + assert len(scores["response"]["choices"][0]["token_logprobs"]) == len( + scores["response"]["choices"][0]["tokens"] + ) + assert len(scores["response"]["choices"][0]["token_logprobs"]) == len( + scores["response"]["choices"][0]["tokens"] ) diff --git a/tests/test_request.py b/tests/test_request.py index 721d0cd..0b3ba22 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -10,10 +10,10 @@ def test_llm_init() -> None: request = LMRequest(temperature=0.5) assert request.temperature == 0.5 - request = LMRequest(**{"temperature": 0.5}) + request = LMRequest(**{"temperature": 0.5}) # type: ignore assert request.temperature == 0.5 - request = LMRequest(**{"temperature": 0.5, "prompt": "test"}) + request = LMRequest(**{"temperature": 0.5, "prompt": "test"}) # type: ignore assert request.temperature == 0.5 assert request.prompt == "test" @@ -26,10 +26,10 @@ def test_diff_init() -> None: request = DiffusionRequest(height=128) assert request.height == 128 - request = DiffusionRequest(**{"height": 128}) + request = DiffusionRequest(**{"height": 128}) # type: ignore assert request.height == 128 - request = DiffusionRequest(**{"height": 128, "prompt": "test"}) + request = DiffusionRequest(**{"height": 128, "prompt": "test"}) # type: ignore assert request.height == 128 assert request.prompt == "test" diff --git a/tests/test_serializer.py b/tests/test_serializer.py index b7825cf..8f4e4a0 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,13 +1,12 @@ """Cache test.""" import json -from pathlib import Path import numpy as np from manifest.caches.serializers import ArraySerializer -def test_response_to_key(session_cache: Path) -> None: +def test_response_to_key() -> None: """Test array serializer initialization.""" serializer = ArraySerializer() arr = np.random.rand(4, 4) diff --git a/tests/test_session.py b/tests/test_session.py deleted file mode 100644 index 2dbe52d..0000000 --- a/tests/test_session.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Test session.""" -import sqlite3 -from pathlib import Path - -import pytest - -from manifest.session import Session - - -@pytest.mark.usefixtures("session_cache") -def test_init(session_cache: Path) -> None: - """Test session initialization.""" - session = Session() - assert isinstance(session.conn, sqlite3.Connection) - assert session.db_file == session_cache / ".manifest" / "session.db" - assert session.query_id == 0 - assert (session_cache / ".manifest" / "session.db").exists() - # Remove session cache file. - (session_cache / ".manifest" / "session.db").unlink() - - session = Session("dog_days") - assert isinstance(session.conn, sqlite3.Connection) - assert session.db_file == session_cache / ".manifest" / "session.db" - assert session.query_id == 0 - assert session.session_id == "dog_days" - assert (session_cache / ".manifest" / "session.db").exists() - session.close() - - -@pytest.mark.usefixtures("session_cache") -def test_log_query(session_cache: Path) -> None: - """Test session log_query.""" - session = Session() - assert session.get_last_queries(1) == [] - - query_key = {"query": "What is your name?", "time": "now"} - response_key = {"response": "I don't have a name", "engine": "nodel"} - session.log_query(query_key, response_key) - assert session.query_id == 1 - assert session.get_last_queries(1) == [(query_key, response_key)] - - query_key2 = {"query2": "What is your name?", "time": "now"} - response_key2 = {"response2": "I don't have a name", "engine": "nodel"} - session.log_query(query_key2, response_key2) - assert session.query_id == 2 - assert len(session.get_last_queries(1)) == 1 - assert session.get_last_queries(2) == [ - (query_key, response_key), - (query_key2, response_key2), - ] - session.close() - - -@pytest.mark.usefixtures("session_cache") -def test_resume_query(session_cache: Path) -> None: - """Test session log_query.""" - session = Session(session_id="dog_days") - query_key = {"query": "What is your name?", "time": "now"} - response_key = {"response": "I don't have a name", "engine": "nodel"} - session.log_query(query_key, response_key) - session.close() - - session = Session(session_id="dog_days") - assert session.query_id == 1 - - -@pytest.mark.usefixtures("session_cache") -def test_session_keys(session_cache: Path) -> None: - """Test get session keys.""" - # Assert empty before queries - assert Session.get_session_keys(session_cache / ".manifest" / "session.db") == [] - # Add queries and make sure session is logged - session = Session(session_id="dog_days") - query_key = {"query": "What is your name?", "time": "now"} - response_key = {"response": "I don't have a name", "engine": "nodel"} - session.log_query(query_key, response_key) - session.close() - - assert Session.get_session_keys(session_cache / ".manifest" / "session.db") == [ - "dog_days" - ]