mirror of https://github.com/HazyResearch/manifest
Zoo (#23)
* Zoo Model * Remove optional import zoo * Read model name from zoo model * Logprobs passed through raw response for gold choices Co-authored-by: Simran <emailsimran@gmail.com> Co-authored-by: Dan Fu <danfu@cs.stanford.edu>laurel/helm
parent
e0a76d1f93
commit
5428afdc58
@ -0,0 +1,94 @@
|
||||
"""Zoo model."""
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from manifest.api.models.model import Model
|
||||
|
||||
ZOO_PATH = os.environ.get("ZOO_PATH", None)
|
||||
if not ZOO_PATH:
|
||||
raise ImportError("ZOO_PATH environment variable not set.")
|
||||
sys.path.append(ZOO_PATH)
|
||||
|
||||
from src.models.s4_seq import S4LMManifest # type: ignore
|
||||
|
||||
|
||||
class ZooModel(Model):
|
||||
"""Zoo model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
model_config: str,
|
||||
cache_dir: str,
|
||||
device: int,
|
||||
use_accelerate: bool,
|
||||
use_parallelize: bool,
|
||||
perc_max_gpu_mem_red: float,
|
||||
use_fp16: bool,
|
||||
):
|
||||
"""
|
||||
Initialize model.
|
||||
|
||||
All arguments will be passed in the request from Manifest.
|
||||
|
||||
Args:
|
||||
model_name_or_path: model name string.
|
||||
model_config: model config path.
|
||||
cache_dir: cache directory for model.
|
||||
device: device to use for model.
|
||||
use_accelerate: whether to use accelerate for multi-gpu inference.
|
||||
use_parallelize: use HF default parallelize
|
||||
perc_max_gpu_mem_red: percent max memory reduction in accelerate
|
||||
use_fp16: use fp16 for model weights.
|
||||
"""
|
||||
# Check if providing path
|
||||
self.model_path = model_name_or_path
|
||||
self.model_config = model_config
|
||||
if not self.model_config:
|
||||
raise ValueError("Must provide model config.")
|
||||
self.model = S4LMManifest(
|
||||
config_path=self.model_config,
|
||||
weights_path=self.model_path,
|
||||
)
|
||||
# Can only load this after the model has been initialized
|
||||
self.model_name = self.model.get_model_name()
|
||||
|
||||
def get_init_params(self) -> Dict:
|
||||
"""Return init params to determine what model is being used."""
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"model_path": self.model_path,
|
||||
"model_config": self.model_config,
|
||||
}
|
||||
|
||||
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
|
||||
"""
|
||||
Generate the prompt from model.
|
||||
|
||||
Outputs must be generated text, not including prompt.
|
||||
|
||||
Args:
|
||||
prompt: promt to generate from.
|
||||
|
||||
Returns:
|
||||
list of generated text (list of length 1 for 1 generation).
|
||||
"""
|
||||
print(prompt)
|
||||
final_results = self.model.generate(prompt, **kwargs)
|
||||
return final_results
|
||||
|
||||
def logits_scoring(
|
||||
self, prompt: str, gold_choices: List[str], **kwargs: Any
|
||||
) -> Tuple[str, float]:
|
||||
"""
|
||||
Given the prompt and gold choices, choose the best choice with max logits.
|
||||
|
||||
Args:
|
||||
prompt: promt to generate from.
|
||||
gold_choices: list of choices to choose from.
|
||||
|
||||
Returns:
|
||||
the returned gold choice and the score
|
||||
"""
|
||||
raise NotImplementedError()
|
@ -0,0 +1,102 @@
|
||||
"""Zoo client."""
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from manifest.clients.client import Client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# User param -> (client param, default value)
|
||||
ZOO_PARAMS: Dict[str, Tuple[str, str]] = {}
|
||||
|
||||
|
||||
class ZooClient(Client):
|
||||
"""Zoo client."""
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the model.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.host = connection_str.rstrip("/")
|
||||
for key in ZOO_PARAMS:
|
||||
setattr(self, key, client_args.pop(key, ZOO_PARAMS[key][1]))
|
||||
self.model_params = self.get_model_params()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
res = requests.post(self.host + "/params")
|
||||
return res.json()
|
||||
|
||||
def get_model_inputs(self) -> List:
|
||||
"""
|
||||
Get allowable model inputs.
|
||||
|
||||
Returns:
|
||||
model inputs.
|
||||
"""
|
||||
return list(ZOO_PARAMS.keys())
|
||||
|
||||
def get_request(
|
||||
self, query: str, request_args: Dict[str, Any] = {}
|
||||
) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
||||
Args:
|
||||
query: query string.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
request_params = {"prompt": query}
|
||||
# Zoo is greedy and takes all params
|
||||
# TODO: Once zoo is finalized, fix this
|
||||
for key in list(request_args.keys()):
|
||||
request_params[key] = request_args.pop(key, None)
|
||||
request_params.update(self.model_params)
|
||||
|
||||
def _run_completion() -> Dict:
|
||||
post_str = self.host + "/completions"
|
||||
res = requests.post(post_str, json=request_params)
|
||||
return res.json()
|
||||
|
||||
return _run_completion, request_params
|
||||
|
||||
def get_choice_logit_request(
|
||||
self, query: str, gold_choices: List[str], request_args: Dict[str, Any] = {}
|
||||
) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function for choosing max choices.
|
||||
|
||||
Args:
|
||||
query: query string.
|
||||
gold_choices: choices for model to choose from via max logits.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
raise NotImplementedError("Zoo does not support choice logit request.")
|
Loading…
Reference in New Issue