feat: remove choice logits and use prompt scoring (#50)

pull/82/head
Laurel Orr 1 year ago committed by GitHub
parent 876d27bd2d
commit 94b57a6e6f

@ -206,29 +206,26 @@ def embed() -> Dict:
return {"result": results} return {"result": results}
@app.route("/choice_logits", methods=["POST"]) @app.route("/score_sequence", methods=["POST"])
def choice_logits() -> Response: def score_sequence() -> Response:
"""Get maximal likely choice via max logits after generation.""" """Get logprob of prompt."""
prompt = request.json["prompt"] prompt = request.json["prompt"]
del request.json["prompt"] del request.json["prompt"]
gold_choices = request.json["gold_choices"]
del request.json["gold_choices"]
generation_args = request.json generation_args = request.json
if not isinstance(prompt, (str, list)): if not isinstance(prompt, (str, list)):
raise ValueError("Prompt must be a str or list of str") raise ValueError("Prompt must be a str or list of str")
if not isinstance(gold_choices, list):
raise ValueError("Gold choices must be a list of string choices")
try: try:
choice_score_list = model.logits_scoring( score_list = model.score_sequence(prompt, **generation_args)
prompt, gold_choices, **generation_args results = [
) {"text": prompt if isinstance(prompt, str) else prompt[i], "logprob": r}
results = [{"text": r[0], "text_logprob": r[1]} for r in choice_score_list] for i, r in enumerate(score_list)
]
# transform the result into the openai format # transform the result into the openai format
return Response( return Response(
json.dumps( json.dumps(
ModelResponse(results, response_type="choice_selection").__dict__() ModelResponse(results, response_type="prompt_logit_score").__dict__()
), ),
status=200, status=200,
) )

@ -3,6 +3,7 @@ import json
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast from typing import Any, Dict, List, Optional, Tuple, Union, cast
import deepspeed
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
@ -24,7 +25,6 @@ from transformers import (
PreTrainedTokenizer, PreTrainedTokenizer,
) )
import deepspeed
from manifest.api.models.model import Model from manifest.api.models.model import Model
MODEL_REGISTRY = { MODEL_REGISTRY = {

@ -72,17 +72,16 @@ class Model(ABC):
""" """
raise NotImplementedError() raise NotImplementedError()
def logits_scoring( def score_sequence(
self, prompt: Union[str, List[str]], gold_choices: List[str], **kwargs: Any self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float]]: ) -> List[float]:
""" """
Given the prompt and gold choices, choose the best choice with max logits. Score a sequence of choices.
Args: Args:
prompt: promt to generate from. prompt (:obj:`str` or :obj:`List[str]`):
gold_choices: list of choices to choose from. The prompt to score the choices against.
**kwargs:
Returns: Additional keyword arguments passed along to the :obj:`__call__` method.
the returned gold choice and the score.
""" """
raise NotImplementedError() raise NotImplementedError()

@ -14,12 +14,12 @@ class ModelResponse:
self.response_type = response_type self.response_type = response_type
if self.response_type not in { if self.response_type not in {
"text_completion", "text_completion",
"choice_selection", "prompt_logit_score",
"image_generation", "image_generation",
}: }:
raise ValueError( raise ValueError(
f"Invalid response type: {self.response_type}. " f"Invalid response type: {self.response_type}. "
"Must be one of: text_completion, choice_selection." "Must be one of: text_completion, prompt_logit_score, image_generation."
) )
self.response_id = str(uuid.uuid4()) self.response_id = str(uuid.uuid4())
self.created = int(time.time()) self.created = int(time.time())

@ -167,16 +167,14 @@ class Client(ABC):
return _run_completion, request_params return _run_completion, request_params
def get_choice_logit_request( def get_score_prompt_request(
self, self,
gold_choices: List[str],
request: Request, request: Request,
) -> Tuple[Callable[[], Dict], Dict]: ) -> Tuple[Callable[[], Dict], Dict]:
""" """
Get request string function for choosing max choices. Get the logit score of the prompt via a forward pass of the model.
Args: Args:
gold_choices: choices for model to choose from via max logits.
request: request. request: request.
Returns: Returns:
@ -184,5 +182,5 @@ class Client(ABC):
request parameters as dict. request parameters as dict.
""" """
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__.__name__} does not support choice logit request." f"{self.__class__.__name__} does not support prompt scoring request."
) )

@ -1,6 +1,6 @@
"""Dummy client.""" """Dummy client."""
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
from manifest.clients.client import Client from manifest.clients.client import Client
from manifest.request import LMRequest, Request from manifest.request import LMRequest, Request
@ -93,16 +93,14 @@ class DummyClient(Client):
return _run_completion, request_params return _run_completion, request_params
def get_choice_logit_request( def get_score_prompt_request(
self, self,
gold_choices: List[str],
request: Request, request: Request,
) -> Tuple[Callable[[], Dict], Dict]: ) -> Tuple[Callable[[], Dict], Dict]:
""" """
Get request string function for choosing max choices. Get the logit score of the prompt via a forward pass of the model.
Args: Args:
gold_choices: choices for model to choose from via max logits.
request: request. request: request.
Returns: Returns:
@ -113,9 +111,19 @@ class DummyClient(Client):
num_results = len(request.prompt) num_results = len(request.prompt)
else: else:
num_results = 1 num_results = 1
request_params = {"prompt": request.prompt, "gold_choices": gold_choices} request_params = {"prompt": request.prompt}
def _run_completion() -> Dict: def _run_completion() -> Dict:
return {"choices": [{"text": gold_choices[0]}] * num_results} return {
"choices": [
{
"text": request.prompt
if isinstance(request.prompt, str)
else request.prompt[i],
"logprob": 0.3,
}
for i in range(num_results)
]
}
return _run_completion, request_params return _run_completion, request_params

@ -1,6 +1,6 @@
"""Hugging Face client.""" """Hugging Face client."""
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
import requests import requests
@ -78,16 +78,14 @@ class HuggingFaceClient(Client):
res = requests.post(self.host + "/params") res = requests.post(self.host + "/params")
return res.json() return res.json()
def get_choice_logit_request( def get_score_prompt_request(
self, self,
gold_choices: List[str],
request: Request, request: Request,
) -> Tuple[Callable[[], Dict], Dict]: ) -> Tuple[Callable[[], Dict], Dict]:
""" """
Get request string function for choosing max choices. Get the logit score of the prompt via a forward pass of the model.
Args: Args:
gold_choices: choices for model to choose from via max logits.
request: request. request: request.
Returns: Returns:
@ -97,10 +95,10 @@ class HuggingFaceClient(Client):
request_params = request.to_dict(self.PARAMS) request_params = request.to_dict(self.PARAMS)
retry_timeout = request_params.pop("client_timeout") retry_timeout = request_params.pop("client_timeout")
# Do not add params like we do with request as the model isn't sampling # Do not add params like we do with request as the model isn't sampling
request_params = {"prompt": request.prompt, "gold_choices": gold_choices} request_params = {"prompt": request.prompt}
def _run_completion() -> Dict: def _run_completion() -> Dict:
post_str = self.host + "/choice_logits" post_str = self.host + "/score_sequence"
try: try:
res = requests.post( res = requests.post(
post_str, post_str,

@ -1,6 +1,6 @@
"""Manifest class.""" """Manifest class."""
import logging import logging
from typing import Any, List, Optional, Tuple, Union, cast from typing import Any, Dict, List, Optional, Tuple, Union, cast
import numpy as np import numpy as np
@ -16,6 +16,7 @@ from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.openai import OpenAIClient from manifest.clients.openai import OpenAIClient
from manifest.clients.toma import TOMAClient from manifest.clients.toma import TOMAClient
from manifest.clients.toma_diffuser import TOMADiffuserClient from manifest.clients.toma_diffuser import TOMADiffuserClient
from manifest.request import Request
from manifest.response import Response from manifest.response import Response
from manifest.session import Session from manifest.session import Session
@ -145,10 +146,33 @@ class Manifest:
if stop_token is not None: if stop_token is not None:
self.stop_token = stop_token self.stop_token = stop_token
def _validate_kwargs(self, kwargs: Dict, request_params: Request) -> None:
"""Validate kwargs.
Args:
kwargs: kwargs to validate.
request_params: request object to validate against.
"""
# Check for invalid kwargs
non_request_kwargs = [
(k, v) for k, v in kwargs.items() if k not in request_params.__dict__
]
if len(non_request_kwargs) > 0:
raise ValueError(
f"{list(non_request_kwargs)} arguments are not recognized."
)
# Warn for valid but unused kwargs
request_unused_kwargs = [
(k, v) for k, v in kwargs.items() if k not in non_request_kwargs
]
if len(request_unused_kwargs) > 0:
logger.warning(f"{list(request_unused_kwargs)} arguments are unused.")
return
def run( def run(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
gold_choices: Optional[List[str]] = None,
overwrite_cache: bool = False, overwrite_cache: bool = False,
run_id: Optional[str] = None, run_id: Optional[str] = None,
stop_token: Optional[str] = None, stop_token: Optional[str] = None,
@ -160,7 +184,6 @@ class Manifest:
Args: Args:
prompt: prompt(s) to run. prompt: prompt(s) to run.
gold_choices: gold choices for max logit response (only HF models).
overwrite_cache: whether to overwrite cache. overwrite_cache: whether to overwrite cache.
run_id: run id for cache to repeat same run. run_id: run id for cache to repeat same run.
stop_token: stop token for prompt generation. stop_token: stop token for prompt generation.
@ -179,31 +202,9 @@ class Manifest:
# Avoid nested list of results - enforce n = 1 for batch # Avoid nested list of results - enforce n = 1 for batch
if is_batch and request_params.n > 1: if is_batch and request_params.n > 1:
raise ValueError("Batch mode does not support n > 1.") raise ValueError("Batch mode does not support n > 1.")
if gold_choices is None: possible_request, full_kwargs = self.client.get_request(request_params)
possible_request, full_kwargs = self.client.get_request(request_params)
else:
try:
possible_request, full_kwargs = cast(
HuggingFaceClient, self.client
).get_choice_logit_request(gold_choices, request_params)
except AttributeError:
raise ValueError("`gold_choices` only supported for HF models.")
# Check for invalid kwargs self._validate_kwargs(kwargs, request_params)
non_request_kwargs = [
(k, v) for k, v in kwargs.items() if k not in request_params.__dict__
]
if len(non_request_kwargs) > 0:
raise ValueError(
f"{list(non_request_kwargs)} arguments are not recognized."
)
# Warn for valid but unused kwargs
request_unused_kwargs = [
(k, v) for k, v in kwargs.items() if k not in non_request_kwargs
]
if len(request_unused_kwargs) > 0:
logger.warning(f"{list(request_unused_kwargs)} arguments are unused.")
# Create cacke key # Create cacke key
cache_key = full_kwargs.copy() cache_key = full_kwargs.copy()
# Make query model dependent # Make query model dependent
@ -220,6 +221,47 @@ class Manifest:
else: else:
return response_obj.get_response(stop_token, is_batch) return response_obj.get_response(stop_token, is_batch)
def score_prompt(
self,
prompt: Union[str, List[str]],
overwrite_cache: bool = False,
**kwargs: Any,
) -> Dict:
"""
Score the prompt via forward pass of the model - no sampling or generation.
Returns the response object with logits of the prompt.
Prompt scoring is not part of a session cache.
Args:
prompt: prompt(s) to run.
overwrite_cache: whether to overwrite cache.
Returns:
response from prompt.
"""
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self.client.get_request_params(prompt, kwargs)
if request_params.n > 1:
raise ValueError("Sequence scoring does not support n > 1.")
try:
possible_request, full_kwargs = cast(
HuggingFaceClient, self.client
).get_score_prompt_request(request_params)
except AttributeError:
raise ValueError("`score_prompt` only supported for HF models.")
self._validate_kwargs(kwargs, request_params)
# Create cacke key
cache_key = full_kwargs.copy()
# Make query model dependent
cache_key.update(self.client.get_model_params())
response_obj = self.cache.get(cache_key, overwrite_cache, possible_request)
return response_obj.to_dict()
def get_last_queries( def get_last_queries(
self, self,
last_n: int = -1, last_n: int = -1,

@ -254,11 +254,7 @@ def test_batch_run(
@pytest.mark.usefixtures("sqlite_cache") @pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("session_cache") def test_score_run(sqlite_cache: str) -> None:
@pytest.mark.parametrize("return_response", [True, False])
def test_choices_run(
sqlite_cache: str, session_cache: str, return_response: bool
) -> None:
"""Test manifest run.""" """Test manifest run."""
manifest = Manifest( manifest = Manifest(
client_name="dummy", client_name="dummy",
@ -267,71 +263,13 @@ def test_choices_run(
) )
prompt = "This is a prompt" prompt = "This is a prompt"
# Dummy client will always return first choice result = manifest.score_prompt(prompt)
choices = ["cat", "dog"]
result = manifest.run(prompt, gold_choices=choices, return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = cast(Response, result).get_response(manifest.stop_token)
else:
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
{
"prompt": "This is a prompt",
"gold_choices": ["cat", "dog"],
"engine": "dummy",
},
sort_keys=True,
)
)
is not None
)
assert res == "cat"
prompt = "Hello is a prompt"
choices = ["cat", "dog"]
result = manifest.run(prompt, gold_choices=choices, return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = cast(Response, result).get_response(manifest.stop_token)
else:
res = cast(str, result)
assert ( assert (
manifest.cache.get_key( manifest.cache.get_key(
json.dumps( json.dumps(
{ {
"prompt": "Hello is a prompt", "prompt": "This is a prompt",
"gold_choices": ["cat", "dog"],
"engine": "dummy",
},
sort_keys=True,
)
)
is not None
)
assert res == "cat"
prompt = "Hello is a prompt"
choices = ["callt", "dog"]
result = manifest.run(
prompt,
gold_choices=choices,
stop_token="ll",
return_response=return_response,
)
if return_response:
assert isinstance(result, Response)
res = cast(Response, result).get_response(stop_token="ll")
else:
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
{
"prompt": "Hello is a prompt",
"gold_choices": ["cat", "dog"],
"engine": "dummy", "engine": "dummy",
}, },
sort_keys=True, sort_keys=True,
@ -339,27 +277,23 @@ def test_choices_run(
) )
is not None is not None
) )
assert res == "ca" assert result == {
"generation_key": "choices",
"logits_key": "logprobs",
"item_key": "text",
"item_dtype": None,
"response": {"choices": [{"text": "This is a prompt", "logprob": 0.3}]},
"cached": False,
"request_params": {"prompt": "This is a prompt", "engine": "dummy"},
}
prompt_lst = ["Hello is a prompt", "Hello is a prompt"] prompt_list = ["Hello is a prompt", "Hello is another prompt"]
choices = ["callt", "dog"] result = manifest.score_prompt(prompt_list)
result = manifest.run(
prompt_lst,
gold_choices=choices,
stop_token="ll",
return_response=return_response,
)
if return_response:
assert isinstance(result, Response)
res = cast(Response, result).get_response(stop_token="ll", is_batch=True)
else:
res = cast(str, result)
assert ( assert (
manifest.cache.get_key( manifest.cache.get_key(
json.dumps( json.dumps(
{ {
"prompt": ["Hello is a prompt", "Hello is a prompt"], "prompt": ["Hello is a prompt", "Hello is another prompt"],
"gold_choices": ["callt", "dog"],
"engine": "dummy", "engine": "dummy",
}, },
sort_keys=True, sort_keys=True,
@ -367,7 +301,23 @@ def test_choices_run(
) )
is not None is not None
) )
assert res == ["ca", "ca"] assert result == {
"generation_key": "choices",
"logits_key": "logprobs",
"item_key": "text",
"item_dtype": None,
"response": {
"choices": [
{"text": "Hello is a prompt", "logprob": 0.3},
{"text": "Hello is another prompt", "logprob": 0.3},
]
},
"cached": False,
"request_params": {
"prompt": ["Hello is a prompt", "Hello is another prompt"],
"engine": "dummy",
},
}
@pytest.mark.usefixtures("session_cache") @pytest.mark.usefixtures("session_cache")

Loading…
Cancel
Save