feat: streaming support completions (#99)

pull/100/head
Laurel Orr 12 months ago committed by GitHub
parent b52a4d9a4b
commit 6324e0fe43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,7 @@ Added
^^^^^
* Azure model support (completion and chat)
* Google Vertex API model support (completion and chat)
* Streaming responses for LM Completions (set stream=True)
Fixed
^^^^^

@ -6,9 +6,10 @@ How to make prompt programming with Foundation Models a little easier.
- [Install](#install)
- [Getting Started](#getting-started)
- [Manifest](#manifest-components)
- [Local HuggingFace Models](#local-huggingface-models)
- [Chat Models](#chat-models)
- [Embedding Models](#embedding-models)
- [Other Models Types](#other-models)
- [Local HuggingFace Models](#local-huggingface-models)
- [Chat Models](#chat-models)
- [Embedding Models](#embedding-models)
- [Road Map](#road-map)
- [Development](#development)
- [Cite](#cite)
@ -43,7 +44,7 @@ Running is simple to get started. If using OpenAI, set `export OPENAI_API_KEY=<O
```python
from manifest import Manifest
# Start a manifest session to OpenAI - default `engine=text-davinci-002`
# Start a manifest session to OpenAI - default `engine=text-davinci-003`
manifest = Manifest(
client_name = "openai",
)
@ -142,6 +143,16 @@ If you want to change default parameters to a model, we pass those as `kwargs` t
result = manifest.run(prompt, "Laurel", max_tokens=50)
```
## Streaming Queries
Manifest also supports streaming the model response back, assuming it's supported by the underlying client. When calling `run`, pass `stream=True` to get a streaming iterator in response.
```python
result_iterator = manifest.run("Tell me a story. Once upon a time", max_tokens=100, stream=True)
for res_text in result_iterator:
print(res_text)
```
Streaming responses are only supported for single string queries (not batch mode) for text completion models.
## Model Pools
Manifest supports querying multiple models with different schedulers. This is very much a work in progress effort, but Manifest will round robin select (or randomly select) the clients you want. You can use the same client multiple times with different connection strings (e.g. different API keys), or you can mix and match. The only requirement is that all clients are the same request type. I.e. you can't have a pool of generation models and embedding models.
@ -169,7 +180,9 @@ The speed benefit comes in with async batched runs. When calling `arun_batch` wi
responses = asyncio.run(manifest.arun_batch(prompts, max_tokens=30, chunk_size=20))
```
# Local Huggingface Models
# Other Models
## Local Huggingface Models
To use a HuggingFace generative model, in `manifest/api` we have a Flask application that hosts the models for you.
In a separate terminal or Tmux/Screen session, to load 6B parameters models, run
@ -217,7 +230,7 @@ python3 -m manifest.api.app \
--percent_max_gpu_mem_reduction 0.85
```
# Chat Models
## Chat Models
Manifest has specific support for executing against chat models in the more standard "system" / "user" dialogue. To pass in a dialogue history to Manifest, use the `run` command with a list of dictionary inputs with `role` and `content` keys using an associated chat model such as `openaichat`.
```python
@ -229,7 +242,7 @@ dialogue = [
res = manifest.run(dialogue, max_tokens=100)
```
# Embedding Models
## Embedding Models
Manifest also supports getting embeddings from models and available APIs. We do this all through changing the `client_name` argument. You still use `run` and `abatch_run`.
To use OpenAI's embedding models, simply run
@ -250,11 +263,14 @@ python3 -m manifest.api.app \
Here's what's coming up next
- [ ] Clients
- [ ] HuggingFace Hub
- [ ] Azure OpenAI
- [x] Azure OpenAI
- [x] Google Vertex
- [ ] Anthropic
- [x] Streaming Support Completions
- [ ] Streaming Support Chat Models
- [ ] Data Types
- [ ] Diffusion Models
- [ ] Orchestration
- [x] Orchestration
- [x] Connection pools
- [ ] Local Inference
- [ ] FlexGen

@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"OPENAI_KEY = \"sk-XXX\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use ChatOpenAI\n",
"\n",
"Set you `OPENAI_API_KEY` environment variable."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from manifest import Manifest\n",
"from manifest.connections.client_pool import ClientConnection\n",
"\n",
"openai_chat = ClientConnection(\n",
" client_name=\"openaichat\",\n",
" client_connection=OPENAI_KEY,\n",
" engine=\"gpt-3.5-turbo\"\n",
")\n",
"\n",
"manifest = Manifest(client_pool=[openai_chat])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"manifest_iterator = manifest.run(\"Tell me a story about a fat cat.\\n\\nOnce upon a time\", max_tokens=200, stream=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"cur_line_length = 0\n",
"# Iterate over stream\n",
"for res in manifest_iterator:\n",
" sys.stdout.write(res)\n",
" cur_line_length += len(res)\n",
" if cur_line_length > 80:\n",
" sys.stdout.write(\"\\n\")\n",
" cur_line_length = 0"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "manifest",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -82,6 +82,13 @@ class AI21Client(Client):
"""Return whether the client supports batch inference."""
return False
def supports_streaming_inference(self) -> bool:
"""Return whether the client supports streaming inference.
Override in child client class.
"""
return False
def get_model_params(self) -> Dict:
"""
Get model params.

@ -1,10 +1,11 @@
"""Client class."""
import asyncio
import copy
import json
import logging
import math
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, Generator, List, Optional, Tuple, Union, cast
import aiohttp
import requests
@ -107,6 +108,7 @@ class Client(ABC):
"""
Connect to client.
Override in child client class.
Args:
connection_str: connection string.
"""
@ -114,12 +116,18 @@ class Client(ABC):
@abstractmethod
def close(self) -> None:
"""Close the client."""
"""Close the client.
Override in child client class.
"""
raise NotImplementedError()
@abstractmethod
def get_generation_url(self) -> str:
"""Get generation URL."""
"""Get generation URL.
Override in child client class.
"""
raise NotImplementedError()
@abstractmethod
@ -127,6 +135,7 @@ class Client(ABC):
"""
Get generation header.
Override in child client class.
Returns:
header.
"""
@ -134,7 +143,18 @@ class Client(ABC):
@abstractmethod
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
"""Return whether the client supports batch inference.
Override in child client class.
"""
raise NotImplementedError()
@abstractmethod
def supports_streaming_inference(self) -> bool:
"""Return whether the client supports streaming inference.
Override in child client class.
"""
raise NotImplementedError()
@abstractmethod
@ -145,6 +165,7 @@ class Client(ABC):
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Override in child client class.
Returns:
model params.
"""
@ -153,6 +174,8 @@ class Client(ABC):
def get_tokenizer(self, model: str) -> Tuple[Any, int]:
"""Get tokenizer for model.
Override in child client class. Return None, -1 if not supported
or no prompt truncation required.
Returns:
tokenizer: tokenizer with encoder and decode
max_length: max length of model
@ -177,6 +200,8 @@ class Client(ABC):
"""
Preprocess request params.
Override in child client class to reformat requests to model.
Args:
request: request params.
@ -191,6 +216,8 @@ class Client(ABC):
"""
Postprocess and validate response as dict.
Override in child client class to reform model responses.
Args:
response: response
request: request
@ -314,6 +341,7 @@ class Client(ABC):
final_usages = None
if usages:
final_usages = Usages(usages=[Usage(**usage) for usage in usages])
# TODO: Add usage based on tokenizer
return Response(
self._get_model_choices(final_response_dict),
cached=False,
@ -415,6 +443,55 @@ class Client(ABC):
res_json = await res.json(content_type=None)
return self.postprocess_response(res_json, request_params)
@retry(
reraise=True,
retry=retry_if_ratelimit,
wait=wait_random_exponential(min=1, max=ATTEMPTS_TIMEOUT),
stop=stop_after_attempt(ATTEMPTS_BEFORE_STOP),
)
def _run_streaming_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Generator[Dict, None, None]:
"""Execute completion request streaming.
Args:
request_params: request params.
retry_timeout: retry timeout.
Returns:
response as dict.
"""
request_params = self.preprocess_request_params(request_params)
request_params["stream"] = True
post_str = self.get_generation_url()
res_iter = requests.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
stream=True,
)
for res_token in res_iter.iter_lines():
if res_token:
decoded_res_token = res_token.decode("utf-8")
decoded_res_token = decoded_res_token.replace("data: ", "")
if decoded_res_token == "[DONE]":
break
try:
decoded_res_token_dct = json.loads(decoded_res_token)
postprocess_res_token_dct = self.postprocess_response(
decoded_res_token_dct, request_params
)
# If nothing is returned, skip
if (
not postprocess_res_token_dct
or not postprocess_res_token_dct["choices"]
):
continue
yield postprocess_res_token_dct
except Exception as e:
raise e
def run_request(self, request: Request) -> Response:
"""
Run request.
@ -563,6 +640,45 @@ class Client(ABC):
**RESPONSE_CONSTRUCTORS[LMChatRequest], # type: ignore
)
def run_streaming_request(
self, request: Request
) -> Generator[Response, None, None]:
"""
Run streaming request.
Args:
request: request.
Returns:
response.
"""
if not isinstance(request.prompt, str):
raise ValueError("Streaming requests must have a single prompt.")
if not self.supports_streaming_inference():
raise ValueError(
f"{self.__class__.__name__} does not support streaming inference."
)
request_params = self._get_request_params(request)
# Take the default keys we need and drop the rest as they
# are not part of the model request.
retry_timeout = request_params.pop("client_timeout")
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
# Make sure requests are in the request length
# If no tokenizer is set or not LM request, this
# will do nothing
if isinstance(request, LMRequest):
self._verify_request_lengths(
request_params, model=request.engine, max_tokens=request.max_tokens
)
for token_response in self._run_streaming_completion(
request_params, retry_timeout
):
yield self._stitch_responses(request, [token_response])
def run_score_prompt_request(
self,
request: LMScoreRequest,

@ -81,6 +81,13 @@ class CohereClient(Client):
"""Return whether the client supports batch inference."""
return False
def supports_streaming_inference(self) -> bool:
"""Return whether the client supports streaming inference.
Override in child client class.
"""
return False
def get_model_params(self) -> Dict:
"""
Get model params.

@ -72,6 +72,13 @@ class DiffuserClient(Client):
"""Return whether the client supports batch inference."""
return True
def supports_streaming_inference(self) -> bool:
"""Return whether the client supports streaming inference.
Override in child client class.
"""
return False
def get_model_params(self) -> Dict:
"""
Get model params.

@ -48,6 +48,13 @@ class DummyClient(Client):
"""Return whether the client supports batch inference."""
return True
def supports_streaming_inference(self) -> bool:
"""Return whether the client supports streaming inference.
Override in child client class.
"""
return False
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.

@ -117,6 +117,13 @@ class GoogleClient(Client):
"""Return whether the client supports batch inference."""
return True
def supports_streaming_inference(self) -> bool:
"""Return whether the client supports streaming inference.
Override in child client class.
"""
return False
def get_model_params(self) -> Dict:
"""
Get model params.

@ -66,6 +66,13 @@ class HuggingFaceClient(Client):
"""Return whether the client supports batch inference."""
return True
def supports_streaming_inference(self) -> bool:
"""Return whether the client supports streaming inference.
Override in child client class.
"""
return False
def get_model_params(self) -> Dict:
"""
Get model params.

@ -58,6 +58,13 @@ class HuggingFaceEmbeddingClient(Client):
"""Return whether the client supports batch inference."""
return True
def supports_streaming_inference(self) -> bool:
"""Return whether the client supports streaming inference.
Override in child client class.
"""
return False
def get_model_params(self) -> Dict:
"""
Get model params.

@ -95,6 +95,13 @@ class OpenAIClient(Client):
"""Return whether the client supports batch inference."""
return True
def supports_streaming_inference(self) -> bool:
"""Return whether the client supports streaming inference.
Override in child client class.
"""
return True
def get_model_params(self) -> Dict:
"""
Get model params.

@ -129,6 +129,11 @@ class OpenAIChatClient(OpenAIClient):
new_choices = []
response = copy.deepcopy(response)
for message in response["choices"]:
new_choices.append({"text": message["message"]["content"]})
if "delta" in message:
# This is a streaming response
if "content" in message["delta"]:
new_choices.append({"text": message["delta"]["content"]})
else:
new_choices.append({"text": message["message"]["content"]})
response["choices"] = new_choices
return super().postprocess_response(response, request)

@ -76,6 +76,13 @@ class OpenAIEmbeddingClient(OpenAIClient):
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def supports_streaming_inference(self) -> bool:
"""Return whether the client supports streaming inference.
Override in child client class.
"""
return False
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -111,6 +111,13 @@ class TOMAClient(Client):
"""Return whether the client supports batch inference."""
return False
def supports_streaming_inference(self) -> bool:
"""Return whether the client supports streaming inference.
Override in child client class.
"""
return False
def get_model_params(self) -> Dict:
"""
Get model params.

@ -2,7 +2,18 @@
import asyncio
import copy
import logging
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from typing import (
Any,
Dict,
Generator,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import numpy as np
@ -291,8 +302,17 @@ class Manifest:
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
stream: bool = False,
**kwargs: Any,
) -> Union[str, List[str], np.ndarray, List[np.ndarray], Response]:
) -> Union[
str,
List[str],
np.ndarray,
List[np.ndarray],
Response,
Iterator[str],
Iterator[Response],
]:
"""
Run the prompt.
@ -302,9 +322,11 @@ class Manifest:
prompt: prompt(s) to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
stream: whether to stream the prompt. Only supported
for single string prompts and LMs.
Returns:
response from prompt.
@ -319,6 +341,24 @@ class Manifest:
raise ValueError("Prompt cannot be empty list")
# Get the client to run
client = self.client_pool.get_next_client()
if stream:
if not client.supports_streaming_inference():
raise ValueError(
f"Client {client} does not support streaming inference."
)
if not isinstance(prompt, str):
raise ValueError(
"Stream is only supported for single string prompts. "
"It will soon be supported for chat dictionary prompts, too."
)
return self._run_stream(
prompt=cast(str, prompt),
client=client,
overwrite_cache=overwrite_cache,
stop_token=stop_token,
return_response=return_response,
**kwargs,
)
if isinstance(prompt, list) and isinstance(prompt[0], dict):
if not client.IS_CHAT:
raise ValueError(
@ -337,15 +377,14 @@ class Manifest:
return_response=return_response,
**kwargs,
)
else:
return self._run(
prompt=cast(Union[str, List[str]], prompt),
client=client,
overwrite_cache=overwrite_cache,
stop_token=stop_token,
return_response=return_response,
**kwargs,
)
return self._run(
prompt=cast(Union[str, List[str]], prompt),
client=client,
overwrite_cache=overwrite_cache,
stop_token=stop_token,
return_response=return_response,
**kwargs,
)
def _run(
self,
@ -399,7 +438,6 @@ class Manifest:
response=response,
cached_idx_to_response=cached_idx_to_response,
)
# Extract text results
if return_response:
return final_response
@ -467,6 +505,77 @@ class Manifest:
else:
return cast(str, final_response.get_response("", is_batch))
def _run_stream(
self,
prompt: str,
client: Client,
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
**kwargs: Any,
) -> Union[Generator[str, None, None], Generator[Response, None, None]]:
"""
Run the prompt in a stream.
Args:
prompt: prompt(s) to run.
client: client to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
Returns:
response from prompt.
"""
is_batch = False
stop_token = stop_token if stop_token is not None else self.stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = client.get_request(prompt, kwargs)
# Avoid nested list of results - enforce n = 1 for batch
if request_params.n > 1:
raise ValueError("Stream mode does not support n > 1.")
self._validate_kwargs(kwargs, request_params)
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, client, overwrite_cache
)
if request_params.prompt:
# Because we are streaming, we should have either a cached response
# a prompt to run
assert len(cached_idx_to_response) == 0
response_iter = client.run_streaming_request(request_params)
is_cached = False
else:
assert len(cached_idx_to_response) == 1
response_iter = cached_idx_to_response[0].as_iter()
is_cached = True
saved_responses = []
# Start timing metrics
self.client_pool.start_timer()
for response_token in response_iter:
saved_responses.append(response_token)
if return_response:
yield response_token
else:
yield cast(
Union[str, Response], response_token.get_response("", is_batch)
)
self.client_pool.end_timer()
if not is_cached:
final_response = Response.union_all(
saved_responses, as_single_lmchoice=True
)
self._stitch_responses_and_cache(
request=request_params,
client=client,
response=final_response,
cached_idx_to_response=cached_idx_to_response,
)
async def arun_batch(
self,
prompts: List[str],

@ -1,7 +1,7 @@
"""Client response."""
import copy
import json
from typing import Any, Dict, List, Optional, Type, Union, cast
from typing import Any, Dict, Generator, List, Optional, Type, Union, cast
import numpy as np
from pydantic import BaseModel
@ -154,9 +154,7 @@ class Response:
stop_token: stop token for string generation
is_batch: whether response is batched
"""
process_result = (
lambda x: x.strip().split(stop_token)[0] if stop_token else x.strip()
)
process_result = lambda x: x.split(stop_token)[0] if stop_token else x
extracted_items = [
choice.text if isinstance(choice, LMModelChoice) else choice.array
for choice in self._response.choices
@ -173,8 +171,17 @@ class Response:
return processed_results
@classmethod
def union_all(cls, responses: List["Response"]) -> "Response":
"""Union a list of response."""
def union_all(
cls, responses: List["Response"], as_single_lmchoice: bool = False
) -> "Response":
"""Union a list of response.
Args:
responses: list of responses to union.
as_single_lmchoice: if True, will concatenate all responses into a single
model choice. Useful for merging streaming responses. Only valid
for LMRequest responses.
"""
if not responses:
raise ValueError("Response list is empty.")
if len(responses) == 1:
@ -184,6 +191,9 @@ class Response:
response_type = first_response._response_type
request = first_response.get_request_obj()
if as_single_lmchoice and response_type != "text":
raise ValueError("as_single_lmchoice=True only works for text responses.")
# Make sure all responses have the same keys
if not all(
[
@ -197,7 +207,7 @@ class Response:
# Get all the prompts and model choices
all_prompts = []
all_choices = []
all_usages = []
all_usages: List[Usage] = []
all_engines = []
for res in responses:
all_engines.extend(res.get_request_obj().engine.split(ENGINE_SEP))
@ -213,18 +223,115 @@ class Response:
all_usages.extend([Usage()] * len(res_prompt))
new_request = copy.deepcopy(request)
new_request.engine = ENGINE_SEP.join(sorted(set(all_engines)))
new_request.prompt = all_prompts
new_response = ModelChoices(choices=all_choices)
new_usages = Usages(usages=all_usages)
response_obj = cls(
response=new_response,
cached=any(res.is_cached() for res in responses),
request=new_request,
usages=new_usages,
request_type=request_type,
response_type=response_type,
)
return response_obj
if as_single_lmchoice:
if len(set(all_prompts)) != 1:
raise ValueError("Prompts must be the same for as_single_lmchoice=True")
all_choices_txt = cast(List[LMModelChoice], all_choices) # type: ignore
single_prompt = all_prompts[0]
single_text = "".join([choice.text for choice in all_choices_txt])
single_logprobs = [
logprob
for choice in all_choices_txt
for logprob in choice.token_logprobs or []
]
single_tokens = [
token for choice in all_choices_txt for token in choice.tokens or []
]
single_usage = Usage(
completion_tokens=sum(usg.completion_tokens for usg in all_usages),
prompt_tokens=sum(usg.prompt_tokens for usg in all_usages),
total_tokens=sum(usg.total_tokens for usg in all_usages),
)
new_choices = [
LMModelChoice(
text=single_text,
token_logprobs=single_logprobs,
tokens=single_tokens,
)
]
new_responses = ModelChoices(choices=new_choices) # type: ignore
new_usages = Usages(usages=[single_usage])
new_request.prompt = single_prompt
response_obj = cls(
response=new_responses,
cached=any(res.is_cached() for res in responses),
request=new_request,
usages=new_usages,
request_type=request_type,
response_type=response_type,
)
return response_obj
else:
new_request.prompt = all_prompts
new_response = ModelChoices(choices=all_choices)
new_usages = Usages(usages=all_usages)
response_obj = cls(
response=new_response,
cached=any(res.is_cached() for res in responses),
request=new_request,
usages=new_usages,
request_type=request_type,
response_type=response_type,
)
return response_obj
# Return a token by token iterator over the response
def as_iter(self) -> Generator["Response", None, None]:
"""Return a token by token iterator over the response.
Will return iterator of responses with one token each.
"""
if self._response_type not in {"text"}:
raise ValueError(
f"Invalid response type {self._response_type} for as_iter()"
)
if not self._response.choices:
raise ValueError("No choices in response.")
if len(self._response.choices) > 1:
raise ValueError(
"Response has more than one choice. as_iter() "
"should be over single choice responses."
)
if not isinstance(self._response.choices[0], LMModelChoice):
raise ValueError(
"response_type is text but response is "
f"{self._response.choices[0].__class__}"
)
choice = cast(LMModelChoice, self._response.choices[0])
# If tokens, return iterator of tokens
if choice.tokens:
for token, logprob in zip(choice.tokens, choice.token_logprobs):
yield Response(
response=ModelChoices(
choices=[
LMModelChoice(
text=token, token_logprobs=[logprob], tokens=[token]
)
]
),
cached=self._cached,
request=self._request,
usages=self._usages,
request_type=self._request_type,
response_type=self._response_type,
)
# Otherwise, do it by words
else:
for i, word in enumerate(choice.text.split(" ")):
word = " " + word if i > 0 else word
yield Response(
response=ModelChoices(
choices=[
LMModelChoice(text=word, token_logprobs=None, tokens=None)
]
),
cached=self._cached,
request=self._request,
usages=self._usages,
request_type=self._request_type,
response_type=self._response_type,
)
def serialize(self) -> str:
"""

@ -17,8 +17,10 @@ def model_choice() -> ModelChoices:
"""Get dummy model choice."""
model_choices = ModelChoices(
choices=[
LMModelChoice(text="hello", token_logprobs=[0.1, 0.2]),
LMModelChoice(text="bye", token_logprobs=[0.3]),
LMModelChoice(
text="hello", token_logprobs=[0.1, 0.2], tokens=["hel", "lo"]
),
LMModelChoice(text="bye", token_logprobs=[0.3], tokens=["bye"]),
]
)
return model_choices
@ -29,7 +31,9 @@ def model_choice_single() -> ModelChoices:
"""Get dummy model choice."""
model_choices = ModelChoices(
choices=[
LMModelChoice(text="helloo", token_logprobs=[0.1, 0.2]),
LMModelChoice(
text="helloo", token_logprobs=[0.1, 0.2], tokens=["hel", "loo"]
),
]
)
return model_choices

@ -1,7 +1,7 @@
"""Manifest test."""
import asyncio
import os
from typing import cast
from typing import Iterator, cast
from unittest.mock import MagicMock, Mock, patch
import numpy as np
@ -787,6 +787,45 @@ def test_openai(sqlite_cache: str) -> None:
)
assert response.is_cached() is True
# Test streaming
num_responses = 0
streaming_response_text = cast(
Iterator[str], client.run("Why are there oranges?", stream=True)
)
for res_text in streaming_response_text:
num_responses += 1
assert isinstance(res_text, str) and len(res_text) > 0
assert num_responses == 8
streaming_response = cast(
Iterator[Response],
client.run("Why are there mandarines?", return_response=True, stream=True),
)
num_responses = 0
merged_res = []
for res in streaming_response:
num_responses += 1
assert isinstance(res, Response) and len(res.get_response()) > 0
merged_res.append(cast(str, res.get_response()))
assert not res.is_cached()
assert num_responses == 10
# Make sure cached
streaming_response = cast(
Iterator[Response],
client.run("Why are there mandarines?", return_response=True, stream=True),
)
num_responses = 0
merged_res_cachced = []
for res in streaming_response:
num_responses += 1
assert isinstance(res, Response) and len(res.get_response()) > 0
merged_res_cachced.append(cast(str, res.get_response()))
assert res.is_cached()
# OpenAI stream does not return logprobs, so this is by number of words
assert num_responses == 7
assert "".join(merged_res) == "".join(merged_res_cachced)
@pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set")
@pytest.mark.usefixtures("sqlite_cache")
@ -796,6 +835,7 @@ def test_openaichat(sqlite_cache: str) -> None:
client_name="openaichat",
cache_name="sqlite",
cache_connection=sqlite_cache,
temperature=0.0,
)
res = client.run("Why are there apples?")
@ -868,6 +908,45 @@ def test_openaichat(sqlite_cache: str) -> None:
response = cast(Response, client.run(chat_dict, return_response=True))
assert response.is_cached() is False
# Test streaming
num_responses = 0
streaming_response_text = cast(
Iterator[str], client.run("Why are there oranges?", stream=True)
)
for res_text in streaming_response_text:
num_responses += 1
assert isinstance(res_text, str) and len(res_text) > 0
assert num_responses == 9
streaming_response = cast(
Iterator[Response],
client.run("Why are there mandarines?", return_response=True, stream=True),
)
num_responses = 0
merged_res = []
for res in streaming_response:
num_responses += 1
assert isinstance(res, Response) and len(res.get_response()) > 0
merged_res.append(cast(str, res.get_response()))
assert not res.is_cached()
assert num_responses == 10
# Make sure cached
streaming_response = cast(
Iterator[Response],
client.run("Why are there mandarines?", return_response=True, stream=True),
)
num_responses = 0
merged_res_cachced = []
for res in streaming_response:
num_responses += 1
assert isinstance(res, Response) and len(res.get_response()) > 0
merged_res_cachced.append(cast(str, res.get_response()))
assert res.is_cached()
# OpenAI stream does not return logprobs, so this is by number of words
assert num_responses == 7
assert "".join(merged_res) == "".join(merged_res_cachced)
@pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set")
@pytest.mark.usefixtures("sqlite_cache")
@ -1156,7 +1235,7 @@ def test_retry_handling() -> None:
with patch("manifest.clients.client.requests.post", mock_create):
# Run manifest
result = client.run(prompts, temperature=0, overwrite_cache=True)
assert result == ["WHATTT.", "UH OH.", "HARG"]
assert result == [" WHATTT.", " UH OH.", " HARG"]
# Assert that OpenAI client was called twice
assert mock_create.call_count == 2

@ -6,7 +6,13 @@ import pytest
from manifest import Response
from manifest.request import EmbeddingRequest, LMRequest
from manifest.response import ArrayModelChoice, ModelChoices, Usage, Usages
from manifest.response import (
ArrayModelChoice,
LMModelChoice,
ModelChoices,
Usage,
Usages,
)
def test_init(
@ -275,9 +281,9 @@ def test_union_all(
final_response = Response.union_all([response1, response2])
assert final_response.get_json_response() == {
"choices": [
{"text": "hello", "token_logprobs": [0.1, 0.2], "tokens": None},
{"text": "bye", "token_logprobs": [0.3], "tokens": None},
{"text": "helloo", "token_logprobs": [0.1, 0.2], "tokens": None},
{"text": "hello", "token_logprobs": [0.1, 0.2], "tokens": ["hel", "lo"]},
{"text": "bye", "token_logprobs": [0.3], "tokens": ["bye"]},
{"text": "helloo", "token_logprobs": [0.1, 0.2], "tokens": ["hel", "loo"]},
]
}
assert final_response.get_usage_obj() == Usages(usages=[Usage(), Usage(), Usage()])
@ -299,3 +305,83 @@ def test_union_all(
assert final_response.get_usage_obj() == Usages(
usages=[Usage(total_tokens=4), Usage(total_tokens=6), Usage()]
)
# Test merge to single
model_choices = ModelChoices(
choices=[
LMModelChoice(
text=" helloo this is a bug",
token_logprobs=[0.1, 0.2, 0.3],
tokens=[" helloo", " this is", " a bug"],
),
]
)
request = LMRequest(prompt="monkey", engine="dummy")
response1 = Response(
response=model_choices,
cached=False,
request=request,
usages=None,
request_type=LMRequest,
response_type="text",
)
final_response = Response.union_all([response1, response1], as_single_lmchoice=True)
assert final_response.get_json_response() == {
"choices": [
{
"text": " helloo this is a bug helloo this is a bug",
"token_logprobs": [0.1, 0.2, 0.3, 0.1, 0.2, 0.3],
"tokens": [
" helloo",
" this is",
" a bug",
" helloo",
" this is",
" a bug",
],
},
]
}
assert final_response.get_usage_obj() == Usages(usages=[Usage()])
assert final_response.get_request_obj().prompt == "monkey"
assert final_response.get_request_obj().engine == "dummy"
def test_as_iter(
model_choice_single: ModelChoices, request_lm_single: LMRequest
) -> None:
"""Test as iter."""
response = Response(
response=model_choice_single,
cached=False,
request=request_lm_single,
usages=None,
request_type=LMRequest,
response_type="text",
)
response_iter_list = list(response.as_iter())
assert len(response_iter_list) == 2
assert response_iter_list[0].get_response() == "hel"
assert response_iter_list[1].get_response() == "loo"
model_choices = ModelChoices(
choices=[
LMModelChoice(text="helloo this is a bug"),
]
)
request = LMRequest(prompt="monkey", engine="dummy")
response = Response(
response=model_choices,
cached=False,
request=request,
usages=None,
request_type=LMRequest,
response_type="text",
)
response_iter_list = list(response.as_iter())
assert len(response_iter_list) == 5
assert response_iter_list[0].get_response() == "helloo"
assert response_iter_list[1].get_response() == " this"
assert response_iter_list[2].get_response() == " is"
assert response_iter_list[3].get_response() == " a"
assert response_iter_list[4].get_response() == " bug"

Loading…
Cancel
Save