feat: unify run_chat and run (#92)

pull/93/head
Laurel Orr 1 year ago committed by GitHub
parent 5ad4b017b5
commit 147436c9b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,9 @@
0.1.6 - Unreleased
---------------------
Fixed
^^^^^
* Unified `run` and `run_chat` methods so it's just `run` now.
* LLama HF models for eval
0.1.5 - 2022-05-03
---------------------

@ -218,7 +218,7 @@ python3 -m manifest.api.app \
```
# Chat Models
Manifest has specific support for executing against chat models in the more standard "system" / "user" dialogue. To pass in a dialogue history to Manifest, you must use the `run_chat` command with an associated chat model such as `openaichat`.
Manifest has specific support for executing against chat models in the more standard "system" / "user" dialogue. To pass in a dialogue history to Manifest, use the `run` command with a list of dictionary inputs with `role` and `content` keys using an associated chat model such as `openaichat`.
```python
manifest = Manifest(client_name="openaichat")
@ -226,7 +226,7 @@ dialogue = [
{"role": "system", "content": "You are a helpful assistant who also responds in rhymes"},
{"role": "user", "content": "What is the date?"},
]
res = manifest.run_chat(dialogue, max_tokens=100)
res = manifest.run(dialogue, max_tokens=100)
```
# Embedding Models

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -31,7 +31,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -49,17 +49,9 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The 2020 World Series was played at Globe\n"
]
}
],
"outputs": [],
"source": [
"# Simple question\n",
"chat_dict = [\n",
@ -68,7 +60,7 @@
" {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
" {\"role\": \"user\", \"content\": \"Where was it played?\"}\n",
"]\n",
"print(manifest.run_chat(chat_dict))"
"print(manifest.run(chat_dict, max_tokens=100))"
]
}
],

@ -51,6 +51,7 @@ class Client(ABC):
PARAMS: Dict[str, Tuple[str, Any]] = {}
REQUEST_CLS = Request
NAME: str = None
IS_CHAT: bool = False
def __init__(
self, connection_str: Optional[str] = None, client_args: Dict[str, Any] = {}

@ -29,6 +29,7 @@ class OpenAIChatClient(OpenAIClient):
}
REQUEST_CLS = LMRequest
NAME = "openaichat"
IS_CHAT = True
def connect(
self,

@ -287,7 +287,7 @@ class Manifest:
def run(
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str], List[Dict[str, str]]],
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
@ -296,6 +296,8 @@ class Manifest:
"""
Run the prompt.
Orchestrates between the standard run and chat run and batch run.
Args:
prompt: prompt(s) to run.
overwrite_cache: whether to overwrite cache.
@ -307,9 +309,68 @@ class Manifest:
Returns:
response from prompt.
"""
is_batch = isinstance(prompt, list)
if not isinstance(prompt, list) and not isinstance(prompt, str):
raise ValueError(
f"Invalid prompt type: {type(prompt)}. "
"Prompt must be a string or list of strings "
"or list of dicts."
)
if isinstance(prompt, list) and not prompt:
raise ValueError("Prompt cannot be empty list")
# Get the client to run
client = self.client_pool.get_next_client()
if isinstance(prompt, list) and isinstance(prompt[0], dict):
if not client.IS_CHAT:
raise ValueError(
f"Client {client} does not support dict chat prompt. "
"Please use a chat model."
)
if stop_token:
logger.warning(
"stop_token is not supported for chat prompt. "
"Ignoring stop_token."
)
return self._run_chat(
prompt=cast(List[Dict[str, str]], prompt),
client=client,
overwrite_cache=overwrite_cache,
return_response=return_response,
)
else:
return self._run(
prompt=cast(Union[str, List[str]], prompt),
client=client,
overwrite_cache=overwrite_cache,
stop_token=stop_token,
return_response=return_response,
**kwargs,
)
def _run(
self,
prompt: Union[str, List[str]],
client: Client,
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
**kwargs: Any,
) -> Union[str, List[str], np.ndarray, List[np.ndarray], Response]:
"""
Run the prompt.
Args:
prompt: prompt(s) to run.
client: client to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
Returns:
response from prompt.
"""
is_batch = isinstance(prompt, list)
stop_token = stop_token if stop_token is not None else self.stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = client.get_request(prompt, kwargs)
@ -344,6 +405,67 @@ class Manifest:
else:
return final_response.get_response(stop_token, is_batch)
def _run_chat(
self,
prompt: List[Dict[str, str]],
client: Client,
overwrite_cache: bool = False,
return_response: bool = False,
**kwargs: Any,
) -> Union[str, Response]:
"""
Run the prompt.
Args:
prompt: prompt dictionary to run.
client: client to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
Returns:
response from prompt.
"""
is_batch = False
# Get a request for an empty prompt to handle all kwargs
request_params = client.get_request("", kwargs)
# Add prompt and cast as chat request
request_params_dict = request_params.to_dict()
request_params_dict["prompt"] = prompt
request_params_as_chat = LMChatRequest(**request_params_dict)
# Avoid nested list of results - enforce n = 1 for batch
if request_params_as_chat.n > 1:
raise ValueError("Chat mode does not support n > 1.")
self._validate_kwargs(kwargs, request_params_as_chat)
cached_idx_to_response, request_params_as_chat = self._split_cached_requests( # type: ignore # noqa: E501
request_params_as_chat, client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params_as_chat.prompt:
# Start timing metrics
self.client_pool.start_timer()
response = client.run_chat_request(request_params_as_chat)
self.client_pool.end_timer()
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params_as_chat,
client=client,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
# Extract text results
if return_response:
return final_response
else:
return cast(str, final_response.get_response("", is_batch))
async def arun_batch(
self,
prompts: List[str],
@ -381,6 +503,13 @@ class Manifest:
Returns:
response from prompt.
"""
if not isinstance(prompts, list):
raise ValueError("Prompts must be a list of strings.")
if not prompts:
raise ValueError("Prompts must not be empty.")
if not isinstance(prompts[0], str):
raise ValueError("Prompts must be a list of strings.")
# Split the prompts into chunks
prompt_chunks: List[Tuple[Client, List[str]]] = []
if chunk_size > 0:
@ -464,67 +593,6 @@ class Manifest:
)
return final_response
def run_chat(
self,
prompt: List[Dict[str, str]],
overwrite_cache: bool = False,
return_response: bool = False,
**kwargs: Any,
) -> Union[str, Response]:
"""
Run the prompt.
Args:
prompt: prompt dictionary to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
Returns:
response from prompt.
"""
is_batch = False
# Get the client to run
client = self.client_pool.get_next_client()
# Get a request for an empty prompt to handle all kwargs
request_params = client.get_request("", kwargs)
# Add prompt and cast as chat request
request_params_dict = request_params.to_dict()
request_params_dict["prompt"] = prompt
request_params_as_chat = LMChatRequest(**request_params_dict)
# Avoid nested list of results - enforce n = 1 for batch
if request_params_as_chat.n > 1:
raise ValueError("Chat mode does not support n > 1.")
self._validate_kwargs(kwargs, request_params_as_chat)
cached_idx_to_response, request_params_as_chat = self._split_cached_requests( # type: ignore # noqa: E501
request_params_as_chat, client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params_as_chat.prompt:
# Start timing metrics
self.client_pool.start_timer()
response = client.run_chat_request(request_params_as_chat)
self.client_pool.end_timer()
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params_as_chat,
client=client,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
# Extract text results
if return_response:
return final_response
else:
return cast(str, final_response.get_response("", is_batch))
def score_prompt(
self,
prompt: Union[str, List[str]],

@ -11,6 +11,7 @@ module = [
"diffusers",
"sentence_transformers",
"sqlitedict",
"sqlalchemy",
"dill",
"accelerate",
"accelerate.utils.modeling",

@ -399,11 +399,13 @@ def test_run_chat(sqlite_cache: str) -> None:
cache_name="sqlite",
cache_connection=sqlite_cache,
)
# Set CHAT to be true for this model
manifest.client_pool.client_pool[0].IS_CHAT = True
prompt = [
{"role": "system", "content": "Hello."},
]
result = manifest.run_chat(prompt, return_response=False)
result = manifest.run(prompt, return_response=False)
assert result == "Hello."
assert (
manifest.cache.get(
@ -421,7 +423,7 @@ def test_run_chat(sqlite_cache: str) -> None:
{"role": "system", "content": "Hello."},
{"role": "user", "content": "Goodbye?"},
]
result = manifest.run_chat(prompt, return_response=True)
result = manifest.run(prompt, return_response=True)
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
@ -849,9 +851,9 @@ def test_openaichat(sqlite_cache: str) -> None:
},
{"role": "user", "content": "Where was it played?"},
]
res = client.run_chat(chat_dict)
res = client.run(chat_dict)
assert isinstance(res, str) and len(res) > 0
response = cast(Response, client.run_chat(chat_dict, return_response=True))
response = cast(Response, client.run(chat_dict, return_response=True))
assert response.is_cached() is True
assert response.get_usage_obj().usages[0].total_tokens == 67
chat_dict = [
@ -863,7 +865,7 @@ def test_openaichat(sqlite_cache: str) -> None:
},
{"role": "user", "content": "Where was it played?"},
]
response = cast(Response, client.run_chat(chat_dict, return_response=True))
response = cast(Response, client.run(chat_dict, return_response=True))
assert response.is_cached() is False

Loading…
Cancel
Save