fix: added openai usage back (#69)

pull/82/head
Laurel Orr 1 year ago committed by GitHub
parent 395ac06a95
commit e4d3a57f92

@ -94,12 +94,13 @@ class AI21Client(Client):
"""
return {"model_name": "ai21", "engine": getattr(self, "engine")}
def format_response(self, response: Dict) -> Dict[str, Any]:
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict

@ -142,18 +142,34 @@ class Client(ABC):
request_params.update(self.get_model_params())
return request_params
def format_response(self, response: Dict) -> Dict[str, Any]:
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
"""Split usage into list of usages for each prompt."""
return []
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict
"""
if "choices" not in response:
raise ValueError(f"Invalid response: {response}")
if "usage" in response:
# Handle splitting the usages for batch requests
if len(response["choices"]) == 1:
if isinstance(response["usage"], list):
response["usage"] = response["usage"][0]
response["usage"] = [response["usage"]]
else:
# Try to split usage
split_usage = self.split_usage(request, response["choices"])
if split_usage:
response["usage"] = split_usage
return response
def split_requests(
@ -208,7 +224,7 @@ class Client(ABC):
except requests.exceptions.HTTPError:
logger.error(res.json())
raise requests.exceptions.HTTPError(res.json())
return self.format_response(res.json())
return self.format_response(res.json(), request_params)
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
@ -234,7 +250,7 @@ class Client(ABC):
) as res:
res.raise_for_status()
res_json = await res.json(content_type=None)
return self.format_response(res_json)
return self.format_response(res_json, request_params)
except aiohttp.ClientError as e:
logger.error(f"{self.__class__.__name__} request error {e}")
raise e
@ -307,9 +323,14 @@ class Client(ABC):
responses = await asyncio.gather(*all_tasks)
# Flatten responses
choices = []
usages = []
for res_dict in responses:
choices.extend(res_dict["choices"])
final_response_dict = self.format_response({"choices": choices})
if "usage" in res_dict:
usages.extend(res_dict["usage"])
final_response_dict = {"choices": choices}
if usages:
final_response_dict["usage"] = usages
return Response(
final_response_dict,
cached=False,

@ -93,12 +93,13 @@ class CohereClient(Client):
"""
return {"model_name": "cohere", "engine": getattr(self, "engine")}
def format_response(self, response: Dict) -> Dict[str, Any]:
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict

@ -85,12 +85,13 @@ class DiffuserClient(Client):
res = requests.post(self.host + "/params")
return res.json()
def format_response(self, response: Dict) -> Dict[str, Any]:
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict

@ -89,7 +89,10 @@ class DummyClient(Client):
response_dict = {
"choices": [{"text": "hello"}]
* int(request_params["num_results"])
* num_results
* num_results,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}]
* int(request_params["num_results"])
* num_results,
}
return Response(response_dict, False, request_params)

@ -1,7 +1,9 @@
"""OpenAI client."""
import logging
import os
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import tiktoken
from manifest.clients.client import Client
from manifest.request import LMRequest
@ -102,3 +104,28 @@ class OpenAIClient(Client):
model params.
"""
return {"model_name": "openai", "engine": getattr(self, "engine")}
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
"""Split usage into list of usages for each prompt."""
try:
encoding = tiktoken.encoding_for_model(getattr(self, "engine"))
except Exception:
return []
prompt = request["prompt"]
# If n > 1 and prompt is a string, we need to split it into a list
if isinstance(prompt, str):
prompts = [prompt] * len(choices)
else:
prompts = prompt
assert len(prompts) == len(choices)
usages = []
for pmt, chc in zip(prompts, choices):
pmt_tokens = len(encoding.encode(pmt))
chc_tokens = len(encoding.encode(chc["text"])) # type: ignore
usage = {
"prompt_tokens": pmt_tokens,
"completion_tokens": chc_tokens,
"total_tokens": pmt_tokens + chc_tokens,
}
usages.append(usage)
return usages

@ -4,7 +4,7 @@ import logging
import os
from typing import Any, Dict, Optional
from manifest.clients.client import Client
from manifest.clients.openai import OpenAIClient
from manifest.request import LMRequest
logger = logging.getLogger(__name__)
@ -14,7 +14,7 @@ OPENAICHAT_ENGINES = {
}
class OpenAIChatClient(Client):
class OpenAIChatClient(OpenAIClient):
"""OpenAI Chat client."""
# User param -> (client param, default value)
@ -60,23 +60,10 @@ class OpenAIChatClient(Client):
f"Must be {OPENAICHAT_ENGINES}."
)
def close(self) -> None:
"""Close the client."""
pass
def get_generation_url(self) -> str:
"""Get generation URL."""
return self.host + "/chat/completions"
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return {"Authorization": f"Bearer {self.api_key}"}
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return False

@ -143,12 +143,13 @@ class TOMAClient(Client):
}
return heartbeats
def format_response(self, response: Dict) -> Dict[str, Any]:
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict

@ -46,12 +46,13 @@ class TOMADiffuserClient(TOMAClient):
"""
return {"model_name": "tomadiffuser", "engine": getattr(self, "engine")}
def format_response(self, response: Dict) -> Dict[str, Any]:
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict

@ -222,6 +222,7 @@ class Manifest:
# We stitch the responses (the choices) here from both the new request the
# cached entries.
all_model_choices = []
all_usages = []
all_input_prompts = []
response_idx = 0
number_prompts = len(cached_idx_to_response)
@ -241,24 +242,23 @@ class Manifest:
response_gen_key = cached_res.generation_key
response_logits_key = cached_res.logits_key
response_item_key = cached_res.item_key
response_usage_key = cached_res.usage_key
all_input_prompts.append(cached_res.get_request()["prompt"])
json_response = cached_res.get_json_response()
if request.n == 1:
assert (
len(cached_res.get_json_response()[response_gen_key]) == 1
len(json_response[response_gen_key]) == 1
), "cached response should have only one choice"
all_model_choices.append(
cached_res.get_json_response()[response_gen_key][0]
)
else:
all_model_choices.extend(
cached_res.get_json_response()[response_gen_key]
)
all_model_choices.extend(json_response[response_gen_key])
if response_usage_key:
all_usages.extend(json_response[response_usage_key])
else:
assert response is not None, "response should not be None"
response = cast(Response, response)
response_gen_key = response.generation_key
response_logits_key = response.logits_key
response_item_key = response.item_key
response_usage_key = response.usage_key
# the choices list in the response is a flat one.
# length is request.n * num_prompts
current_choices = response.get_json_response()[response_gen_key][
@ -270,6 +270,11 @@ class Manifest:
prompt = response.get_request()["prompt"][response_idx]
else:
prompt = str(response.get_request()["prompt"])
if response_usage_key:
usage = response.get_json_response()[response_usage_key][
response_idx * request.n : (response_idx + 1) * request.n
]
all_usages.extend(usage)
all_input_prompts.append(prompt)
# set cache
new_request = copy.deepcopy(request)
@ -277,6 +282,8 @@ class Manifest:
cache_key = self.client.get_cache_key(new_request)
new_response_key = copy.deepcopy(response.get_json_response())
new_response_key[response_gen_key] = current_choices
if response_usage_key:
new_response_key[response_usage_key] = usage
self.cache.set(cache_key, new_response_key)
response_idx += 1
@ -286,13 +293,17 @@ class Manifest:
if len(all_input_prompts) > 1 or not single_output
else all_input_prompts[0]
)
new_response = {response_gen_key: all_model_choices}
if response_usage_key:
new_response[response_usage_key] = all_usages
response_obj = Response(
{response_gen_key: all_model_choices},
new_response,
cached=len(cached_idx_to_response) > 0,
request_params=self.client.get_cache_key(new_request),
generation_key=response_gen_key,
logits_key=response_logits_key,
item_key=response_item_key,
usage_key=response_usage_key,
)
return response_obj

@ -31,12 +31,13 @@ class Response:
def __init__(
self,
response: Dict,
response: Dict, # TODO: make pydantic model
cached: bool,
request_params: Dict,
request_params: Dict, # TODO: use request pydantic model
generation_key: str = "choices",
logits_key: str = "token_logprobs",
item_key: str = "text",
usage_key: str = "usage",
):
"""
Initialize response.
@ -52,6 +53,7 @@ class Response:
self.generation_key = generation_key
self.logits_key = logits_key
self.item_key = item_key
self.usage_key = usage_key
self.item_dtype = None
if isinstance(response, dict):
self._response = response
@ -66,6 +68,16 @@ class Response:
"Response must be serialized to a dict with a nonempty"
f" list of choices. Response is\n{self._response}."
)
# Turn off usage if it is not in response
if self.usage_key not in self._response:
self.usage_key = None
else:
if not isinstance(self._response[self.usage_key], list):
raise ValueError(
"Response must be a list with usage dicts, one per choice."
f" Response is\n{self._response}."
)
if self.item_key not in self._response[self.generation_key][0]:
raise ValueError(
"Response must be serialized to a dict with a "

@ -35,6 +35,7 @@ REQUIRED = [
"aiohttp>=3.8.0",
"sqlitedict>=2.0.0",
"xxhash>=3.0.0",
"tiktoken>=0.3.0",
]
# What packages are optional?

@ -34,7 +34,10 @@ def test_get_request() -> None:
"num_results": 3,
"engine": "dummy",
}
assert response.get_json_response() == {"choices": [{"text": "hello"}] * 3}
assert response.get_json_response() == {
"choices": [{"text": "hello"}] * 3,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 3,
}
request_params = client.get_request("hello", {"n": 5})
response = client.run_request(request_params)
@ -43,7 +46,10 @@ def test_get_request() -> None:
"num_results": 5,
"engine": "dummy",
}
assert response.get_json_response() == {"choices": [{"text": "hello"}] * 5}
assert response.get_json_response() == {
"choices": [{"text": "hello"}] * 5,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
}
request_params = client.get_request(["hello"] * 5, {"n": 1})
response = client.run_request(request_params)
@ -52,4 +58,7 @@ def test_get_request() -> None:
"num_results": 1,
"engine": "dummy",
}
assert response.get_json_response() == {"choices": [{"text": "hello"}] * 5}
assert response.get_json_response() == {
"choices": [{"text": "hello"}] * 5,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
}

@ -1,5 +1,6 @@
"""Manifest test."""
import asyncio
import os
from typing import cast
import pytest
@ -17,6 +18,8 @@ try:
except Exception:
MODEL_ALIVE = False
OPENAI_ALIVE = os.environ.get("OPENAI_API_KEY") is not None
@pytest.mark.usefixtures("sqlite_cache")
def test_init(sqlite_cache: str) -> None:
@ -104,7 +107,11 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
result = manifest.run(prompt, return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = cast(Response, result).get_response(manifest.stop_token)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token)
else:
res = cast(str, result)
assert (
@ -126,7 +133,11 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
result = manifest.run(prompt, run_id="34", return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = cast(Response, result).get_response(manifest.stop_token)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token)
else:
res = cast(str, result)
assert (
@ -149,7 +160,11 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
result = manifest.run(prompt, return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = cast(Response, result).get_response(manifest.stop_token)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token)
else:
res = cast(str, result)
assert (
@ -171,7 +186,11 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
result = manifest.run(prompt, stop_token="ll", return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = cast(Response, result).get_response(stop_token="ll")
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(stop_token="ll")
else:
res = cast(str, result)
assert (
@ -209,9 +228,12 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
else:
result = manifest.run(prompt, return_response=return_response)
if return_response:
res = cast(Response, result).get_response(
manifest.stop_token, is_batch=True
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token, is_batch=True)
else:
res = cast(str, result)
assert res == ["hello"]
@ -229,9 +251,12 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
prompt = ["Hello is a prompt", "Hello is a prompt"]
result = manifest.run(prompt, return_response=return_response)
if return_response:
res = cast(Response, result).get_response(
manifest.stop_token, is_batch=True
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token, is_batch=True)
else:
res = cast(str, result)
assert res == ["hello", "hello"]
@ -263,11 +288,14 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
prompt = ["This is a prompt", "New prompt"]
result = manifest.run(prompt, return_response=return_response)
if return_response:
res = cast(Response, result).get_response(
manifest.stop_token, is_batch=True
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token, is_batch=True)
# Cached because one item is in cache
assert cast(Response, result).is_cached()
assert result.is_cached()
else:
res = cast(str, result)
assert res == ["hello", "hello"]
@ -275,7 +303,12 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
prompt = ["Hello is a prompt", "Hello is a prompt"]
result = manifest.run(prompt, stop_token="ll", return_response=return_response)
if return_response:
res = cast(Response, result).get_response(stop_token="ll", is_batch=True)
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(stop_token="ll", is_batch=True)
else:
res = cast(str, result)
assert res == ["he", "he"]
@ -290,9 +323,14 @@ def test_abatch_run(sqlite_cache: str) -> None:
cache_connection=sqlite_cache,
)
prompt = ["This is a prompt"]
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
result = cast(
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
)
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token, is_batch=True)
assert res == ["hello"]
assert (
manifest.cache.get(
@ -306,8 +344,14 @@ def test_abatch_run(sqlite_cache: str) -> None:
)
prompt = ["Hello is a prompt", "Hello is a prompt"]
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
result = cast(
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token, is_batch=True)
assert res == ["hello", "hello"]
assert (
manifest.cache.get(
@ -320,9 +364,15 @@ def test_abatch_run(sqlite_cache: str) -> None:
is not None
)
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
assert cast(Response, result).is_cached()
result = cast(
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token, is_batch=True)
assert result.is_cached()
assert (
manifest.cache.get(
@ -335,15 +385,27 @@ def test_abatch_run(sqlite_cache: str) -> None:
is None
)
prompt = ["This is a prompt", "New prompt"]
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
result = cast(
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(manifest.stop_token, is_batch=True)
# Cached because one item is in cache
assert cast(Response, result).is_cached()
assert result.is_cached()
assert res == ["hello", "hello"]
prompt = ["Hello is a prompt", "Hello is a prompt"]
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
res = cast(Response, result).get_response(stop_token="ll", is_batch=True)
result = cast(
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
)
assert len(result.get_json_response()["usage"]) == len(
result.get_json_response()["choices"]
)
res = result.get_response(stop_token="ll", is_batch=True)
assert res == ["he", "he"]
@ -484,3 +546,138 @@ def test_local_huggingface(sqlite_cache: str) -> None:
assert len(scores["response"]["choices"][0]["token_logprobs"]) == len(
scores["response"]["choices"][0]["tokens"]
)
@pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set")
@pytest.mark.usefixtures("sqlite_cache")
def test_openai(sqlite_cache: str) -> None:
"""Test openai client."""
client = Manifest(
client_name="openai",
engine="text-ada-001",
cache_name="sqlite",
cache_connection=sqlite_cache,
temperature=0.0,
)
res = client.run("Why are there apples?")
assert isinstance(res, str) and len(res) > 0
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
assert response.is_cached() is True
assert "usage" in response.get_json_response()
assert response.get_json_response()["usage"][0]["total_tokens"] == 15
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert response.is_cached() is True
res_list = client.run(["Why are there apples?", "Why are there bananas?"])
assert isinstance(res_list, list) and len(res_list) == 2
response = cast(
Response,
client.run(
["Why are there apples?", "Why are there mangos?"], return_response=True
),
)
assert (
isinstance(response.get_response(), list) and len(response.get_response()) == 2
)
assert (
"usage" in response.get_json_response()
and len(response.get_json_response()["usage"]) == 2
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 15
assert response.get_json_response()["usage"][1]["total_tokens"] == 16
response = cast(
Response, client.run("Why are there bananas?", return_response=True)
)
assert response.is_cached() is True
res_list = asyncio.run(
client.arun_batch(["Why are there pears?", "Why are there oranges?"])
)
assert isinstance(res_list, list) and len(res_list) == 2
response = cast(
Response,
asyncio.run(
client.arun_batch(
["Why are there pinenuts?", "Why are there cocoa?"],
return_response=True,
)
),
)
assert (
isinstance(response.get_response(), list) and len(response.get_response()) == 2
)
assert (
"usage" in response.get_json_response()
and len(response.get_json_response()["usage"]) == 2
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 17
assert response.get_json_response()["usage"][1]["total_tokens"] == 15
response = cast(
Response, client.run("Why are there oranges?", return_response=True)
)
assert response.is_cached() is True
@pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set")
@pytest.mark.usefixtures("sqlite_cache")
def test_openaichat(sqlite_cache: str) -> None:
"""Test openaichat client."""
client = Manifest(
client_name="openaichat",
cache_name="sqlite",
cache_connection=sqlite_cache,
)
res = client.run("Why are there apples?")
assert isinstance(res, str) and len(res) > 0
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
assert response.is_cached() is True
assert "usage" in response.get_json_response()
assert response.get_json_response()["usage"][0]["total_tokens"] == 22
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert response.is_cached() is True
response = cast(
Response, client.run("Why are there oranges?", return_response=True)
)
assert response.is_cached() is False
res_list = asyncio.run(
client.arun_batch(["Why are there pears?", "Why are there oranges?"])
)
assert isinstance(res_list, list) and len(res_list) == 2
response = cast(
Response,
asyncio.run(
client.arun_batch(
["Why are there pinenuts?", "Why are there cocoa?"],
return_response=True,
)
),
)
assert (
isinstance(response.get_response(), list) and len(response.get_response()) == 2
)
assert (
"usage" in response.get_json_response()
and len(response.get_json_response()["usage"]) == 2
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 24
assert response.get_json_response()["usage"][1]["total_tokens"] == 22
response = cast(
Response, client.run("Why are there oranges?", return_response=True)
)
assert response.is_cached() is True

Loading…
Cancel
Save