|
|
|
@ -1,6 +1,6 @@
|
|
|
|
|
"""Manifest class."""
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Any, List, Optional, Tuple, Union, cast
|
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
@ -16,6 +16,7 @@ from manifest.clients.huggingface import HuggingFaceClient
|
|
|
|
|
from manifest.clients.openai import OpenAIClient
|
|
|
|
|
from manifest.clients.toma import TOMAClient
|
|
|
|
|
from manifest.clients.toma_diffuser import TOMADiffuserClient
|
|
|
|
|
from manifest.request import Request
|
|
|
|
|
from manifest.response import Response
|
|
|
|
|
from manifest.session import Session
|
|
|
|
|
|
|
|
|
@ -145,10 +146,33 @@ class Manifest:
|
|
|
|
|
if stop_token is not None:
|
|
|
|
|
self.stop_token = stop_token
|
|
|
|
|
|
|
|
|
|
def _validate_kwargs(self, kwargs: Dict, request_params: Request) -> None:
|
|
|
|
|
"""Validate kwargs.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
kwargs: kwargs to validate.
|
|
|
|
|
request_params: request object to validate against.
|
|
|
|
|
"""
|
|
|
|
|
# Check for invalid kwargs
|
|
|
|
|
non_request_kwargs = [
|
|
|
|
|
(k, v) for k, v in kwargs.items() if k not in request_params.__dict__
|
|
|
|
|
]
|
|
|
|
|
if len(non_request_kwargs) > 0:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"{list(non_request_kwargs)} arguments are not recognized."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Warn for valid but unused kwargs
|
|
|
|
|
request_unused_kwargs = [
|
|
|
|
|
(k, v) for k, v in kwargs.items() if k not in non_request_kwargs
|
|
|
|
|
]
|
|
|
|
|
if len(request_unused_kwargs) > 0:
|
|
|
|
|
logger.warning(f"{list(request_unused_kwargs)} arguments are unused.")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
def run(
|
|
|
|
|
self,
|
|
|
|
|
prompt: Union[str, List[str]],
|
|
|
|
|
gold_choices: Optional[List[str]] = None,
|
|
|
|
|
overwrite_cache: bool = False,
|
|
|
|
|
run_id: Optional[str] = None,
|
|
|
|
|
stop_token: Optional[str] = None,
|
|
|
|
@ -160,7 +184,6 @@ class Manifest:
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompt: prompt(s) to run.
|
|
|
|
|
gold_choices: gold choices for max logit response (only HF models).
|
|
|
|
|
overwrite_cache: whether to overwrite cache.
|
|
|
|
|
run_id: run id for cache to repeat same run.
|
|
|
|
|
stop_token: stop token for prompt generation.
|
|
|
|
@ -179,31 +202,9 @@ class Manifest:
|
|
|
|
|
# Avoid nested list of results - enforce n = 1 for batch
|
|
|
|
|
if is_batch and request_params.n > 1:
|
|
|
|
|
raise ValueError("Batch mode does not support n > 1.")
|
|
|
|
|
if gold_choices is None:
|
|
|
|
|
possible_request, full_kwargs = self.client.get_request(request_params)
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
possible_request, full_kwargs = cast(
|
|
|
|
|
HuggingFaceClient, self.client
|
|
|
|
|
).get_choice_logit_request(gold_choices, request_params)
|
|
|
|
|
except AttributeError:
|
|
|
|
|
raise ValueError("`gold_choices` only supported for HF models.")
|
|
|
|
|
possible_request, full_kwargs = self.client.get_request(request_params)
|
|
|
|
|
|
|
|
|
|
# Check for invalid kwargs
|
|
|
|
|
non_request_kwargs = [
|
|
|
|
|
(k, v) for k, v in kwargs.items() if k not in request_params.__dict__
|
|
|
|
|
]
|
|
|
|
|
if len(non_request_kwargs) > 0:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"{list(non_request_kwargs)} arguments are not recognized."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Warn for valid but unused kwargs
|
|
|
|
|
request_unused_kwargs = [
|
|
|
|
|
(k, v) for k, v in kwargs.items() if k not in non_request_kwargs
|
|
|
|
|
]
|
|
|
|
|
if len(request_unused_kwargs) > 0:
|
|
|
|
|
logger.warning(f"{list(request_unused_kwargs)} arguments are unused.")
|
|
|
|
|
self._validate_kwargs(kwargs, request_params)
|
|
|
|
|
# Create cacke key
|
|
|
|
|
cache_key = full_kwargs.copy()
|
|
|
|
|
# Make query model dependent
|
|
|
|
@ -220,6 +221,47 @@ class Manifest:
|
|
|
|
|
else:
|
|
|
|
|
return response_obj.get_response(stop_token, is_batch)
|
|
|
|
|
|
|
|
|
|
def score_prompt(
|
|
|
|
|
self,
|
|
|
|
|
prompt: Union[str, List[str]],
|
|
|
|
|
overwrite_cache: bool = False,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Dict:
|
|
|
|
|
"""
|
|
|
|
|
Score the prompt via forward pass of the model - no sampling or generation.
|
|
|
|
|
|
|
|
|
|
Returns the response object with logits of the prompt.
|
|
|
|
|
|
|
|
|
|
Prompt scoring is not part of a session cache.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompt: prompt(s) to run.
|
|
|
|
|
overwrite_cache: whether to overwrite cache.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
response from prompt.
|
|
|
|
|
"""
|
|
|
|
|
# Must pass kwargs as dict for client "pop" methods removed used arguments
|
|
|
|
|
request_params = self.client.get_request_params(prompt, kwargs)
|
|
|
|
|
|
|
|
|
|
if request_params.n > 1:
|
|
|
|
|
raise ValueError("Sequence scoring does not support n > 1.")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
possible_request, full_kwargs = cast(
|
|
|
|
|
HuggingFaceClient, self.client
|
|
|
|
|
).get_score_prompt_request(request_params)
|
|
|
|
|
except AttributeError:
|
|
|
|
|
raise ValueError("`score_prompt` only supported for HF models.")
|
|
|
|
|
|
|
|
|
|
self._validate_kwargs(kwargs, request_params)
|
|
|
|
|
# Create cacke key
|
|
|
|
|
cache_key = full_kwargs.copy()
|
|
|
|
|
# Make query model dependent
|
|
|
|
|
cache_key.update(self.client.get_model_params())
|
|
|
|
|
response_obj = self.cache.get(cache_key, overwrite_cache, possible_request)
|
|
|
|
|
return response_obj.to_dict()
|
|
|
|
|
|
|
|
|
|
def get_last_queries(
|
|
|
|
|
self,
|
|
|
|
|
last_n: int = -1,
|
|
|
|
|