diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 60ef57b..aacd9da 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,5 +1,9 @@ 0.1.6 - Unreleased --------------------- +Fixed +^^^^^ +* Unified `run` and `run_chat` methods so it's just `run` now. +* LLama HF models for eval 0.1.5 - 2022-05-03 --------------------- diff --git a/README.md b/README.md index e550b3c..6fe934c 100644 --- a/README.md +++ b/README.md @@ -218,7 +218,7 @@ python3 -m manifest.api.app \ ``` # Chat Models -Manifest has specific support for executing against chat models in the more standard "system" / "user" dialogue. To pass in a dialogue history to Manifest, you must use the `run_chat` command with an associated chat model such as `openaichat`. +Manifest has specific support for executing against chat models in the more standard "system" / "user" dialogue. To pass in a dialogue history to Manifest, use the `run` command with a list of dictionary inputs with `role` and `content` keys using an associated chat model such as `openaichat`. ```python manifest = Manifest(client_name="openaichat") @@ -226,7 +226,7 @@ dialogue = [ {"role": "system", "content": "You are a helpful assistant who also responds in rhymes"}, {"role": "user", "content": "What is the date?"}, ] -res = manifest.run_chat(dialogue, max_tokens=100) +res = manifest.run(dialogue, max_tokens=100) ``` # Embedding Models diff --git a/examples/manifest_chatgpt.ipynb b/examples/manifest_chatgpt.ipynb index 28f30fd..4dc2339 100644 --- a/examples/manifest_chatgpt.ipynb +++ b/examples/manifest_chatgpt.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -49,17 +49,9 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The 2020 World Series was played at Globe\n" - ] - } - ], + "outputs": [], "source": [ "# Simple question\n", "chat_dict = [\n", @@ -68,7 +60,7 @@ " {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n", " {\"role\": \"user\", \"content\": \"Where was it played?\"}\n", "]\n", - "print(manifest.run_chat(chat_dict))" + "print(manifest.run(chat_dict, max_tokens=100))" ] } ], diff --git a/manifest/clients/client.py b/manifest/clients/client.py index 4b37863..10692b0 100644 --- a/manifest/clients/client.py +++ b/manifest/clients/client.py @@ -51,6 +51,7 @@ class Client(ABC): PARAMS: Dict[str, Tuple[str, Any]] = {} REQUEST_CLS = Request NAME: str = None + IS_CHAT: bool = False def __init__( self, connection_str: Optional[str] = None, client_args: Dict[str, Any] = {} diff --git a/manifest/clients/openai_chat.py b/manifest/clients/openai_chat.py index ab83700..f06a18c 100644 --- a/manifest/clients/openai_chat.py +++ b/manifest/clients/openai_chat.py @@ -29,6 +29,7 @@ class OpenAIChatClient(OpenAIClient): } REQUEST_CLS = LMRequest NAME = "openaichat" + IS_CHAT = True def connect( self, diff --git a/manifest/manifest.py b/manifest/manifest.py index 9221925..0721b29 100644 --- a/manifest/manifest.py +++ b/manifest/manifest.py @@ -287,7 +287,7 @@ class Manifest: def run( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[Dict[str, str]]], overwrite_cache: bool = False, stop_token: Optional[str] = None, return_response: bool = False, @@ -296,6 +296,8 @@ class Manifest: """ Run the prompt. + Orchestrates between the standard run and chat run and batch run. + Args: prompt: prompt(s) to run. overwrite_cache: whether to overwrite cache. @@ -307,9 +309,68 @@ class Manifest: Returns: response from prompt. """ - is_batch = isinstance(prompt, list) + if not isinstance(prompt, list) and not isinstance(prompt, str): + raise ValueError( + f"Invalid prompt type: {type(prompt)}. " + "Prompt must be a string or list of strings " + "or list of dicts." + ) + if isinstance(prompt, list) and not prompt: + raise ValueError("Prompt cannot be empty list") # Get the client to run client = self.client_pool.get_next_client() + if isinstance(prompt, list) and isinstance(prompt[0], dict): + if not client.IS_CHAT: + raise ValueError( + f"Client {client} does not support dict chat prompt. " + "Please use a chat model." + ) + if stop_token: + logger.warning( + "stop_token is not supported for chat prompt. " + "Ignoring stop_token." + ) + return self._run_chat( + prompt=cast(List[Dict[str, str]], prompt), + client=client, + overwrite_cache=overwrite_cache, + return_response=return_response, + ) + else: + return self._run( + prompt=cast(Union[str, List[str]], prompt), + client=client, + overwrite_cache=overwrite_cache, + stop_token=stop_token, + return_response=return_response, + **kwargs, + ) + + def _run( + self, + prompt: Union[str, List[str]], + client: Client, + overwrite_cache: bool = False, + stop_token: Optional[str] = None, + return_response: bool = False, + **kwargs: Any, + ) -> Union[str, List[str], np.ndarray, List[np.ndarray], Response]: + """ + Run the prompt. + + Args: + prompt: prompt(s) to run. + client: client to run. + overwrite_cache: whether to overwrite cache. + stop_token: stop token for prompt generation. + Default is self.stop_token. + "" for no stop token. + return_response: whether to return Response object. + + Returns: + response from prompt. + """ + is_batch = isinstance(prompt, list) stop_token = stop_token if stop_token is not None else self.stop_token # Must pass kwargs as dict for client "pop" methods removed used arguments request_params = client.get_request(prompt, kwargs) @@ -344,6 +405,67 @@ class Manifest: else: return final_response.get_response(stop_token, is_batch) + def _run_chat( + self, + prompt: List[Dict[str, str]], + client: Client, + overwrite_cache: bool = False, + return_response: bool = False, + **kwargs: Any, + ) -> Union[str, Response]: + """ + Run the prompt. + + Args: + prompt: prompt dictionary to run. + client: client to run. + overwrite_cache: whether to overwrite cache. + stop_token: stop token for prompt generation. + Default is self.stop_token. + "" for no stop token. + return_response: whether to return Response object. + + Returns: + response from prompt. + """ + is_batch = False + # Get a request for an empty prompt to handle all kwargs + request_params = client.get_request("", kwargs) + # Add prompt and cast as chat request + request_params_dict = request_params.to_dict() + request_params_dict["prompt"] = prompt + request_params_as_chat = LMChatRequest(**request_params_dict) + # Avoid nested list of results - enforce n = 1 for batch + if request_params_as_chat.n > 1: + raise ValueError("Chat mode does not support n > 1.") + self._validate_kwargs(kwargs, request_params_as_chat) + + cached_idx_to_response, request_params_as_chat = self._split_cached_requests( # type: ignore # noqa: E501 + request_params_as_chat, client, overwrite_cache + ) + # If not None value or empty list - run new request + if request_params_as_chat.prompt: + # Start timing metrics + self.client_pool.start_timer() + response = client.run_chat_request(request_params_as_chat) + self.client_pool.end_timer() + else: + # Nothing to run + response = None + + final_response = self._stitch_responses_and_cache( + request=request_params_as_chat, + client=client, + response=response, + cached_idx_to_response=cached_idx_to_response, + ) + + # Extract text results + if return_response: + return final_response + else: + return cast(str, final_response.get_response("", is_batch)) + async def arun_batch( self, prompts: List[str], @@ -381,6 +503,13 @@ class Manifest: Returns: response from prompt. """ + if not isinstance(prompts, list): + raise ValueError("Prompts must be a list of strings.") + if not prompts: + raise ValueError("Prompts must not be empty.") + if not isinstance(prompts[0], str): + raise ValueError("Prompts must be a list of strings.") + # Split the prompts into chunks prompt_chunks: List[Tuple[Client, List[str]]] = [] if chunk_size > 0: @@ -464,67 +593,6 @@ class Manifest: ) return final_response - def run_chat( - self, - prompt: List[Dict[str, str]], - overwrite_cache: bool = False, - return_response: bool = False, - **kwargs: Any, - ) -> Union[str, Response]: - """ - Run the prompt. - - Args: - prompt: prompt dictionary to run. - overwrite_cache: whether to overwrite cache. - stop_token: stop token for prompt generation. - Default is self.stop_token. - "" for no stop token. - return_response: whether to return Response object. - - Returns: - response from prompt. - """ - is_batch = False - # Get the client to run - client = self.client_pool.get_next_client() - # Get a request for an empty prompt to handle all kwargs - request_params = client.get_request("", kwargs) - # Add prompt and cast as chat request - request_params_dict = request_params.to_dict() - request_params_dict["prompt"] = prompt - request_params_as_chat = LMChatRequest(**request_params_dict) - # Avoid nested list of results - enforce n = 1 for batch - if request_params_as_chat.n > 1: - raise ValueError("Chat mode does not support n > 1.") - self._validate_kwargs(kwargs, request_params_as_chat) - - cached_idx_to_response, request_params_as_chat = self._split_cached_requests( # type: ignore # noqa: E501 - request_params_as_chat, client, overwrite_cache - ) - # If not None value or empty list - run new request - if request_params_as_chat.prompt: - # Start timing metrics - self.client_pool.start_timer() - response = client.run_chat_request(request_params_as_chat) - self.client_pool.end_timer() - else: - # Nothing to run - response = None - - final_response = self._stitch_responses_and_cache( - request=request_params_as_chat, - client=client, - response=response, - cached_idx_to_response=cached_idx_to_response, - ) - - # Extract text results - if return_response: - return final_response - else: - return cast(str, final_response.get_response("", is_batch)) - def score_prompt( self, prompt: Union[str, List[str]], diff --git a/pyproject.toml b/pyproject.toml index 0e63470..5cd0250 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ module = [ "diffusers", "sentence_transformers", "sqlitedict", + "sqlalchemy", "dill", "accelerate", "accelerate.utils.modeling", diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 5cfec9a..a6a3555 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -399,11 +399,13 @@ def test_run_chat(sqlite_cache: str) -> None: cache_name="sqlite", cache_connection=sqlite_cache, ) + # Set CHAT to be true for this model + manifest.client_pool.client_pool[0].IS_CHAT = True prompt = [ {"role": "system", "content": "Hello."}, ] - result = manifest.run_chat(prompt, return_response=False) + result = manifest.run(prompt, return_response=False) assert result == "Hello." assert ( manifest.cache.get( @@ -421,7 +423,7 @@ def test_run_chat(sqlite_cache: str) -> None: {"role": "system", "content": "Hello."}, {"role": "user", "content": "Goodbye?"}, ] - result = manifest.run_chat(prompt, return_response=True) + result = manifest.run(prompt, return_response=True) assert isinstance(result, Response) result = cast(Response, result) assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) @@ -849,9 +851,9 @@ def test_openaichat(sqlite_cache: str) -> None: }, {"role": "user", "content": "Where was it played?"}, ] - res = client.run_chat(chat_dict) + res = client.run(chat_dict) assert isinstance(res, str) and len(res) > 0 - response = cast(Response, client.run_chat(chat_dict, return_response=True)) + response = cast(Response, client.run(chat_dict, return_response=True)) assert response.is_cached() is True assert response.get_usage_obj().usages[0].total_tokens == 67 chat_dict = [ @@ -863,7 +865,7 @@ def test_openaichat(sqlite_cache: str) -> None: }, {"role": "user", "content": "Where was it played?"}, ] - response = cast(Response, client.run_chat(chat_dict, return_response=True)) + response = cast(Response, client.run(chat_dict, return_response=True)) assert response.is_cached() is False