feat: support token_logprobs (#59)

pull/82/head
Laurel Orr 1 year ago committed by GitHub
parent c6331770d4
commit c4ad007f02

File diff suppressed because one or more lines are too long

@ -2,14 +2,22 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: TOMA_URL=<TOMA_URL>\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"%env TOMA_URL=https://staging.together.xyz/api"
"%env TOMA_URL=<TOMA_URL>"
]
},
{

@ -173,10 +173,16 @@ def completions() -> Response:
result_gens.append(generations)
if model_type == "diffuser":
# Assign None logprob as it's not supported in diffusers
results = [{"array": r[0], "logprob": None} for r in result_gens]
results = [
{"array": r[0], "logprob": None, "token_logprobs": None}
for r in result_gens
]
res_type = "image_generation"
else:
results = [{"text": r[0], "logprob": r[1]} for r in result_gens]
results = [
{"text": r[0], "logprob": r[1], "token_logprobs": r[2]}
for r in result_gens
]
res_type = "text_completion"
# transform the result into the openai format
return Response(
@ -232,7 +238,11 @@ def score_sequence() -> Response:
try:
score_list = model.score_sequence(prompt, **generation_args)
results = [
{"text": prompt if isinstance(prompt, str) else prompt[i], "logprob": r}
{
"text": prompt if isinstance(prompt, str) else prompt[i],
"logprob": r[0],
"token_logprobs": r[1],
}
for i, r in enumerate(score_list)
]
# transform the result into the openai format

@ -74,7 +74,7 @@ class DiffuserModel(Model):
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float]]:
) -> List[Tuple[Any, float, List[float]]]:
"""
Generate the prompt from model.
@ -90,13 +90,13 @@ class DiffuserModel(Model):
if isinstance(prompt, str):
prompt = [prompt]
result = self.pipeline(prompt, output_type="np.array", **kwargs)
# Return None for logprobs
return [(im, None) for im in result["images"]]
# Return None for logprobs and token logprobs
return [(im, None, None) for im in result["images"]]
@torch.no_grad()
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[float]:
) -> List[Tuple[float, List[float]]]:
"""
Score a sequence of choices.

@ -547,7 +547,7 @@ class TextGenerationModel(HuggingFaceModel):
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float]]:
) -> List[Tuple[Any, float, List[float]]]:
"""
Generate the prompt from model.
@ -573,7 +573,11 @@ class TextGenerationModel(HuggingFaceModel):
num_return_sequences=num_return,
)
final_results = [
(cast(str, r["generated_text"]), sum(cast(List[float], r["logprobs"])))
(
cast(str, r["generated_text"]),
sum(cast(List[float], r["logprobs"])),
cast(List[float], r["logprobs"]),
)
for r in result
]
return final_results
@ -581,7 +585,7 @@ class TextGenerationModel(HuggingFaceModel):
@torch.no_grad()
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[float]:
) -> List[Tuple[float, List[float]]]:
"""
Score a sequence of choices.
@ -618,4 +622,9 @@ class TextGenerationModel(HuggingFaceModel):
)
seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
seq_log_prob = seq_token_log_probs.sum(dim=-1)
return seq_log_prob.tolist()
return [
(seq, seq_token)
for seq, seq_token in zip(
seq_log_prob.tolist(), seq_token_log_probs.tolist()
)
]

@ -48,7 +48,7 @@ class Model(ABC):
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float]]:
) -> List[Tuple[Any, float, List[float]]]:
"""
Generate the prompt from model.
@ -59,6 +59,8 @@ class Model(ABC):
Returns:
list of generated text (list of length 1 for 1 generation).
Each item is the response, answer logprob,
and list of logprobs for each token.
"""
raise NotImplementedError()
@ -76,7 +78,7 @@ class Model(ABC):
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[float]:
) -> List[Tuple[float, List[float]]]:
"""
Score a sequence of choices.
@ -85,5 +87,8 @@ class Model(ABC):
The prompt to score the choices against.
**kwargs:
Additional keyword arguments passed along to the :obj:`__call__` method.
Returns:
Tuple of scores for each choice and logprobs for the tokens of each choice.
"""
raise NotImplementedError()

@ -36,6 +36,7 @@ class ModelResponse:
{
key: result[key],
"logprob": result["logprob"],
"token_logprobs": result["token_logprobs"],
}
if key == "text"
else {

@ -8,12 +8,12 @@ from manifest.response import Response
RESPONSE_CONSTRUCTORS = {
"diffuser": {
"generation_key": "choices",
"logits_key": "logprobs",
"logits_key": "token_logprobs",
"item_key": "array",
},
"tomadiffuser": {
"generation_key": "choices",
"logits_key": "logprobs",
"logits_key": "token_logprobs",
"item_key": "array",
},
}

@ -110,7 +110,7 @@ class AI21Client(Client):
"choices": [
{
"text": item["data"]["text"],
"logprobs": item["data"]["tokens"],
"token_logprobs": item["data"]["tokens"],
}
for item in response["completions"]
],

@ -110,7 +110,7 @@ class CohereClient(Client):
{
"text": item["text"],
"text_logprob": item.get("likelihood", None),
"logprobs": item.get("token_likelihoods", None),
"token_logprobs": item.get("token_likelihoods", None),
}
for item in response["generations"]
],

@ -158,7 +158,7 @@ class TOMAClient(Client):
"choices": [
{
"text": item["text"],
# "logprobs": [],
# "token_logprobs": [],
}
for item in response["output"]["choices"]
],

@ -24,7 +24,7 @@ class Response:
cached: bool,
request_params: Dict,
generation_key: str = "choices",
logits_key: str = "logprobs",
logits_key: str = "token_logprobs",
item_key: str = "text",
):
"""
@ -68,8 +68,8 @@ class Response:
self._response[self.generation_key][0][self.logits_key], list
):
raise ValueError(
"Response must be serialized to a dict with a "
"list of choices with logprobs field"
f"{self.logits_key} must be a list of items "
"one for each token in the choice."
)
if isinstance(
self._response[self.generation_key][0][self.item_key], np.ndarray

@ -141,8 +141,10 @@ def test_gpt_score() -> None:
result = model.score_sequence(inputs)
assert result is not None
assert len(result) == 2
assert math.isclose(round(result[0], 3), -19.935)
assert math.isclose(round(result[1], 3), -45.831)
assert math.isclose(round(result[0][0], 3), -19.935)
assert math.isclose(round(result[1][0], 3), -45.831)
assert isinstance(result[0][1], list)
assert isinstance(result[1][1], list)
def test_batch_gpt_generate() -> None:

@ -279,7 +279,7 @@ def test_score_run(sqlite_cache: str) -> None:
)
assert result == {
"generation_key": "choices",
"logits_key": "logprobs",
"logits_key": "token_logprobs",
"item_key": "text",
"item_dtype": None,
"response": {"choices": [{"text": "This is a prompt", "logprob": 0.3}]},
@ -303,7 +303,7 @@ def test_score_run(sqlite_cache: str) -> None:
)
assert result == {
"generation_key": "choices",
"logits_key": "logprobs",
"logits_key": "token_logprobs",
"item_key": "text",
"item_dtype": None,
"response": {
@ -338,7 +338,7 @@ def test_log_query(session_cache: str) -> None:
"generation_key": "choices",
"item_dtype": None,
"item_key": "text",
"logits_key": "logprobs",
"logits_key": "token_logprobs",
}
assert manifest.get_last_queries(1) == [("This is a prompt", "hello")]
assert manifest.get_last_queries(1, return_raw_values=True) == [
@ -361,7 +361,7 @@ def test_log_query(session_cache: str) -> None:
"generation_key": "choices",
"item_dtype": None,
"item_key": "text",
"logits_key": "logprobs",
"logits_key": "token_logprobs",
"request_params": query_key,
"response": {"choices": [{"text": "hello"}, {"text": "hello"}]},
}

Loading…
Cancel
Save