You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
manifest/manifest/request.py

133 lines
2.9 KiB
Python

"""Request object."""
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
# Used when unioning requests after async connection pool
ENGINE_SEP = "::"
NOT_CACHE_KEYS = {"client_timeout", "batch_size"}
# The below should match those in Request.
DEFAULT_REQUEST_KEYS = {
"client_timeout": ("client_timeout", 60), # seconds
"batch_size": ("batch_size", 8),
"run_id": ("run_id", None),
}
class Request(BaseModel):
"""Request object."""
# Prompt
prompt: Union[str, List[str]] = ""
# Engine
engine: str = "text-ada-001"
# Number completions
n: int = 1
# Timeout
client_timeout: int = 60
# Run id used to repeat run with same parameters
run_id: Optional[str] = None
# Batch size for async batch run
batch_size: int = 8
def to_dict(
self, allowable_keys: Dict[str, Tuple[str, Any]] = None, add_prompt: bool = True
) -> Dict[str, Any]:
"""
Convert request to a dictionary.
Handles parameter renaming but does not fill in default values.
It will drop any None values.
Add prompt ensures the prompt is always in the output dictionary.
"""
if allowable_keys:
include_keys = set(allowable_keys.keys())
if add_prompt and "prompt":
include_keys.add("prompt")
else:
allowable_keys = {}
include_keys = None
request_dict = {
allowable_keys.get(k, (k, None))[0]: v
for k, v in self.dict(include=include_keys).items()
if v is not None
}
return request_dict
class LMRequest(Request):
"""Language Model Request object."""
# Temperature for generation
temperature: float = 0.7
# Max tokens for generation
max_tokens: int = 100
# Nucleus sampling taking top_p probability mass tokens
top_p: float = 1.0
# Top k sampling taking top_k highest probability tokens
top_k: int = 50
# Logprobs return value
logprobs: Optional[int] = None
# Stop sequences
stop_sequences: Optional[List[str]] = None
# Number beams beam search (HF)
num_beams: int = 1
# Whether to sample or do greedy (HF)
do_sample: bool = False
# Penalize repetition (HF)
repetition_penalty: float = 1.0
# Length penalty (HF)
length_penalty: float = 1.0
# Penalize resence
presence_penalty: float = 0
# Penalize frequency
frequency_penalty: float = 0
class LMScoreRequest(LMRequest):
"""Language Model Score Request object."""
pass
class EmbeddingRequest(Request):
"""Embedding Request object."""
pass
class DiffusionRequest(Request):
"""Diffusion Model Request object."""
# Number of steps
num_inference_steps: int = 50
# Height of image
height: int = 512
# Width of image
width: int = 512
# Guidance scale
guidance_scale: float = 7.5
# Eta
eta: float = 0.0