[feat] toma client (#36)

* [feat] toma client

* Update toma.py

* [chore] clean up model heartbeat toma

Co-authored-by: Karan Goel <kgoel93@gmail.com>
laurel/helm
Laurel Orr 2 years ago committed by GitHub
parent 5bd3175ee0
commit 889837f013

@ -3,3 +3,5 @@ from manifest.manifest import Manifest
from manifest.prompt import Prompt
from manifest.response import Response
from manifest.session import Session
__all__ = ["Manifest", "Prompt", "Response", "Session"]

@ -1,2 +1 @@
"""Cache init."""
from manifest.caches.cache import Cache

@ -1,7 +1,7 @@
"""Noop cache."""
from typing import Any, Dict, Union
from manifest.caches import Cache
from manifest.caches.cache import Cache
class NoopCache(Cache):

@ -3,7 +3,7 @@ from typing import Any, Dict, Union
import redis
from manifest.caches import Cache
from manifest.caches.cache import Cache
class RedisCache(Cache):

@ -4,7 +4,7 @@ from typing import Any, Dict, Union
from sqlitedict import SqliteDict
from manifest.caches import Cache
from manifest.caches.cache import Cache
logging.getLogger("sqlitedict").setLevel(logging.WARNING)

@ -0,0 +1,276 @@
"""TOMA client."""
import logging
import os
import time
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
import requests
from manifest.clients.client import Client
logger = logging.getLogger(__name__)
TOMA_ENGINES = {
"bloom",
"glm-int8",
"gpt-neox-20b",
"opt-66b",
"opt-175b",
"glm",
"stable_diffusion",
"t0pp",
"gpt-j-6b",
"t5-11b",
"glm-int4",
"ul2",
}
# Engine -> request type
# Default is language-model-inference
TOMA_ENGINE_REQUEST_TYPE = {
"stable_diffusion": "image-model-inference",
}
# User param -> (client param, default value)
TOMA_PARAMS = {
"engine": ("model", "gpt-j-6b"),
"temperature": ("temperature", 1.0),
"max_tokens": ("max_tokens", 10),
"n": ("n", 1),
"p": ("top_p", 1.0),
"stop_sequences": ("stop", []),
"best_of": ("best_of", 1),
"logprobs": ("logprobs", 0),
"prompt_embedding": ("prompt_embedding", False),
"echo": ("echo", False),
"rate_limit_retry_timeout": ("rate_limit_retry_timeout", 120), # seconds
}
class TOMAClient(Client):
"""TOMA client."""
pending_jobs: List = []
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the TOMA url.
Arsg:
connection_str: connection string.
client_args: client arguments.
"""
self.host = os.environ.get("TOMA_URL", None)
if not self.host:
raise ValueError("TOMA_URL environment variable not set.")
# self.api_key = os.environ.get("TOMA_API_KEY", connection_str)
# if self.api_key is None:
# raise ValueError(
# "TOMA API key not set. Set TOMA_API_KEY environment "
# "variable or pass through `connection_str`."
# )
for key in TOMA_PARAMS:
setattr(self, key, client_args.pop(key, TOMA_PARAMS[key][1]))
if getattr(self, "engine") not in TOMA_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. Must be {TOMA_ENGINES}."
)
model_heartbeats = self.get_model_heartbeats()
model_heartbeat_threshold = 60
print("TOMA MODEL HEARTBEATS\n", model_heartbeats)
if (
model_heartbeats[getattr(self, "engine")]["last_ping"]
> model_heartbeat_threshold
):
raise ValueError(
f"Model {getattr(self, 'engine')} has not been pinged in "
f"{model_heartbeats[getattr(self, 'engine')]} seconds."
)
if model_heartbeats[getattr(self, "engine")]["expected_runtime"] > getattr(
self, "rate_limit_retry_timeout"
):
logger.warning(
f"Model {getattr(self, 'engine')} has expected runtime "
f"{model_heartbeats[getattr(self, 'engine')]['expected_runtime']} "
f"and may take longer than {getattr(self, 'rate_limit_retry_timeout')} "
"seconds to respond. Increase rate_limit_retry_timeout "
"to avoid timeout."
)
def close(self) -> None:
"""Close the client."""
pass
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "toma", "engine": getattr(self, "engine")}
def get_model_inputs(self) -> List:
"""
Get allowable model inputs.
Returns:
model inputs.
"""
return list(TOMA_PARAMS.keys())
def get_model_heartbeats(self) -> Dict[str, Dict]:
"""
Get TOMA models and their last ping time.
Some TOMA models are not loaded and will not response.
Returns:
model name to time since last ping (sec).
"""
res = requests.get(self.host + "/model_statuses").json()
heartbeats = {}
for mod in res:
mod_time = datetime.fromisoformat(mod["last_heartbeat"])
now = datetime.now(mod_time.tzinfo)
heartbeats[mod["name"]] = {
"last_ping": (now - mod_time).total_seconds(),
"expected_runtime": mod["expected_runtime"],
}
return heartbeats
def format_response(self, response: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
Return:
response as dict
"""
return {
"model": getattr(self, "engine"),
"choices": [
{
"text": item["text"],
# "logprobs": [],
}
for item in response["inference_result"][0]["choices"]
],
}
def get_response(self, job_id: str, retry_timeout: int) -> Dict[str, Any]:
"""
Get response from job id.
Will try up to `rate_limit_retry_timeout` seconds to get response.
Args:
job_id: job id
retry_timeout: retry timeout
Returns:
response as dict
"""
final_res = None
attempts = 0
while True:
ret = requests.get(f"{self.host}/job/{job_id}", json={"id": job_id}).json()
attempts += 1
if ret["status"] == "finished" or ret["status"] == "failed":
final_res = ret["returned_payload"]
break
if attempts > retry_timeout:
break
time.sleep(1)
if not final_res:
raise RuntimeError(
f"TOMA request timed out after {retry_timeout}s with {ret['status']}."
)
if "result" in final_res:
return self.format_response(final_res["result"])
else:
raise RuntimeError(f"TOMA request failed with {final_res['message']}.")
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Args:
query: query string.
Returns:
request function that takes no input.
request parameters as dict.
"""
request_params = {
"prompt": [query],
"request_type": TOMA_ENGINE_REQUEST_TYPE.get(
getattr(self, "engine"), "language-model-inference"
),
}
for key in TOMA_PARAMS:
if key in ["rate_limit_retry_timeout"]:
# These are not passed to the TOMA API
continue
request_params[TOMA_PARAMS[key][0]] = request_args.pop(
key, getattr(self, key)
)
retry_timeout = request_args.pop(
"rate_limit_retry_timeout", getattr(self, "rate_limit_retry_timeout")
)
# num_returns is for image-model-inference
if request_params["request_type"] == "image-model-inference":
request_params["num_returns"] = request_params["n"]
def _run_completion() -> Dict:
post_str = self.host + "/jobs"
res = requests.post(
post_str,
# headers={"Authorization": f"Bearer {self.api_key}"},
json={
"type": "general",
"payload": request_params,
"returned_payload": {},
"status": "submitted",
"source": "dalle",
},
).json()
job_id = res["id"]
print(f"job_id: {job_id}")
# TODO: ideally just submit the jobs and then fetch results in parallel
self.pending_jobs.append(job_id)
job_id = self.pending_jobs.pop()
return self.get_response(job_id, retry_timeout)
return _run_completion, request_params
def get_choice_logit_request(
self, query: str, gold_choices: List[str], request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function for choosing max choices.
Args:
query: query string.
gold_choices: choices for model to choose from via max logits.
Returns:
request function that takes no input.
request parameters as dict.
"""
raise NotImplementedError("TOMA does not support choice logit request.")

@ -13,6 +13,7 @@ from manifest.clients.dummy import DummyClient
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.openai import OpenAIClient
from manifest.clients.opt import OPTClient
from manifest.clients.toma import TOMAClient
from manifest.clients.zoo import ZooClient
from manifest.prompt import Prompt
from manifest.response import Response
@ -28,6 +29,7 @@ CLIENT_CONSTRUCTORS = {
"huggingface": HuggingFaceClient,
"opt": OPTClient,
"dummy": DummyClient,
"toma": TOMAClient,
"zoo": ZooClient,
}

Loading…
Cancel
Save