feat: remove choice logits and use prompt scoring (#50)

pull/82/head
Laurel Orr 1 year ago committed by GitHub
parent 876d27bd2d
commit 94b57a6e6f

@ -206,29 +206,26 @@ def embed() -> Dict:
return {"result": results}
@app.route("/choice_logits", methods=["POST"])
def choice_logits() -> Response:
"""Get maximal likely choice via max logits after generation."""
@app.route("/score_sequence", methods=["POST"])
def score_sequence() -> Response:
"""Get logprob of prompt."""
prompt = request.json["prompt"]
del request.json["prompt"]
gold_choices = request.json["gold_choices"]
del request.json["gold_choices"]
generation_args = request.json
if not isinstance(prompt, (str, list)):
raise ValueError("Prompt must be a str or list of str")
if not isinstance(gold_choices, list):
raise ValueError("Gold choices must be a list of string choices")
try:
choice_score_list = model.logits_scoring(
prompt, gold_choices, **generation_args
)
results = [{"text": r[0], "text_logprob": r[1]} for r in choice_score_list]
score_list = model.score_sequence(prompt, **generation_args)
results = [
{"text": prompt if isinstance(prompt, str) else prompt[i], "logprob": r}
for i, r in enumerate(score_list)
]
# transform the result into the openai format
return Response(
json.dumps(
ModelResponse(results, response_type="choice_selection").__dict__()
ModelResponse(results, response_type="prompt_logit_score").__dict__()
),
status=200,
)

@ -3,6 +3,7 @@ import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import deepspeed
import numpy as np
import PIL
import torch
@ -24,7 +25,6 @@ from transformers import (
PreTrainedTokenizer,
)
import deepspeed
from manifest.api.models.model import Model
MODEL_REGISTRY = {

@ -72,17 +72,16 @@ class Model(ABC):
"""
raise NotImplementedError()
def logits_scoring(
self, prompt: Union[str, List[str]], gold_choices: List[str], **kwargs: Any
) -> List[Tuple[Any, float]]:
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[float]:
"""
Given the prompt and gold choices, choose the best choice with max logits.
Score a sequence of choices.
Args:
prompt: promt to generate from.
gold_choices: list of choices to choose from.
Returns:
the returned gold choice and the score.
prompt (:obj:`str` or :obj:`List[str]`):
The prompt to score the choices against.
**kwargs:
Additional keyword arguments passed along to the :obj:`__call__` method.
"""
raise NotImplementedError()

@ -14,12 +14,12 @@ class ModelResponse:
self.response_type = response_type
if self.response_type not in {
"text_completion",
"choice_selection",
"prompt_logit_score",
"image_generation",
}:
raise ValueError(
f"Invalid response type: {self.response_type}. "
"Must be one of: text_completion, choice_selection."
"Must be one of: text_completion, prompt_logit_score, image_generation."
)
self.response_id = str(uuid.uuid4())
self.created = int(time.time())

@ -167,16 +167,14 @@ class Client(ABC):
return _run_completion, request_params
def get_choice_logit_request(
def get_score_prompt_request(
self,
gold_choices: List[str],
request: Request,
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function for choosing max choices.
Get the logit score of the prompt via a forward pass of the model.
Args:
gold_choices: choices for model to choose from via max logits.
request: request.
Returns:
@ -184,5 +182,5 @@ class Client(ABC):
request parameters as dict.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support choice logit request."
f"{self.__class__.__name__} does not support prompt scoring request."
)

@ -1,6 +1,6 @@
"""Dummy client."""
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple
from manifest.clients.client import Client
from manifest.request import LMRequest, Request
@ -93,16 +93,14 @@ class DummyClient(Client):
return _run_completion, request_params
def get_choice_logit_request(
def get_score_prompt_request(
self,
gold_choices: List[str],
request: Request,
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function for choosing max choices.
Get the logit score of the prompt via a forward pass of the model.
Args:
gold_choices: choices for model to choose from via max logits.
request: request.
Returns:
@ -113,9 +111,19 @@ class DummyClient(Client):
num_results = len(request.prompt)
else:
num_results = 1
request_params = {"prompt": request.prompt, "gold_choices": gold_choices}
request_params = {"prompt": request.prompt}
def _run_completion() -> Dict:
return {"choices": [{"text": gold_choices[0]}] * num_results}
return {
"choices": [
{
"text": request.prompt
if isinstance(request.prompt, str)
else request.prompt[i],
"logprob": 0.3,
}
for i in range(num_results)
]
}
return _run_completion, request_params

@ -1,6 +1,6 @@
"""Hugging Face client."""
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple
import requests
@ -78,16 +78,14 @@ class HuggingFaceClient(Client):
res = requests.post(self.host + "/params")
return res.json()
def get_choice_logit_request(
def get_score_prompt_request(
self,
gold_choices: List[str],
request: Request,
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function for choosing max choices.
Get the logit score of the prompt via a forward pass of the model.
Args:
gold_choices: choices for model to choose from via max logits.
request: request.
Returns:
@ -97,10 +95,10 @@ class HuggingFaceClient(Client):
request_params = request.to_dict(self.PARAMS)
retry_timeout = request_params.pop("client_timeout")
# Do not add params like we do with request as the model isn't sampling
request_params = {"prompt": request.prompt, "gold_choices": gold_choices}
request_params = {"prompt": request.prompt}
def _run_completion() -> Dict:
post_str = self.host + "/choice_logits"
post_str = self.host + "/score_sequence"
try:
res = requests.post(
post_str,

@ -1,6 +1,6 @@
"""Manifest class."""
import logging
from typing import Any, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import numpy as np
@ -16,6 +16,7 @@ from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.openai import OpenAIClient
from manifest.clients.toma import TOMAClient
from manifest.clients.toma_diffuser import TOMADiffuserClient
from manifest.request import Request
from manifest.response import Response
from manifest.session import Session
@ -145,10 +146,33 @@ class Manifest:
if stop_token is not None:
self.stop_token = stop_token
def _validate_kwargs(self, kwargs: Dict, request_params: Request) -> None:
"""Validate kwargs.
Args:
kwargs: kwargs to validate.
request_params: request object to validate against.
"""
# Check for invalid kwargs
non_request_kwargs = [
(k, v) for k, v in kwargs.items() if k not in request_params.__dict__
]
if len(non_request_kwargs) > 0:
raise ValueError(
f"{list(non_request_kwargs)} arguments are not recognized."
)
# Warn for valid but unused kwargs
request_unused_kwargs = [
(k, v) for k, v in kwargs.items() if k not in non_request_kwargs
]
if len(request_unused_kwargs) > 0:
logger.warning(f"{list(request_unused_kwargs)} arguments are unused.")
return
def run(
self,
prompt: Union[str, List[str]],
gold_choices: Optional[List[str]] = None,
overwrite_cache: bool = False,
run_id: Optional[str] = None,
stop_token: Optional[str] = None,
@ -160,7 +184,6 @@ class Manifest:
Args:
prompt: prompt(s) to run.
gold_choices: gold choices for max logit response (only HF models).
overwrite_cache: whether to overwrite cache.
run_id: run id for cache to repeat same run.
stop_token: stop token for prompt generation.
@ -179,31 +202,9 @@ class Manifest:
# Avoid nested list of results - enforce n = 1 for batch
if is_batch and request_params.n > 1:
raise ValueError("Batch mode does not support n > 1.")
if gold_choices is None:
possible_request, full_kwargs = self.client.get_request(request_params)
else:
try:
possible_request, full_kwargs = cast(
HuggingFaceClient, self.client
).get_choice_logit_request(gold_choices, request_params)
except AttributeError:
raise ValueError("`gold_choices` only supported for HF models.")
possible_request, full_kwargs = self.client.get_request(request_params)
# Check for invalid kwargs
non_request_kwargs = [
(k, v) for k, v in kwargs.items() if k not in request_params.__dict__
]
if len(non_request_kwargs) > 0:
raise ValueError(
f"{list(non_request_kwargs)} arguments are not recognized."
)
# Warn for valid but unused kwargs
request_unused_kwargs = [
(k, v) for k, v in kwargs.items() if k not in non_request_kwargs
]
if len(request_unused_kwargs) > 0:
logger.warning(f"{list(request_unused_kwargs)} arguments are unused.")
self._validate_kwargs(kwargs, request_params)
# Create cacke key
cache_key = full_kwargs.copy()
# Make query model dependent
@ -220,6 +221,47 @@ class Manifest:
else:
return response_obj.get_response(stop_token, is_batch)
def score_prompt(
self,
prompt: Union[str, List[str]],
overwrite_cache: bool = False,
**kwargs: Any,
) -> Dict:
"""
Score the prompt via forward pass of the model - no sampling or generation.
Returns the response object with logits of the prompt.
Prompt scoring is not part of a session cache.
Args:
prompt: prompt(s) to run.
overwrite_cache: whether to overwrite cache.
Returns:
response from prompt.
"""
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self.client.get_request_params(prompt, kwargs)
if request_params.n > 1:
raise ValueError("Sequence scoring does not support n > 1.")
try:
possible_request, full_kwargs = cast(
HuggingFaceClient, self.client
).get_score_prompt_request(request_params)
except AttributeError:
raise ValueError("`score_prompt` only supported for HF models.")
self._validate_kwargs(kwargs, request_params)
# Create cacke key
cache_key = full_kwargs.copy()
# Make query model dependent
cache_key.update(self.client.get_model_params())
response_obj = self.cache.get(cache_key, overwrite_cache, possible_request)
return response_obj.to_dict()
def get_last_queries(
self,
last_n: int = -1,

@ -254,11 +254,7 @@ def test_batch_run(
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("session_cache")
@pytest.mark.parametrize("return_response", [True, False])
def test_choices_run(
sqlite_cache: str, session_cache: str, return_response: bool
) -> None:
def test_score_run(sqlite_cache: str) -> None:
"""Test manifest run."""
manifest = Manifest(
client_name="dummy",
@ -267,71 +263,13 @@ def test_choices_run(
)
prompt = "This is a prompt"
# Dummy client will always return first choice
choices = ["cat", "dog"]
result = manifest.run(prompt, gold_choices=choices, return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = cast(Response, result).get_response(manifest.stop_token)
else:
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
{
"prompt": "This is a prompt",
"gold_choices": ["cat", "dog"],
"engine": "dummy",
},
sort_keys=True,
)
)
is not None
)
assert res == "cat"
result = manifest.score_prompt(prompt)
prompt = "Hello is a prompt"
choices = ["cat", "dog"]
result = manifest.run(prompt, gold_choices=choices, return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = cast(Response, result).get_response(manifest.stop_token)
else:
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
{
"prompt": "Hello is a prompt",
"gold_choices": ["cat", "dog"],
"engine": "dummy",
},
sort_keys=True,
)
)
is not None
)
assert res == "cat"
prompt = "Hello is a prompt"
choices = ["callt", "dog"]
result = manifest.run(
prompt,
gold_choices=choices,
stop_token="ll",
return_response=return_response,
)
if return_response:
assert isinstance(result, Response)
res = cast(Response, result).get_response(stop_token="ll")
else:
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
{
"prompt": "Hello is a prompt",
"gold_choices": ["cat", "dog"],
"prompt": "This is a prompt",
"engine": "dummy",
},
sort_keys=True,
@ -339,27 +277,23 @@ def test_choices_run(
)
is not None
)
assert res == "ca"
assert result == {
"generation_key": "choices",
"logits_key": "logprobs",
"item_key": "text",
"item_dtype": None,
"response": {"choices": [{"text": "This is a prompt", "logprob": 0.3}]},
"cached": False,
"request_params": {"prompt": "This is a prompt", "engine": "dummy"},
}
prompt_lst = ["Hello is a prompt", "Hello is a prompt"]
choices = ["callt", "dog"]
result = manifest.run(
prompt_lst,
gold_choices=choices,
stop_token="ll",
return_response=return_response,
)
if return_response:
assert isinstance(result, Response)
res = cast(Response, result).get_response(stop_token="ll", is_batch=True)
else:
res = cast(str, result)
prompt_list = ["Hello is a prompt", "Hello is another prompt"]
result = manifest.score_prompt(prompt_list)
assert (
manifest.cache.get_key(
json.dumps(
{
"prompt": ["Hello is a prompt", "Hello is a prompt"],
"gold_choices": ["callt", "dog"],
"prompt": ["Hello is a prompt", "Hello is another prompt"],
"engine": "dummy",
},
sort_keys=True,
@ -367,7 +301,23 @@ def test_choices_run(
)
is not None
)
assert res == ["ca", "ca"]
assert result == {
"generation_key": "choices",
"logits_key": "logprobs",
"item_key": "text",
"item_dtype": None,
"response": {
"choices": [
{"text": "Hello is a prompt", "logprob": 0.3},
{"text": "Hello is another prompt", "logprob": 0.3},
]
},
"cached": False,
"request_params": {
"prompt": ["Hello is a prompt", "Hello is another prompt"],
"engine": "dummy",
},
}
@pytest.mark.usefixtures("session_cache")

Loading…
Cancel
Save