mirror of https://github.com/HazyResearch/manifest
fix: added pydantic types to response (#84)
parent
4602fb919b
commit
d7401c6ec5
@ -1,230 +1,301 @@
|
|||||||
"""Response test."""
|
"""Response test."""
|
||||||
from typing import Any, Dict
|
from typing import List, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from manifest import Response
|
from manifest import Response
|
||||||
from manifest.request import LMRequest
|
from manifest.request import EmbeddingRequest, LMRequest
|
||||||
|
from manifest.response import ArrayModelChoice, ModelChoices, Usage, Usages
|
||||||
|
|
||||||
|
|
||||||
def test_init() -> None:
|
def test_init(
|
||||||
|
model_choice: ModelChoices,
|
||||||
|
model_choice_arr: ModelChoices,
|
||||||
|
model_choice_arr_int: ModelChoices,
|
||||||
|
request_lm: LMRequest,
|
||||||
|
request_array: EmbeddingRequest,
|
||||||
|
) -> None:
|
||||||
"""Test response initialization."""
|
"""Test response initialization."""
|
||||||
with pytest.raises(ValueError) as exc_info:
|
response = Response(
|
||||||
response = Response(4, False, {}) # type: ignore
|
response=model_choice,
|
||||||
assert str(exc_info.value) == "Response must be dict. Response is\n4."
|
cached=False,
|
||||||
with pytest.raises(ValueError) as exc_info:
|
request=request_lm,
|
||||||
response = Response({"test": "hello"}, False, {})
|
usages=None,
|
||||||
assert str(exc_info.value) == (
|
request_type=LMRequest,
|
||||||
"Response must be serialized to a dict with a nonempty list of choices. "
|
response_type="text",
|
||||||
"Response is\n{'test': 'hello'}."
|
)
|
||||||
)
|
assert response._response == model_choice
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
response = Response({"choices": [{"blah": "hello"}]}, False, {})
|
|
||||||
assert str(exc_info.value) == (
|
|
||||||
"Response must be serialized to a dict "
|
|
||||||
"with a list of choices with text field"
|
|
||||||
)
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
response = Response({"choices": []}, False, {})
|
|
||||||
assert str(exc_info.value) == (
|
|
||||||
"Response must be serialized to a dict with a nonempty list of choices. "
|
|
||||||
"Response is\n{'choices': []}."
|
|
||||||
)
|
|
||||||
|
|
||||||
response = Response({"choices": [{"text": "hello"}]}, False, {})
|
|
||||||
assert response._response == {"choices": [{"text": "hello"}]}
|
|
||||||
assert response._cached is False
|
assert response._cached is False
|
||||||
assert response._request_params == {}
|
assert response._request == request_lm
|
||||||
assert response.item_dtype is None
|
assert response._usages == Usages(usages=[])
|
||||||
|
assert response._request_type == LMRequest
|
||||||
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
|
assert response._response_type == "text"
|
||||||
assert response._response == {"choices": [{"text": "hello"}]}
|
assert response._item_dtype is None
|
||||||
assert response._cached is True
|
|
||||||
assert response._request_params == {"request": "yoyo"}
|
|
||||||
assert response.item_dtype is None
|
|
||||||
|
|
||||||
response = Response(
|
response = Response(
|
||||||
{"generations": [{"txt": "hello"}], "logits": []},
|
response=model_choice_arr_int,
|
||||||
False,
|
cached=False,
|
||||||
{},
|
request=request_array,
|
||||||
generation_key="generations",
|
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
|
||||||
logits_key="logits",
|
request_type=EmbeddingRequest,
|
||||||
item_key="txt",
|
response_type="array",
|
||||||
)
|
)
|
||||||
assert response._response == {"generations": [{"txt": "hello"}], "logits": []}
|
|
||||||
assert response._cached is False
|
assert response._cached is False
|
||||||
assert response._request_params == {}
|
assert response._request == request_array
|
||||||
assert response.item_dtype is None
|
assert sum([usg.total_tokens for usg in response._usages.usages]) == 10
|
||||||
|
assert response._request_type == EmbeddingRequest
|
||||||
|
assert response._response_type == "array"
|
||||||
|
assert response._item_dtype == "int64"
|
||||||
|
|
||||||
int_arr = np.random.randint(20, size=(4, 4))
|
with pytest.raises(ValueError) as excinfo:
|
||||||
response = Response(
|
Response(
|
||||||
{"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array"
|
response=model_choice,
|
||||||
|
cached=False,
|
||||||
|
request=request_lm,
|
||||||
|
usages=None,
|
||||||
|
request_type=LMRequest,
|
||||||
|
response_type="blah",
|
||||||
|
)
|
||||||
|
assert "blah" in str(excinfo.value)
|
||||||
|
|
||||||
|
# Can't convert array with text
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
Response(
|
||||||
|
response=model_choice,
|
||||||
|
cached=False,
|
||||||
|
request=request_lm,
|
||||||
|
usages=None,
|
||||||
|
request_type=LMRequest,
|
||||||
|
response_type="array",
|
||||||
|
)
|
||||||
|
assert str(excinfo.value) == (
|
||||||
|
"response_type is array but response is "
|
||||||
|
"<class 'manifest.response.LMModelChoice'>"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Can't convert text with array
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
Response(
|
||||||
|
response=model_choice_arr,
|
||||||
|
cached=False,
|
||||||
|
request=request_array,
|
||||||
|
usages=None,
|
||||||
|
request_type=LMRequest,
|
||||||
|
response_type="text",
|
||||||
|
)
|
||||||
|
assert str(excinfo.value) == (
|
||||||
|
"response_type is text but response is "
|
||||||
|
"<class 'manifest.response.ArrayModelChoice'>"
|
||||||
)
|
)
|
||||||
assert response._response == {"choices": [{"array": int_arr}]}
|
|
||||||
assert response._cached is True
|
|
||||||
assert response._request_params == {"request": "yoyo"}
|
|
||||||
assert response.item_dtype == "int64"
|
|
||||||
|
|
||||||
|
|
||||||
def test_getters() -> None:
|
def test_getters(model_choice: ModelChoices, request_lm: LMRequest) -> None:
|
||||||
"""Test response cached."""
|
"""Test response cached."""
|
||||||
response = Response({"choices": [{"text": "hello"}]}, False, {})
|
response = Response(
|
||||||
assert response.get_json_response() == {"choices": [{"text": "hello"}]}
|
response=model_choice,
|
||||||
|
cached=False,
|
||||||
|
request=request_lm,
|
||||||
|
usages=None,
|
||||||
|
request_type=LMRequest,
|
||||||
|
response_type="text",
|
||||||
|
)
|
||||||
|
assert response.get_response_obj() == model_choice
|
||||||
assert response.is_cached() is False
|
assert response.is_cached() is False
|
||||||
assert response.get_request() == {}
|
assert response.get_request_obj() == request_lm
|
||||||
|
assert response.get_usage_obj() == Usages(usages=[])
|
||||||
|
assert response.get_json_response() == model_choice.dict()
|
||||||
|
assert response.get_response() == ["hello", "bye"]
|
||||||
|
|
||||||
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
|
|
||||||
assert response.get_json_response() == {"choices": [{"text": "hello"}]}
|
|
||||||
assert response.is_cached() is True
|
|
||||||
assert response.get_request() == {"request": "yoyo"}
|
|
||||||
|
|
||||||
int_arr = np.random.randint(20, size=(4, 4))
|
def test_serialize(
|
||||||
|
model_choice: ModelChoices,
|
||||||
|
model_choice_arr: ModelChoices,
|
||||||
|
model_choice_arr_int: ModelChoices,
|
||||||
|
request_lm: LMRequest,
|
||||||
|
request_array: EmbeddingRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test response serialization."""
|
||||||
response = Response(
|
response = Response(
|
||||||
{"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array"
|
response=model_choice,
|
||||||
|
cached=False,
|
||||||
|
request=request_lm,
|
||||||
|
usages=None,
|
||||||
|
request_type=LMRequest,
|
||||||
|
response_type="text",
|
||||||
)
|
)
|
||||||
assert response.get_json_response() == {"choices": [{"array": int_arr}]}
|
deserialized_response = Response.deserialize(response.serialize())
|
||||||
assert response.is_cached() is True
|
assert deserialized_response.get_response_obj() == model_choice
|
||||||
assert response.get_request() == {"request": "yoyo"}
|
assert deserialized_response.is_cached() is False
|
||||||
|
assert deserialized_response.get_request_obj() == request_lm
|
||||||
|
assert deserialized_response.get_usage_obj() == Usages(usages=[])
|
||||||
|
assert deserialized_response.get_json_response() == model_choice.dict()
|
||||||
|
assert deserialized_response.get_response() == ["hello", "bye"]
|
||||||
|
|
||||||
|
deserialized_response = Response.from_dict(response.to_dict())
|
||||||
|
assert deserialized_response.get_response_obj() == model_choice
|
||||||
|
assert deserialized_response.is_cached() is False
|
||||||
|
assert deserialized_response.get_request_obj() == request_lm
|
||||||
|
assert deserialized_response.get_usage_obj() == Usages(usages=[])
|
||||||
|
assert deserialized_response.get_json_response() == model_choice.dict()
|
||||||
|
assert deserialized_response.get_response() == ["hello", "bye"]
|
||||||
|
|
||||||
def test_serialize() -> None:
|
deserialized_response = Response.from_dict(
|
||||||
"""Test response serialization."""
|
response.to_dict(drop_request=True), request_dict={"prompt": "blahhhh"}
|
||||||
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
|
)
|
||||||
deserialized_response = Response.deserialize(response.serialize())
|
assert deserialized_response.get_response_obj() == model_choice
|
||||||
assert deserialized_response._response == {"choices": [{"text": "hello"}]}
|
assert deserialized_response.is_cached() is False
|
||||||
assert deserialized_response.is_cached() is True
|
assert deserialized_response.get_request_obj().prompt == "blahhhh"
|
||||||
assert deserialized_response._request_params == {"request": "yoyo"}
|
assert deserialized_response.get_usage_obj() == Usages(usages=[])
|
||||||
|
assert deserialized_response.get_json_response() == model_choice.dict()
|
||||||
|
assert deserialized_response.get_response() == ["hello", "bye"]
|
||||||
|
|
||||||
int_arr = np.random.randint(20, size=(4, 4))
|
# Int type
|
||||||
response = Response(
|
response = Response(
|
||||||
{"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array"
|
response=model_choice_arr_int,
|
||||||
|
cached=False,
|
||||||
|
request=request_array,
|
||||||
|
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
|
||||||
|
request_type=EmbeddingRequest,
|
||||||
|
response_type="array",
|
||||||
)
|
)
|
||||||
deserialized_response = Response.deserialize(response.serialize())
|
deserialized_response = Response.deserialize(response.serialize())
|
||||||
|
assert deserialized_response._item_dtype == "int64"
|
||||||
|
assert (
|
||||||
|
cast(
|
||||||
|
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
|
||||||
|
).array.dtype
|
||||||
|
== np.int64
|
||||||
|
)
|
||||||
assert np.array_equal(
|
assert np.array_equal(
|
||||||
deserialized_response._response["choices"][0]["array"], int_arr
|
cast(
|
||||||
|
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
|
||||||
|
).array,
|
||||||
|
cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array,
|
||||||
)
|
)
|
||||||
assert deserialized_response.is_cached() is True
|
|
||||||
assert deserialized_response._request_params == {"request": "yoyo"}
|
|
||||||
|
|
||||||
float_arr = np.random.randn(4, 4)
|
# Float type
|
||||||
response = Response(
|
response = Response(
|
||||||
{"choices": [{"array": float_arr}]}, True, {"request": "yoyo"}, item_key="array"
|
response=model_choice_arr,
|
||||||
|
cached=False,
|
||||||
|
request=request_array,
|
||||||
|
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
|
||||||
|
request_type=EmbeddingRequest,
|
||||||
|
response_type="array",
|
||||||
)
|
)
|
||||||
deserialized_response = Response.deserialize(response.serialize())
|
deserialized_response = Response.deserialize(response.serialize())
|
||||||
|
assert deserialized_response._item_dtype == "float64"
|
||||||
|
assert (
|
||||||
|
cast(
|
||||||
|
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
|
||||||
|
).array.dtype
|
||||||
|
== np.float64
|
||||||
|
)
|
||||||
assert np.array_equal(
|
assert np.array_equal(
|
||||||
deserialized_response._response["choices"][0]["array"], float_arr
|
cast(
|
||||||
|
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
|
||||||
|
).array,
|
||||||
|
cast(ArrayModelChoice, model_choice_arr.choices[0]).array,
|
||||||
)
|
)
|
||||||
assert deserialized_response.is_cached() is True
|
|
||||||
assert deserialized_response._request_params == {"request": "yoyo"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_results() -> None:
|
def test_get_results(
|
||||||
|
model_choice: ModelChoices,
|
||||||
|
model_choice_single: ModelChoices,
|
||||||
|
model_choice_arr: ModelChoices,
|
||||||
|
request_lm: LMRequest,
|
||||||
|
request_array: EmbeddingRequest,
|
||||||
|
) -> None:
|
||||||
"""Test response get results."""
|
"""Test response get results."""
|
||||||
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
|
response = Response(
|
||||||
assert response.get_response() == "hello"
|
response=model_choice_single,
|
||||||
|
cached=False,
|
||||||
|
request=request_lm,
|
||||||
|
usages=None,
|
||||||
|
request_type=LMRequest,
|
||||||
|
response_type="text",
|
||||||
|
)
|
||||||
|
assert response.get_response() == "helloo"
|
||||||
assert response.get_response(stop_token="ll") == "he"
|
assert response.get_response(stop_token="ll") == "he"
|
||||||
assert response.get_response(stop_token="ll", is_batch=True) == ["he"]
|
assert response.get_response(stop_token="ll", is_batch=True) == ["he"]
|
||||||
|
|
||||||
response = Response(
|
response = Response(
|
||||||
{"choices": [{"text": "hello"}, {"text": "my"}, {"text": "name"}]},
|
response=model_choice,
|
||||||
True,
|
cached=False,
|
||||||
{"request": "yoyo"},
|
request=request_lm,
|
||||||
|
usages=None,
|
||||||
|
request_type=LMRequest,
|
||||||
|
response_type="text",
|
||||||
)
|
)
|
||||||
assert response.get_response() == ["hello", "my", "name"]
|
assert response.get_response() == ["hello", "bye"]
|
||||||
assert response.get_response(stop_token="m") == ["hello", "", "na"]
|
assert response.get_response(stop_token="b") == ["hello", ""]
|
||||||
assert response.get_response(stop_token="m", is_batch=True) == ["hello", "", "na"]
|
assert response.get_response(stop_token="y", is_batch=True) == ["hello", "b"]
|
||||||
|
|
||||||
float_arr = np.random.randn(4, 4)
|
float_arr1 = cast(ArrayModelChoice, model_choice_arr.choices[0]).array
|
||||||
|
float_arr2 = cast(ArrayModelChoice, model_choice_arr.choices[1]).array
|
||||||
response = Response(
|
response = Response(
|
||||||
{"choices": [{"array": float_arr}, {"array": float_arr}]},
|
response=model_choice_arr,
|
||||||
True,
|
cached=False,
|
||||||
{"request": "yoyo"},
|
request=request_array,
|
||||||
item_key="array",
|
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
|
||||||
|
request_type=EmbeddingRequest,
|
||||||
|
response_type="array",
|
||||||
)
|
)
|
||||||
assert response.get_response() == [float_arr, float_arr]
|
assert np.array_equal(response.get_response()[0], float_arr1)
|
||||||
assert response.get_response(stop_token="m") == [float_arr, float_arr]
|
assert np.array_equal(response.get_response()[1], float_arr2)
|
||||||
|
assert np.array_equal(response.get_response(stop_token="t")[0], float_arr1)
|
||||||
|
assert np.array_equal(response.get_response(stop_token="t")[1], float_arr2)
|
||||||
|
|
||||||
|
|
||||||
def test_union_all() -> None:
|
def test_union_all(
|
||||||
|
model_choice: ModelChoices,
|
||||||
|
model_choice_single: ModelChoices,
|
||||||
|
request_lm: LMRequest,
|
||||||
|
request_lm_single: LMRequest,
|
||||||
|
) -> None:
|
||||||
"""Test union all."""
|
"""Test union all."""
|
||||||
request_paramsa = LMRequest(prompt=["apple", "orange", "pear"]).to_dict()
|
response1 = Response(
|
||||||
request_paramsa["model"] = "modelA"
|
response=model_choice,
|
||||||
response_paramsa = {
|
cached=False,
|
||||||
"choices": [
|
request=request_lm,
|
||||||
{"text": "hello", "token_logprobs": [1]},
|
usages=None,
|
||||||
{"text": "hello 2", "token_logprobs": [1]},
|
request_type=LMRequest,
|
||||||
{"text": "hello 3", "token_logprobs": [1]},
|
response_type="text",
|
||||||
]
|
)
|
||||||
}
|
|
||||||
responsea = Response(response_paramsa, False, request_paramsa)
|
|
||||||
|
|
||||||
request_paramsb = LMRequest(prompt=["banana", "pineapple", "mango"]).to_dict()
|
response2 = Response(
|
||||||
request_paramsb["model"] = "modelB"
|
response=model_choice_single,
|
||||||
response_paramsb = {
|
cached=False,
|
||||||
"choices": [
|
request=request_lm_single,
|
||||||
{"text": "bye", "token_logprobs": [2]},
|
usages=None,
|
||||||
{"text": "bye 2", "token_logprobs": [2]},
|
request_type=LMRequest,
|
||||||
{"text": "bye 3", "token_logprobs": [2]},
|
response_type="text",
|
||||||
]
|
)
|
||||||
}
|
|
||||||
responseb = Response(response_paramsb, False, request_paramsb)
|
|
||||||
|
|
||||||
final_response = Response.union_all([responsea, responseb])
|
final_response = Response.union_all([response1, response2])
|
||||||
assert final_response.get_json_response() == {
|
assert final_response.get_json_response() == {
|
||||||
"choices": [
|
"choices": [
|
||||||
{"text": "hello", "token_logprobs": [1]},
|
{"text": "hello", "token_logprobs": [0.1, 0.2], "tokens": None},
|
||||||
{"text": "hello 2", "token_logprobs": [1]},
|
{"text": "bye", "token_logprobs": [0.3], "tokens": None},
|
||||||
{"text": "hello 3", "token_logprobs": [1]},
|
{"text": "helloo", "token_logprobs": [0.1, 0.2], "tokens": None},
|
||||||
{"text": "bye", "token_logprobs": [2]},
|
|
||||||
{"text": "bye 2", "token_logprobs": [2]},
|
|
||||||
{"text": "bye 3", "token_logprobs": [2]},
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
final_request = LMRequest(
|
assert final_response.get_usage_obj() == Usages(usages=[Usage(), Usage(), Usage()])
|
||||||
prompt=["apple", "orange", "pear", "banana", "pineapple", "mango"]
|
merged_prompts: List[str] = request_lm.prompt + [request_lm_single.prompt] # type: ignore # noqa: E501
|
||||||
).to_dict()
|
assert final_response.get_request_obj().prompt == merged_prompts
|
||||||
final_request["model"] = "modelA"
|
assert final_response.get_request_obj().engine == "dummy::text-ada-001"
|
||||||
assert final_response.get_request() == final_request
|
|
||||||
assert not final_response.is_cached()
|
|
||||||
|
|
||||||
# Modify A to have usage and cached
|
# Modify A to have usage and cached
|
||||||
response_paramsa_2: Dict[str, Any] = {
|
response1 = Response(
|
||||||
"choices": [
|
response=model_choice,
|
||||||
{"text": "hello", "token_logprobs": [1]},
|
cached=False,
|
||||||
{"text": "hello 2", "token_logprobs": [1]},
|
request=request_lm,
|
||||||
{"text": "hello 3", "token_logprobs": [1]},
|
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
|
||||||
],
|
request_type=LMRequest,
|
||||||
"usage": [
|
response_type="text",
|
||||||
{"completion_tokens": 10},
|
)
|
||||||
{"completion_tokens": 10},
|
|
||||||
{"completion_tokens": 10},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
responsea = Response(response_paramsa_2, True, request_paramsa)
|
|
||||||
|
|
||||||
final_response = Response.union_all([responsea, responseb])
|
final_response = Response.union_all([response1, response2])
|
||||||
assert final_response.get_json_response() == {
|
assert final_response.get_usage_obj() == Usages(
|
||||||
"choices": [
|
usages=[Usage(total_tokens=4), Usage(total_tokens=6), Usage()]
|
||||||
{"text": "hello", "token_logprobs": [1]},
|
)
|
||||||
{"text": "hello 2", "token_logprobs": [1]},
|
|
||||||
{"text": "hello 3", "token_logprobs": [1]},
|
|
||||||
{"text": "bye", "token_logprobs": [2]},
|
|
||||||
{"text": "bye 2", "token_logprobs": [2]},
|
|
||||||
{"text": "bye 3", "token_logprobs": [2]},
|
|
||||||
],
|
|
||||||
"usage": [
|
|
||||||
{"completion_tokens": 10},
|
|
||||||
{"completion_tokens": 10},
|
|
||||||
{"completion_tokens": 10},
|
|
||||||
{},
|
|
||||||
{},
|
|
||||||
{},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
final_request = LMRequest(
|
|
||||||
prompt=["apple", "orange", "pear", "banana", "pineapple", "mango"]
|
|
||||||
).to_dict()
|
|
||||||
final_request["model"] = "modelA"
|
|
||||||
assert final_response.get_request() == final_request
|
|
||||||
assert final_response.is_cached()
|
|
||||||
|
Loading…
Reference in New Issue