fix: google models

pull/98/head
Laurel Orr 1 year ago
parent 4903c7e7e8
commit a982728f75

@ -1,5 +1,13 @@
0.1.8 - Unreleased
---------------------
Added
^^^^^
* Azure model support
* Google Vertex API model support
Fixed
^^^^^
* `run` with batches now acts the same as async run except not async. We will batch requests into appropriate batchs sizes.
0.1.7 - 2023-05-17
---------------------

@ -246,6 +246,29 @@ class Client(ABC):
request_params_list.append(params)
return request_params_list
def stitch_responses(self, request: Request, responses: List[Dict]) -> Response:
"""Stitch responses together.
Useful for batch requests.
"""
choices = []
usages = []
for res_dict in responses:
choices.extend(res_dict["choices"])
if "usage" in res_dict:
usages.extend(res_dict["usage"])
final_response_dict = {"choices": choices}
final_usages = None
if usages:
final_usages = Usages(usages=[Usage(**usage) for usage in usages])
return Response(
self.get_model_choices(final_response_dict),
cached=False,
request=request,
usages=final_usages,
**RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
)
@retry(
reraise=True,
retry=retry_if_ratelimit,
@ -285,14 +308,13 @@ class Client(ABC):
stop=stop_after_attempt(10),
)
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
batch_size: batch size for requests.
Returns:
response as dict.
@ -319,29 +341,44 @@ class Client(ABC):
Returns:
response.
"""
if isinstance(request.prompt, list) and not self.supports_batch_inference():
raise ValueError(
# Make everything list for consistency
if isinstance(request.prompt, list):
prompt_list = request.prompt
else:
prompt_list = [request.prompt]
request_params = self.get_request_params(request)
# Set the params as a list. Do not set the request
# object itself as the cache will then store it as a
# list which is inconsistent with the request input.
request_params["prompt"] = prompt_list
# If batch_size is not set, set it to 1
batch_size = request_params.pop("batch_size") or 1
if not self.supports_batch_inference():
logger.warning(
f"{self.__class__.__name__} does not support batch inference."
" setting batch size ot 1"
)
batch_size = 1
request_params = self.get_request_params(request)
# Take the default keys we need and drop the rest as they
# are not part of the model request.
retry_timeout = request_params.pop("client_timeout")
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
response_dict = self._run_completion(request_params, retry_timeout)
usages = None
if "usage" in response_dict:
usages = [Usage(**usage) for usage in response_dict["usage"]]
return Response(
response=self.get_model_choices(response_dict),
cached=False,
request=request,
usages=Usages(usages=usages) if usages else None,
**RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
)
# Batch requests
num_batches = len(prompt_list) // batch_size
if len(prompt_list) % batch_size != 0:
batch_size = int(math.ceil(len(prompt_list) / (num_batches + 1)))
request_batches = self.split_requests(request_params, batch_size)
response_dicts = [
self._run_completion(batch, retry_timeout) for batch in request_batches
]
# Flatten responses
return self.stitch_responses(request, response_dicts)
async def arun_batch_request(self, request: Request) -> Response:
"""
@ -376,28 +413,12 @@ class Client(ABC):
request_batches = self.split_requests(request_params, batch_size)
all_tasks = [
asyncio.create_task(self._arun_completion(batch, retry_timeout, batch_size))
asyncio.create_task(self._arun_completion(batch, retry_timeout))
for batch in request_batches
]
responses = await asyncio.gather(*all_tasks)
# Flatten responses
choices = []
usages = []
for res_dict in responses:
choices.extend(res_dict["choices"])
if "usage" in res_dict:
usages.extend(res_dict["usage"])
final_response_dict = {"choices": choices}
final_usages = None
if usages:
final_usages = Usages(usages=[Usage(**usage) for usage in usages])
return Response(
self.get_model_choices(final_response_dict),
cached=False,
request=request,
usages=final_usages,
**RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
)
return self.stitch_responses(request, responses)
def run_chat_request(
self,

@ -0,0 +1,201 @@
"""OpenAI client."""
import logging
import os
import subprocess
from typing import Any, Dict, Optional, Type
from manifest.clients.client import Client
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
# https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart
GOOGLE_ENGINES = {
"text-bison",
}
def get_project_id() -> Optional[str]:
"""Get project ID.
Run
`gcloud config get-value project`
"""
try:
project_id = subprocess.run(
["gcloud", "config", "get-value", "project"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if project_id.stderr.decode("utf-8").strip():
return None
return project_id.stdout.decode("utf-8").strip()
except Exception:
return None
class GoogleClient(Client):
"""Google client."""
# User param -> (client param, default value)
PARAMS = {
"engine": ("model", "text-bison"),
"temperature": ("temperature", 1.0),
"max_tokens": ("maxOutputTokens", 10),
"top_p": ("topP", 1.0),
"top_k": ("topK", 1),
"batch_size": ("batch_size", 20),
}
REQUEST_CLS: Type[Request] = LMRequest
NAME = "google"
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the GoogleVertex API.
connection_str is passed as default GOOGLE_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
self.api_key = os.environ.get("GOOGLE_API_KEY", connection_str)
if self.api_key is None:
raise ValueError(
"GoogleVertex API key not set. Set GOOGLE_API_KEY environment "
"variable or pass through `client_connection`. This can be "
"found by running `gcloud auth print-access-token`"
)
self.project_id = os.environ.get("GOOGLE_PROJECT_ID") or get_project_id()
if self.project_id is None:
raise ValueError("GoogleVertex project ID not set. Set GOOGLE_PROJECT_ID")
self.host = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/us-central1/publishers/google/models" # noqa: E501
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in GOOGLE_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. Must be {GOOGLE_ENGINES}."
)
def close(self) -> None:
"""Close the client."""
pass
def get_generation_url(self) -> str:
"""Get generation URL."""
model = getattr(self, "engine")
return self.host + f"/{model}:predict"
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return {"Authorization": f"Bearer {self.api_key}"}
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return True
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def _reformat_request_for_google(
self, request_params: Dict[str, Any]
) -> Dict[str, Any]:
"""Reformat request for google."""
# Refortmat the request params for google
prompt = request_params.pop("prompt")
if isinstance(prompt, str):
prompt_list = [prompt]
else:
prompt_list = prompt
google_request_params = {
"instances": [{"prompt": prompt} for prompt in prompt_list],
"parameters": request_params,
}
return google_request_params
def _run_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
Returns:
response as dict.
"""
return super()._run_completion(
self._reformat_request_for_google(request_params), retry_timeout
)
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
Returns:
response as dict.
"""
return await super()._arun_completion(
self._reformat_request_for_google(request_params), retry_timeout
)
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Validate response as dict.
Assumes response is dict
{
"predictions": [
{
"safetyAttributes": {
"categories": ["Violent", "Sexual"],
"blocked": false,
"scores": [0.1, 0.1]
},
"content": "SELECT * FROM "WWW";"
}
]
}
Args:
response: response
request: request
Return:
response as dict
"""
google_predictions = response.pop("predictions")
new_response = {
"choices": [
{
"text": prediction["content"],
}
for prediction in google_predictions
]
}
return super().validate_response(new_response, request)

@ -0,0 +1,182 @@
"""OpenAI client."""
import copy
import logging
import os
from typing import Any, Dict, Optional, Type
from manifest.clients.google import GoogleClient, get_project_id
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
# https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart
GOOGLE_ENGINES = {
"chat-bison",
}
class GoogleChatClient(GoogleClient):
"""GoogleChat client."""
# User param -> (client param, default value)
PARAMS = {
"engine": ("model", "chat-bison"),
"temperature": ("temperature", 1.0),
"max_tokens": ("maxOutputTokens", 10),
"top_p": ("topP", 1.0),
"top_k": ("topK", 1),
"batch_size": ("batch_size", 20),
}
REQUEST_CLS: Type[Request] = LMRequest
NAME = "googlechat"
IS_CHAT = True
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the GoogleVertex API.
connection_str is passed as default GOOGLE_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
self.api_key = os.environ.get("GOOGLE_API_KEY", connection_str)
if self.api_key is None:
raise ValueError(
"GoogleVertex API key not set. Set GOOGLE_API_KEY environment "
"variable or pass through `client_connection`. This can be "
"found by running `gcloud auth print-access-token`"
)
self.project_id = os.environ.get("GOOGLE_PROJECT_ID") or get_project_id()
if self.project_id is None:
raise ValueError("GoogleVertex project ID not set. Set GOOGLE_PROJECT_ID")
self.host = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/us-central1/publishers/google/models" # noqa: E501
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in GOOGLE_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. Must be {GOOGLE_ENGINES}."
)
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return False
def _format_request_for_chat(self, request_params: Dict[str, Any]) -> Dict:
"""Format request params for chat.
Args:
request_params: request params.
Returns:
formatted request params.
"""
# Format for chat model
request_params = copy.deepcopy(request_params)
prompt = request_params.pop("prompt")
if isinstance(prompt, str):
messages = [{"author": "user", "content": prompt}]
elif isinstance(prompt, list) and isinstance(prompt[0], str):
prompt_list = prompt
messages = [{"author": "user", "content": prompt} for prompt in prompt_list]
elif isinstance(prompt, list) and isinstance(prompt[0], dict):
for pmt_dict in prompt:
if "author" not in pmt_dict or "content" not in pmt_dict:
raise ValueError(
"Prompt must be list of dicts with 'author' and 'content' "
f"keys. Got {prompt}."
)
messages = prompt
else:
raise ValueError(
"Prompt must be string, list of strings, or list of dicts."
f"Got {prompt}"
)
new_request = {
"instances": [{"messages": messages}],
"parameters": request_params,
}
return new_request
def _run_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
Returns:
response as dict.
"""
# Format for chat model
request_params = self._format_request_for_chat(request_params)
response_dict = super(GoogleClient, self)._run_completion(
request_params, retry_timeout
)
# Validate response handles the reformatting
return response_dict
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
Returns:
response as dict.
"""
# Format for chat model
request_params = self._format_request_for_chat(request_params)
response_dict = await super(GoogleClient, self)._arun_completion(
request_params, retry_timeout
)
# Validate response handles the reformatting
return response_dict
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Validate response as dict.
Assumes response is dict
{
"candidates": [
{
"safetyAttributes": {
"categories": ["Violent", "Sexual"],
"blocked": false,
"scores": [0.1, 0.1]
},
"author": "1",
"content": "SELECT * FROM "WWW";"
}
]
}
Args:
response: response
request: request
Return:
response as dict
"""
google_predictions = response.pop("predictions")
new_response = {
"choices": [
{
"text": prediction["candidates"][0]["content"],
}
for prediction in google_predictions
]
}
return super(GoogleClient, self).validate_response(new_response, request)

@ -149,23 +149,20 @@ class OpenAIChatClient(OpenAIClient):
return response_dict
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
batch_size: batch size for requests.
Returns:
response as dict.
"""
# Format for chat model
request_params = self._format_request_for_chat(request_params)
response_dict = await super()._arun_completion(
request_params, retry_timeout, batch_size
)
response_dict = await super()._arun_completion(request_params, retry_timeout)
# Reformat for text model
response_dict = self._format_request_from_chat(response_dict)
return response_dict

@ -157,23 +157,20 @@ class OpenAIEmbeddingClient(OpenAIClient):
return response_dict
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
batch_size: batch size for requests.
Returns:
response as dict.
"""
# Format for embedding model
request_params = self._format_request_for_embedding(request_params)
response_dict = await super()._arun_completion(
request_params, retry_timeout, batch_size
)
response_dict = await super()._arun_completion(request_params, retry_timeout)
# Reformat for text model
response_dict = self._format_request_from_embedding(response_dict)
return response_dict

@ -9,6 +9,8 @@ from manifest.clients.ai21 import AI21Client
from manifest.clients.client import Client
from manifest.clients.cohere import CohereClient
from manifest.clients.dummy import DummyClient
from manifest.clients.google import GoogleClient
from manifest.clients.google_chat import GoogleChatClient
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.huggingface_embedding import HuggingFaceEmbeddingClient
from manifest.clients.openai import OpenAIClient
@ -21,14 +23,16 @@ logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
CLIENT_CONSTRUCTORS = {
OpenAIClient.NAME: OpenAIClient,
OpenAIChatClient.NAME: OpenAIChatClient,
OpenAIEmbeddingClient.NAME: OpenAIEmbeddingClient,
CohereClient.NAME: CohereClient,
AI21Client.NAME: AI21Client,
CohereClient.NAME: CohereClient,
DummyClient.NAME: DummyClient,
GoogleClient.NAME: GoogleClient,
GoogleChatClient.NAME: GoogleChatClient,
HuggingFaceClient.NAME: HuggingFaceClient,
HuggingFaceEmbeddingClient.NAME: HuggingFaceEmbeddingClient,
DummyClient.NAME: DummyClient,
OpenAIClient.NAME: OpenAIClient,
OpenAIChatClient.NAME: OpenAIChatClient,
OpenAIEmbeddingClient.NAME: OpenAIEmbeddingClient,
TOMAClient.NAME: TOMAClient,
}

@ -511,7 +511,7 @@ class Manifest:
if not isinstance(prompts[0], str):
raise ValueError("Prompts must be a list of strings.")
# Split the prompts into chunks
# Split the prompts into chunks for connection pool
prompt_chunks: List[Tuple[Client, List[str]]] = []
if chunk_size > 0:
for i in range(0, len(prompts), chunk_size):
@ -534,7 +534,6 @@ class Manifest:
)
)
)
print(f"Running {len(tasks)} tasks across all clients.")
logger.info(f"Running {len(tasks)} tasks across all clients.")
responses = await asyncio.gather(*tasks)
final_response = Response.union_all(responses)

Loading…
Cancel
Save