mirror of https://github.com/HazyResearch/manifest
feat: openai embedding support (#75)
parent
693d105106
commit
40de0e7f59
@ -0,0 +1,139 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use OpenAI\n",
|
||||
"\n",
|
||||
"Set you `OPENAI_API_KEY` environment variable."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'model_name': 'openaiembedding', 'engine': 'text-embedding-ada-002'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from manifest import Manifest\n",
|
||||
"\n",
|
||||
"manifest = Manifest(client_name=\"openaiembedding\")\n",
|
||||
"print(manifest.client.get_model_params())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(1536,)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"emb = manifest.run(\"Is this an embedding?\")\n",
|
||||
"print(emb.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Using Locally Hosted Huggingface LM\n",
|
||||
"\n",
|
||||
"Run\n",
|
||||
"```\n",
|
||||
"python3 manifest/api/app.py --model_type huggingface --model_name_or_path EleutherAI/gpt-neo-125M --device 0\n",
|
||||
"```\n",
|
||||
"in a separate `screen` or `tmux`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'model_name': 'EleutherAI/gpt-neo-125M', 'model_path': 'EleutherAI/gpt-neo-125M'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from manifest import Manifest\n",
|
||||
"\n",
|
||||
"# Local hosted GPT Neo 125M\n",
|
||||
"manifest = Manifest(\n",
|
||||
" client_name=\"huggingfaceembedding\",\n",
|
||||
" client_connection=\"http://127.0.0.1:6001\",\n",
|
||||
" cache_name=\"sqlite\",\n",
|
||||
" cache_connection=\"my_sqlite_manifest.sqlite\"\n",
|
||||
")\n",
|
||||
"print(manifest.client.get_model_params())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"emb = manifest.run(\"Is this an embedding?\")\n",
|
||||
"emb2 = manifest.run(\"Is this an embedding?\", aggregation=\"mean\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "manifest",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,204 @@
|
||||
"""OpenAI client."""
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
|
||||
from manifest.clients.openai import OpenAIClient
|
||||
from manifest.request import EmbeddingRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_EMBEDDING_ENGINES = {
|
||||
"text-embedding-ada-002",
|
||||
}
|
||||
|
||||
|
||||
class OpenAIEmbeddingClient(OpenAIClient):
|
||||
"""OpenAI client."""
|
||||
|
||||
# User param -> (client param, default value)
|
||||
PARAMS = {
|
||||
"engine": ("model", "text-embedding-ada-002"),
|
||||
}
|
||||
REQUEST_CLS = EmbeddingRequest
|
||||
NAME = "openaiembedding"
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the OpenAI server.
|
||||
|
||||
connection_str is passed as default OPENAI_API_KEY if variable not set.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
|
||||
if self.api_key is None:
|
||||
raise ValueError(
|
||||
"OpenAI API key not set. Set OPENAI_API_KEY environment "
|
||||
"variable or pass through `client_connection`."
|
||||
)
|
||||
self.host = "https://api.openai.com/v1"
|
||||
for key in self.PARAMS:
|
||||
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
||||
if getattr(self, "engine") not in OPENAI_EMBEDDING_ENGINES:
|
||||
raise ValueError(
|
||||
f"Invalid engine {getattr(self, 'engine')}. "
|
||||
f"Must be {OPENAI_EMBEDDING_ENGINES}."
|
||||
)
|
||||
|
||||
def get_generation_url(self) -> str:
|
||||
"""Get generation URL."""
|
||||
return self.host + "/embeddings"
|
||||
|
||||
def supports_batch_inference(self) -> bool:
|
||||
"""Return whether the client supports batch inference."""
|
||||
return True
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
|
||||
|
||||
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Format response to dict.
|
||||
|
||||
Args:
|
||||
response: response
|
||||
request: request
|
||||
|
||||
Return:
|
||||
response as dict
|
||||
"""
|
||||
if "data" not in response:
|
||||
raise ValueError(f"Invalid response: {response}")
|
||||
if "usage" in response:
|
||||
# Handle splitting the usages for batch requests
|
||||
if len(response["data"]) == 1:
|
||||
if isinstance(response["usage"], list):
|
||||
response["usage"] = response["usage"][0]
|
||||
response["usage"] = [response["usage"]]
|
||||
else:
|
||||
# Try to split usage
|
||||
split_usage = self.split_usage(request, response["data"])
|
||||
if split_usage:
|
||||
response["usage"] = split_usage
|
||||
return response
|
||||
|
||||
def _format_request_for_embedding(self, request_params: Dict[str, Any]) -> Dict:
|
||||
"""Format request params for embedding.
|
||||
|
||||
Args:
|
||||
request_params: request params.
|
||||
|
||||
Returns:
|
||||
formatted request params.
|
||||
"""
|
||||
# Format for embedding model
|
||||
request_params = copy.deepcopy(request_params)
|
||||
prompt = request_params.pop("prompt")
|
||||
if isinstance(prompt, str):
|
||||
prompt_list = [prompt]
|
||||
else:
|
||||
prompt_list = prompt
|
||||
request_params["input"] = prompt_list
|
||||
return request_params
|
||||
|
||||
def _format_request_from_embedding(self, response_dict: Dict[str, Any]) -> Dict:
|
||||
"""Format response from embedding for standard response.
|
||||
|
||||
Args:
|
||||
response_dict: response.
|
||||
|
||||
Return:
|
||||
formatted response.
|
||||
"""
|
||||
new_choices = []
|
||||
response_dict = copy.deepcopy(response_dict)
|
||||
for res in response_dict.pop("data"):
|
||||
new_choices.append({"array": np.array(res["embedding"])})
|
||||
response_dict["choices"] = new_choices
|
||||
return response_dict
|
||||
|
||||
def _run_completion(
|
||||
self, request_params: Dict[str, Any], retry_timeout: int
|
||||
) -> Dict:
|
||||
"""Execute completion request.
|
||||
|
||||
Args:
|
||||
request_params: request params.
|
||||
retry_timeout: retry timeout.
|
||||
|
||||
Returns:
|
||||
response as dict.
|
||||
"""
|
||||
# Format for embedding model
|
||||
request_params = self._format_request_for_embedding(request_params)
|
||||
response_dict = super()._run_completion(request_params, retry_timeout)
|
||||
# Reformat for text model
|
||||
response_dict = self._format_request_from_embedding(response_dict)
|
||||
return response_dict
|
||||
|
||||
async def _arun_completion(
|
||||
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
|
||||
) -> Dict:
|
||||
"""Async execute completion request.
|
||||
|
||||
Args:
|
||||
request_params: request params.
|
||||
retry_timeout: retry timeout.
|
||||
batch_size: batch size for requests.
|
||||
|
||||
Returns:
|
||||
response as dict.
|
||||
"""
|
||||
# Format for embedding model
|
||||
request_params = self._format_request_for_embedding(request_params)
|
||||
response_dict = await super()._arun_completion(
|
||||
request_params, retry_timeout, batch_size
|
||||
)
|
||||
# Reformat for text model
|
||||
response_dict = self._format_request_from_embedding(response_dict)
|
||||
return response_dict
|
||||
|
||||
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
|
||||
"""Split usage into list of usages for each prompt."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(getattr(self, "engine"))
|
||||
except Exception:
|
||||
return []
|
||||
prompt = request["input"]
|
||||
if isinstance(prompt, str):
|
||||
prompts = [prompt]
|
||||
else:
|
||||
prompts = prompt
|
||||
assert len(prompts) == len(choices)
|
||||
usages = []
|
||||
for pmt in prompts:
|
||||
pmt_tokens = len(encoding.encode(pmt))
|
||||
# No completion tokens for embedding models
|
||||
chc_tokens = 0
|
||||
usage = {
|
||||
"prompt_tokens": pmt_tokens,
|
||||
"completion_tokens": chc_tokens,
|
||||
"total_tokens": pmt_tokens + chc_tokens,
|
||||
}
|
||||
usages.append(usage)
|
||||
return usages
|
Loading…
Reference in New Issue