fix: dummy client to output tokens and random responses (#106)

pull/109/head
Laurel Orr 10 months ago committed by GitHub
parent b775d15f2e
commit 49f51952df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,10 @@
"""Dummy client."""
import hashlib
import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import tiktoken
from manifest.clients.client import Client
from manifest.request import LMChatRequest, LMRequest, LMScoreRequest, Request
@ -14,7 +18,13 @@ class DummyClient(Client):
# User param -> (client param, default value)
PARAMS = {
"n": ("num_results", 1),
"engine": ("model", "text-davinci-003"),
"temperature": ("temperature", 0.0),
"max_tokens": ("max_tokens", 10),
"n": ("n", 1),
"top_p": ("top_p", 1.0),
"top_k": ("best_of", 1),
"batch_size": ("batch_size", 20),
}
REQUEST_CLS = LMRequest
NAME = "dummy"
@ -33,6 +43,9 @@ class DummyClient(Client):
connection_str: connection string.
client_args: client arguments.
"""
# We tiktoken as it is faster than HF for tokenizing
# Use any model to create the tokenizer
self.encoder = tiktoken.get_encoding("cl100k_base")
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
@ -74,7 +87,65 @@ class DummyClient(Client):
Returns:
model params.
"""
return {"engine": "dummy"}
return {"engine": "dummy", "model": getattr(self, "engine")}
def get_mock_output(
self, output_toks: int, is_completion: bool, seed: Optional[int] = None
) -> LMModelChoice:
"""Return mock model output by generating random tokens."""
np.random.seed(seed)
random_tokens = np.random.randint(
0, self.encoder.max_token_value + 1, output_toks
)
response = self.encoder.decode(random_tokens) # type: ignore
if is_completion:
np.random.seed(seed)
random_logprobs = np.random.uniform(
low=-2, high=-0.00001, size=output_toks
).tolist()
else:
# Return all Nones to mimic chat models
# OpenAI chat models do not return logprobs
random_logprobs = [None] * output_toks
return LMModelChoice(
text=response,
token_logprobs=random_logprobs,
tokens=random_tokens.tolist(),
)
def get_mock_choices(
self,
prompt_list: List[str],
request_params: Dict,
is_completion: bool,
) -> Tuple[List[LMModelChoice], List[Usage]]:
"""Get choices and usages of mock output."""
choices = []
usages = []
for prompt in prompt_list:
num_prompt_tokens = len(self.encoder.encode(prompt))
if request_params["temperature"] == 0:
# Get integer seed from hash of prompt
seed = (
int(hashlib.sha256(prompt.encode("utf-8")).hexdigest(), 16)
% 10**8
)
else:
# Get random seed
seed = None
for _ in range(int(request_params["n"])):
choice = self.get_mock_output(
request_params["max_tokens"], is_completion=is_completion, seed=seed
)
choices.append(choice)
usages.append(
Usage(
prompt_tokens=num_prompt_tokens,
completion_tokens=request_params["max_tokens"],
total_tokens=num_prompt_tokens + request_params["max_tokens"],
)
)
return choices, usages
def run_request(self, request: Request) -> Response:
"""
@ -88,32 +159,19 @@ class DummyClient(Client):
request parameters as dict.
"""
if isinstance(request.prompt, list):
num_results = len(request.prompt)
prompt_list = request.prompt
else:
num_results = 1
prompt_list = [request.prompt]
request_params = request.to_dict(self.PARAMS)
choices, usages = self.get_mock_choices(
prompt_list, request_params, is_completion=True
)
return Response(
response=ModelChoices(
choices=[LMModelChoice(text="hello")] # type: ignore
* int(request_params["num_results"])
* num_results
),
response=ModelChoices(choices=choices), # type: ignore
cached=False,
request=request,
usages=Usages(
usages=[
Usage(
**{
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
}
)
]
* int(request_params["num_results"])
* num_results
),
usages=Usages(usages=usages),
response_type="text",
request_type=self.REQUEST_CLS,
)
@ -145,35 +203,17 @@ class DummyClient(Client):
Returns:
response.
"""
num_results = 1
response_dict = {
"choices": [
{
"text": request.prompt[0]["content"],
}
for i in range(num_results)
]
}
prompt_list = ["_".join(pmp["content"] for pmp in request.prompt)]
request_params = request.to_dict(self.PARAMS)
choices, usages = self.get_mock_choices(
prompt_list, request_params, is_completion=False
)
return Response(
response=ModelChoices(
choices=[
LMModelChoice(**choice) # type: ignore
for choice in response_dict["choices"]
]
),
response=ModelChoices(choices=choices), # type: ignore
cached=False,
request=request,
usages=Usages(
usages=[
Usage(
**{
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
}
)
]
),
usages=Usages(usages=usages),
response_type="text",
request_type=LMChatRequest,
)
@ -193,30 +233,19 @@ class DummyClient(Client):
request parameters as dict.
"""
if isinstance(request.prompt, list):
num_results = len(request.prompt)
prompt_list = request.prompt
else:
num_results = 1
response_dict = {
"choices": [
{
"text": request.prompt
if isinstance(request.prompt, str)
else request.prompt[i],
"token_logprobs": [0.3],
}
for i in range(num_results)
]
}
prompt_list = [request.prompt]
request_params = request.to_dict(self.PARAMS)
choices, usages = self.get_mock_choices(
prompt_list, request_params, is_completion=True
)
return Response(
response=ModelChoices(
choices=[
LMModelChoice(**choice) # type: ignore
for choice in response_dict["choices"]
]
),
response=ModelChoices(choices=choices), # type: ignore
cached=False,
request=request,
usages=None,
usages=Usages(usages=usages),
response_type="text",
request_type=LMScoreRequest,
)

@ -53,7 +53,7 @@ class LMModelChoice(BaseModel):
"""Model single completion."""
text: str
token_logprobs: Optional[List[float]] = None
token_logprobs: Optional[List[Optional[float]]] = None
tokens: Optional[List[str]] = None

@ -19,8 +19,19 @@ def test_init() -> None:
def test_get_params() -> None:
"""Test get param functions."""
client = DummyClient(connection_str=None)
assert client.get_model_params() == {"engine": "dummy"}
assert client.get_model_inputs() == ["n"]
assert client.get_model_params() == {
"engine": "dummy",
"model": "text-davinci-003",
}
assert client.get_model_inputs() == [
"engine",
"temperature",
"max_tokens",
"n",
"top_p",
"top_k",
"batch_size",
]
def test_get_request() -> None:
@ -31,43 +42,148 @@ def test_get_request() -> None:
response = client.run_request(request_params)
assert client.get_cache_key(request_params) == {
"prompt": "hello",
"num_results": 3,
"model": "text-davinci-003",
"n": 3,
"temperature": 0.0,
"max_tokens": 10,
"top_p": 1.0,
"best_of": 1,
"engine": "dummy",
"request_cls": "LMRequest",
}
assert response.get_json_response() == {
"choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 3,
"choices": [
{
"text": " probsuib.FirstName>- commodityting segunda inserted signals Religious", # noqa: E501
"token_logprobs": [
-0.2649905035732101,
-1.210794839387105,
-1.2173929801003434,
-0.7758233850171001,
-0.7165940659570416,
-1.7430328887209088,
-1.5379414228820203,
-1.7838011423472508,
-1.139095076944217,
-0.6321855879833425,
],
"tokens": [
"70470",
"80723",
"52693",
"39743",
"38983",
"1303",
"56072",
"22306",
"17738",
"53176",
],
}
]
* 3
}
assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 3,
"usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}]
* 3,
}
request_params = client.get_request("hello", {"n": 5})
response = client.run_request(request_params)
assert client.get_cache_key(request_params) == {
"prompt": "hello",
"num_results": 5,
"model": "text-davinci-003",
"n": 5,
"temperature": 0.0,
"max_tokens": 10,
"top_p": 1.0,
"best_of": 1,
"engine": "dummy",
"request_cls": "LMRequest",
}
assert response.get_json_response() == {
"choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5,
"choices": [
{
"text": " probsuib.FirstName>- commodityting segunda inserted signals Religious", # noqa: E501
"token_logprobs": [
-0.2649905035732101,
-1.210794839387105,
-1.2173929801003434,
-0.7758233850171001,
-0.7165940659570416,
-1.7430328887209088,
-1.5379414228820203,
-1.7838011423472508,
-1.139095076944217,
-0.6321855879833425,
],
"tokens": [
"70470",
"80723",
"52693",
"39743",
"38983",
"1303",
"56072",
"22306",
"17738",
"53176",
],
}
]
* 5
}
assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
"usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}]
* 5,
}
request_params = client.get_request(["hello"] * 5, {"n": 1})
response = client.run_request(request_params)
assert client.get_cache_key(request_params) == {
"prompt": ["hello"] * 5,
"num_results": 1,
"model": "text-davinci-003",
"n": 1,
"temperature": 0.0,
"max_tokens": 10,
"top_p": 1.0,
"best_of": 1,
"engine": "dummy",
"request_cls": "LMRequest",
}
assert response.get_json_response() == {
"choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5,
"choices": [
{
"text": " probsuib.FirstName>- commodityting segunda inserted signals Religious", # noqa: E501
"token_logprobs": [
-0.2649905035732101,
-1.210794839387105,
-1.2173929801003434,
-0.7758233850171001,
-0.7165940659570416,
-1.7430328887209088,
-1.5379414228820203,
-1.7838011423472508,
-1.139095076944217,
-0.6321855879833425,
],
"tokens": [
"70470",
"80723",
"52693",
"39743",
"38983",
"1303",
"56072",
"22306",
"17738",
"53176",
],
}
]
* 5
}
assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
"usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}]
* 5,
}

@ -73,6 +73,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
cache_name="sqlite",
cache_connection=sqlite_cache,
n=n,
temperature=0.0,
)
prompt = "This is a prompt"
@ -80,8 +81,6 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
result = manifest.run(prompt, return_response=return_response, bad_input=5)
assert str(exc_info.value) == "[('bad_input', 5)] arguments are not recognized."
# Allow params in the request object but not in the client to go through
assert "top_k" not in manifest.client_pool.get_next_client().PARAMS
result = manifest.run(prompt, return_response=return_response, top_k=5)
assert result is not None
@ -96,21 +95,30 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
res = result.get_response(manifest.stop_token)
else:
res = cast(str, result)
assert (
manifest.cache.get(
{
"prompt": "This is a prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "This is a prompt",
"request_cls": "LMRequest",
"num_results": n,
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is not None
)
if n == 1:
assert res == "hello"
assert res == "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines"
else:
assert res == ["hello", "hello"]
assert res == [
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
]
prompt = "This is a prompt"
result = manifest.run(prompt, run_id="34", return_response=return_response)
@ -126,19 +134,27 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
assert (
manifest.cache.get(
{
"prompt": "This is a prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "This is a prompt",
"request_cls": "LMRequest",
"num_results": n,
"temperature": 0.0,
"top_p": 1.0,
"run_id": "34",
}
)
is not None
)
if n == 1:
assert res == "hello"
assert res == "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines"
else:
assert res == ["hello", "hello"]
assert res == [
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
]
prompt = "Hello is a prompt"
result = manifest.run(prompt, return_response=return_response)
@ -154,45 +170,60 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
assert (
manifest.cache.get(
{
"prompt": "Hello is a prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "Hello is a prompt",
"request_cls": "LMRequest",
"num_results": n,
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is not None
)
if n == 1:
assert res == "hello"
assert res == "appersstoff210 currentNodeleh norm unified_voice DIYHam"
else:
assert res == ["hello", "hello"]
assert res == [
"appersstoff210 currentNodeleh norm unified_voice DIYHam",
"appersstoff210 currentNodeleh norm unified_voice DIYHam",
]
prompt = "Hello is a prompt"
result = manifest.run(prompt, stop_token="ll", return_response=return_response)
result = manifest.run(
prompt, stop_token=" current", return_response=return_response
)
if return_response:
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_usage_obj().usages) == len(
result.get_response_obj().choices
)
res = result.get_response(stop_token="ll")
res = result.get_response(stop_token=" current")
else:
res = cast(str, result)
assert (
manifest.cache.get(
{
"prompt": "Hello is a prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "Hello is a prompt",
"request_cls": "LMRequest",
"num_results": n,
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is not None
)
if n == 1:
assert res == "he"
assert res == "appersstoff210"
else:
assert res == ["he", "he"]
assert res == ["appersstoff210", "appersstoff210"]
@pytest.mark.usefixtures("sqlite_cache")
@ -205,6 +236,7 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
cache_name="sqlite",
cache_connection=sqlite_cache,
n=n,
temperature=0.0,
)
prompt = ["This is a prompt"]
if n == 2:
@ -222,15 +254,20 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
res = result.get_response(manifest.stop_token, is_batch=True)
else:
res = cast(str, result)
assert res == ["hello"]
assert res == ["Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines"]
assert (
manifest.cache.get(
{
"prompt": "This is a prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "This is a prompt",
"request_cls": "LMRequest",
"num_results": n,
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is not None
)
@ -246,15 +283,23 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
res = result.get_response(manifest.stop_token, is_batch=True)
else:
res = cast(str, result)
assert res == ["hello", "hello"]
assert res == [
"appersstoff210 currentNodeleh norm unified_voice DIYHam",
"appersstoff210 currentNodeleh norm unified_voice DIYHam",
]
assert (
manifest.cache.get(
{
"prompt": "Hello is a prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "Hello is a prompt",
"request_cls": "LMRequest",
"num_results": n,
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is not None
)
@ -266,11 +311,16 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
assert (
manifest.cache.get(
{
"prompt": "New prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "New prompt",
"request_cls": "LMRequest",
"num_results": n,
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is None
)
@ -287,20 +337,25 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
assert result.is_cached()
else:
res = cast(str, result)
assert res == ["hello", "hello"]
assert res == [
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
".vol.deserializebigmnchantment ROTıl='')\najsС",
]
prompt = ["Hello is a prompt", "Hello is a prompt"]
result = manifest.run(prompt, stop_token="ll", return_response=return_response)
result = manifest.run(
prompt, stop_token=" current", return_response=return_response
)
if return_response:
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_usage_obj().usages) == len(
result.get_response_obj().choices
)
res = result.get_response(stop_token="ll", is_batch=True)
res = result.get_response(stop_token=" current", is_batch=True)
else:
res = cast(str, result)
assert res == ["he", "he"]
assert res == ["appersstoff210", "appersstoff210"]
@pytest.mark.usefixtures("sqlite_cache")
@ -310,6 +365,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
client_name="dummy",
cache_name="sqlite",
cache_connection=sqlite_cache,
temperature=0.0,
)
prompt = ["This is a prompt"]
result = cast(
@ -318,15 +374,20 @@ def test_abatch_run(sqlite_cache: str) -> None:
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response(manifest.stop_token, is_batch=True)
assert res == ["hello"]
assert res == ["Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines"]
assert (
manifest.cache.get(
{
"prompt": "This is a prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": "This is a prompt",
"request_cls": "LMRequest",
"num_results": 1,
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is not None
)
@ -338,15 +399,23 @@ def test_abatch_run(sqlite_cache: str) -> None:
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response(manifest.stop_token, is_batch=True)
assert res == ["hello", "hello"]
assert res == [
"appersstoff210 currentNodeleh norm unified_voice DIYHam",
"appersstoff210 currentNodeleh norm unified_voice DIYHam",
]
assert (
manifest.cache.get(
{
"prompt": "Hello is a prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": "Hello is a prompt",
"request_cls": "LMRequest",
"num_results": 1,
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is not None
)
@ -362,11 +431,16 @@ def test_abatch_run(sqlite_cache: str) -> None:
assert (
manifest.cache.get(
{
"prompt": "New prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": "New prompt",
"request_cls": "LMRequest",
"num_results": 1,
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is None
)
@ -379,7 +453,10 @@ def test_abatch_run(sqlite_cache: str) -> None:
res = result.get_response(manifest.stop_token, is_batch=True)
# Cached because one item is in cache
assert result.is_cached()
assert res == ["hello", "hello"]
assert res == [
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
".vol.deserializebigmnchantment ROTıl='')\najsС",
]
prompt = ["Hello is a prompt", "Hello is a prompt"]
result = cast(
@ -387,8 +464,8 @@ def test_abatch_run(sqlite_cache: str) -> None:
)
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response(stop_token="ll", is_batch=True)
assert res == ["he", "he"]
res = result.get_response(stop_token=" current", is_batch=True)
assert res == ["appersstoff210", "appersstoff210"]
@pytest.mark.usefixtures("sqlite_cache")
@ -398,6 +475,7 @@ def test_run_chat(sqlite_cache: str) -> None:
client_name="dummy",
cache_name="sqlite",
cache_connection=sqlite_cache,
temperature=0.0,
)
# Set CHAT to be true for this model
manifest.client_pool.client_pool[0].IS_CHAT = True
@ -406,15 +484,23 @@ def test_run_chat(sqlite_cache: str) -> None:
{"role": "system", "content": "Hello."},
]
result = manifest.run(prompt, return_response=False)
assert result == "Hello."
assert (
result
== "ectors WortGo ré_sg|--------------------------------------------------------------------------\n contradictory Aad \u200b getUserId" # noqa: E501
)
assert (
manifest.cache.get(
{
"prompt": [{"content": "Hello.", "role": "system"}],
"best_of": 1,
"engine": "dummy",
"num_results": 1,
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": [{"content": "Hello.", "role": "system"}],
"request_cls": "LMChatRequest",
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is not None
)
@ -428,18 +514,23 @@ def test_run_chat(sqlite_cache: str) -> None:
result = cast(Response, result)
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response()
assert res == "Hello."
assert res == "_deploy_age_gp hora Plus Scheduler EisenhowerRF视 chemotherapy"
assert (
manifest.cache.get(
{
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": [
{"role": "system", "content": "Hello."},
{"role": "user", "content": "Goodbye?"},
],
"engine": "dummy",
"num_results": 1,
"request_cls": "LMChatRequest",
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is not None
)
@ -452,6 +543,7 @@ def test_score_run(sqlite_cache: str) -> None:
client_name="dummy",
cache_name="sqlite",
cache_connection=sqlite_cache,
temperature=0.0,
)
prompt = "This is a prompt"
@ -459,33 +551,68 @@ def test_score_run(sqlite_cache: str) -> None:
assert (
manifest.cache.get(
{
"prompt": "This is a prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": "This is a prompt",
"request_cls": "LMScoreRequest",
"num_results": 1,
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is not None
)
assert result == {
"response": {
"choices": [
{"text": "This is a prompt", "token_logprobs": [0.3], "tokens": None}
{
"text": "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
"token_logprobs": [
-1.827188890438529,
-1.6981601736417915,
-0.24606708391178755,
-1.9209383499010613,
-0.8833563758318617,
-1.4121369466920703,
-0.376352908076236,
-1.3200064558188096,
-0.813028447207917,
-0.5977255311239729,
],
"tokens": [
"46078",
"21445",
"48305",
"7927",
"76125",
"46233",
"34581",
"23679",
"63021",
"78158",
],
}
]
},
"usages": {
"usages": [
{"completion_tokens": 10, "prompt_tokens": 4, "total_tokens": 14}
]
},
"usages": {"usages": []},
"cached": False,
"request": {
"prompt": "This is a prompt",
"engine": "text-ada-001",
"engine": "text-davinci-003",
"n": 1,
"client_timeout": 60,
"run_id": None,
"batch_size": 8,
"temperature": 0.7,
"max_tokens": 100,
"batch_size": 20,
"temperature": 0.0,
"max_tokens": 10,
"top_p": 1.0,
"top_k": 50,
"top_k": 1,
"logprobs": None,
"stop_sequences": None,
"num_beams": 1,
@ -505,49 +632,112 @@ def test_score_run(sqlite_cache: str) -> None:
assert (
manifest.cache.get(
{
"prompt": "Hello is a prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": "Hello is a prompt",
"request_cls": "LMScoreRequest",
"num_results": 1,
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is not None
)
assert (
manifest.cache.get(
{
"prompt": "Hello is another prompt",
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": "Hello is another prompt",
"request_cls": "LMScoreRequest",
"num_results": 1,
},
"temperature": 0.0,
"top_p": 1.0,
}
)
is not None
)
assert result == {
"response": {
"choices": [
{"text": "Hello is a prompt", "token_logprobs": [0.3], "tokens": None},
{
"text": "Hello is another prompt",
"token_logprobs": [0.3],
"tokens": None,
"text": "appersstoff210 currentNodeleh norm unified_voice DIYHam",
"token_logprobs": [
-0.5613340599860608,
-1.2822870706137146,
-1.9909319620162806,
-0.6312373658222814,
-1.9066239705571664,
-1.2420939968397082,
-0.7208735169940805,
-1.9144266963723062,
-0.041181937860757856,
-0.5356282450367043,
],
"tokens": [
"28921",
"81056",
"8848",
"47399",
"74890",
"7617",
"43790",
"77865",
"32558",
"41041",
],
},
{
"text": ".addAttribute_size DE imageUrl_datas\tapFixed(hour setups\tcomment", # noqa: E501
"token_logprobs": [
-1.1142500072582333,
-0.819706434396527,
-1.9956443391600693,
-0.8425896744807639,
-1.8398050571245623,
-1.912564137256891,
-1.6677665162080606,
-1.1579612203844727,
-1.9876114502998343,
-0.2698297864722319,
],
"tokens": [
"26300",
"2424",
"3467",
"40749",
"47630",
"70998",
"13829",
"72135",
"84823",
"97368",
],
},
]
},
"usages": {
"usages": [
{"completion_tokens": 10, "prompt_tokens": 4, "total_tokens": 14},
{"completion_tokens": 10, "prompt_tokens": 4, "total_tokens": 14},
]
},
"usages": {"usages": []},
"cached": False,
"request": {
"prompt": ["Hello is a prompt", "Hello is another prompt"],
"engine": "text-ada-001",
"engine": "text-davinci-003",
"n": 1,
"client_timeout": 60,
"run_id": None,
"batch_size": 8,
"temperature": 0.7,
"max_tokens": 100,
"batch_size": 20,
"temperature": 0.0,
"max_tokens": 10,
"top_p": 1.0,
"top_k": 50,
"top_k": 1,
"logprobs": None,
"stop_sequences": None,
"num_beams": 1,

Loading…
Cancel
Save