mirror of https://github.com/HazyResearch/manifest
Merge remote-tracking branch 'upstream/main'
commit
3084200233
@ -0,0 +1,27 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from manifest import Manifest
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
manifest = Manifest(
|
||||
client_name="openaichat",
|
||||
)
|
||||
|
||||
print("Running in serial")
|
||||
prompts = [f"Tell me something interesting about {i}" for i in range(50)]
|
||||
st = time.time()
|
||||
for pmt in prompts:
|
||||
_ = manifest.run(pmt)
|
||||
print(f"For loop: {time.time() - st :.2f}")
|
||||
|
||||
print("Running with async")
|
||||
st = time.time()
|
||||
_ = asyncio.run(manifest.arun_batch(prompts, max_tokens=30))
|
||||
print(f"Async loop: {time.time() - st :.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,63 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from manifest import Manifest\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# ChatGPT tries hard not to give people programmatic access.\n",
|
||||
"# As a warning, this will open a browser window.\n",
|
||||
"# You need to install xvfb and chromium for linux headless mode to work\n",
|
||||
"# See https://github.com/terry3041/pyChatGPT\n",
|
||||
"\n",
|
||||
"# The responses are not fast\n",
|
||||
"manifest = Manifest(\n",
|
||||
" client_name=\"chatgpt\",\n",
|
||||
" client_connection=os.environ.get(\"CHATGPT_SESSION_KEY\"),\n",
|
||||
")\n",
|
||||
"print(manifest.run(\"Describe in a single, short sentence what is the best sandwhich in the world. Be short and concise.\"))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "mlcore",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.0"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "1ea9cc00d433352044b557b1784ac6e58df03de4b7bb312554014351989eb135"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,156 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use OpenAI\n",
|
||||
"\n",
|
||||
"Set you `OPENAI_API_KEY` environment variable."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'model_name': 'openaiembedding', 'engine': 'text-embedding-ada-002'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from manifest import Manifest\n",
|
||||
"\n",
|
||||
"manifest = Manifest(client_name=\"openaiembedding\")\n",
|
||||
"print(manifest.client.get_model_params())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(1536,)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"emb = manifest.run(\"Is this an embedding?\")\n",
|
||||
"print(emb.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Using Locally Hosted Huggingface LM\n",
|
||||
"\n",
|
||||
"Run\n",
|
||||
"```\n",
|
||||
"python3 manifest/api/app.py --model_type huggingface --model_name_or_path EleutherAI/gpt-neo-125M --device 0\n",
|
||||
"```\n",
|
||||
"or\n",
|
||||
"```\n",
|
||||
"python3 manifest/api/app.py --model_type sentence_transformers --model_name_or_path all-mpnet-base-v2 --device 0\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"in a separate `screen` or `tmux`. Make sure to note the port. You can change this with `export FLASK_PORT=<port>`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'model_name': 'all-mpnet-base-v2', 'model_path': 'all-mpnet-base-v2', 'client_name': 'huggingfaceembedding'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from manifest import Manifest\n",
|
||||
"\n",
|
||||
"# Local hosted GPT Neo 125M\n",
|
||||
"manifest = Manifest(\n",
|
||||
" client_name=\"huggingfaceembedding\",\n",
|
||||
" client_connection=\"http://127.0.0.1:6000\",\n",
|
||||
" cache_name=\"sqlite\",\n",
|
||||
" cache_connection=\"my_sqlite_manifest.sqlite\"\n",
|
||||
")\n",
|
||||
"print(manifest.client.get_model_params())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(768,)\n",
|
||||
"(768,) (768,)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"emb = manifest.run(\"Is this an embedding?\")\n",
|
||||
"print(emb.shape)\n",
|
||||
"\n",
|
||||
"emb = manifest.run([\"Is this an embedding?\", \"Bananas!!!\"])\n",
|
||||
"print(emb[0].shape, emb[1].shape)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "manifest",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,113 @@
|
||||
"""Sentence transformer model."""
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from manifest.api.models.model import Model
|
||||
|
||||
|
||||
class SentenceTransformerModel(Model):
|
||||
"""SentenceTransformer model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
model_type: Optional[str] = None,
|
||||
model_config: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
device: int = 0,
|
||||
use_accelerate: bool = False,
|
||||
use_parallelize: bool = False,
|
||||
use_bitsandbytes: bool = False,
|
||||
use_deepspeed: bool = False,
|
||||
perc_max_gpu_mem_red: float = 1.0,
|
||||
use_fp16: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize model.
|
||||
|
||||
All arguments will be passed in the request from Manifest.
|
||||
|
||||
Args:
|
||||
model_name_or_path: model name string.
|
||||
model_config: model config string.
|
||||
cache_dir: cache directory for model.
|
||||
device: device to use for model.
|
||||
use_accelerate: whether to use accelerate for multi-gpu inference.
|
||||
use_parallelize: use HF default parallelize
|
||||
use_bitsandbytes: use HF bits and bytes
|
||||
use_deepspeed: use deepspeed
|
||||
perc_max_gpu_mem_red: percent max memory reduction in accelerate
|
||||
use_fp16: use fp16 for model weights.
|
||||
"""
|
||||
if use_accelerate or use_parallelize or use_bitsandbytes or use_deepspeed:
|
||||
raise ValueError(
|
||||
"Cannot use accelerate or parallelize or "
|
||||
"bitsandbytes or deepspeeed with sentence transformers"
|
||||
)
|
||||
# Check if providing path
|
||||
self.model_name = model_name_or_path
|
||||
print("Model Name:", self.model_name)
|
||||
torch_device = (
|
||||
torch.device("cpu")
|
||||
if (device == -1 or not torch.cuda.is_available())
|
||||
else torch.device(f"cuda:{device}")
|
||||
)
|
||||
self.embedding_model = SentenceTransformer(self.model_name, device=torch_device)
|
||||
self.embedding_model.to(torch_device)
|
||||
self.embedding_model.eval()
|
||||
|
||||
def get_init_params(self) -> Dict:
|
||||
"""Return init params to determine what model is being used."""
|
||||
return {"model_name": self.model_name, "model_path": self.model_name}
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self, prompt: Union[str, List[str]], **kwargs: Any
|
||||
) -> List[Tuple[Any, float, List[int], List[float]]]:
|
||||
"""
|
||||
Generate the prompt from model.
|
||||
|
||||
Outputs must be generated text and score, not including prompt.
|
||||
|
||||
Args:
|
||||
prompt: promt to generate from.
|
||||
|
||||
Returns:
|
||||
list of generated text (list of length 1 for 1 generation).
|
||||
"""
|
||||
raise NotImplementedError("Generate not supported for sentence transformers")
|
||||
|
||||
@torch.no_grad()
|
||||
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
|
||||
"""
|
||||
Embed the prompt from model.
|
||||
|
||||
Args:
|
||||
prompt: promt to embed from.
|
||||
|
||||
Returns:
|
||||
list of embeddings (list of length 1 for 1 embedding).
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
return self.embedding_model.encode(prompt)
|
||||
|
||||
@torch.no_grad()
|
||||
def score_sequence(
|
||||
self, prompt: Union[str, List[str]], **kwargs: Any
|
||||
) -> List[Tuple[float, List[int], List[float]]]:
|
||||
"""
|
||||
Score a sequence of choices.
|
||||
|
||||
Args:
|
||||
prompt (:obj:`str` or :obj:`List[str]`):
|
||||
The prompt to score the choices against.
|
||||
**kwargs:
|
||||
Additional keyword arguments passed along to the :obj:`__call__` method.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Score sequence not supported for sentence transformers"
|
||||
)
|
@ -1,130 +0,0 @@
|
||||
"""Client class."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
from pyChatGPT import ChatGPT
|
||||
|
||||
from manifest.clients.client import Client
|
||||
from manifest.request import LMRequest, Request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatGPTClient(Client):
|
||||
"""ChatGPT Client class."""
|
||||
|
||||
# No params for ChatGPT
|
||||
PARAMS: Dict[str, Tuple[str, Any]] = {}
|
||||
REQUEST_CLS = LMRequest
|
||||
|
||||
def connect(
|
||||
self, connection_str: Optional[str], client_args: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Connect to ChatGPT.
|
||||
|
||||
We use https://github.com/terry3041/pyChatGPT.
|
||||
|
||||
Arsg:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.session_key = os.environ.get("CHATGPT_SESSION_KEY", connection_str)
|
||||
if self.session_key is None:
|
||||
raise ValueError(
|
||||
"ChatGPT session key not set. Set CHATGPT_SESSION_KEY environment "
|
||||
"variable or pass through `client_connection`. "
|
||||
"For details, see https://github.com/terry3041/pyChatGPT "
|
||||
"and go through instructions for getting a session key."
|
||||
)
|
||||
self.host = None
|
||||
for key in self.PARAMS:
|
||||
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
||||
self._chat_session = ChatGPT(self.session_key, verbose=False)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
self._chat_session = None
|
||||
|
||||
def clear_conversations(self) -> None:
|
||||
"""Clear conversations.
|
||||
|
||||
Only works for ChatGPT.
|
||||
"""
|
||||
self._chat_session.clear_conversations()
|
||||
|
||||
def get_generation_url(self) -> str:
|
||||
"""Get generation URL."""
|
||||
return ""
|
||||
|
||||
def get_generation_header(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get generation header.
|
||||
|
||||
Returns:
|
||||
header.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def supports_batch_inference(self) -> bool:
|
||||
"""Return whether the client supports batch inference."""
|
||||
return False
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"model_name": "chatgpt", "engine": "chatgpt"}
|
||||
|
||||
def format_response(self, response: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Format response to dict.
|
||||
|
||||
Args:
|
||||
response: response
|
||||
|
||||
Return:
|
||||
response as dict
|
||||
"""
|
||||
return {
|
||||
"model": "chatgpt",
|
||||
"choices": [
|
||||
{
|
||||
"text": response["message"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
||||
Args:
|
||||
request: request.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
if isinstance(request.prompt, list):
|
||||
raise ValueError("ChatGPT does not support batch inference.")
|
||||
|
||||
prompt = str(request.prompt)
|
||||
request_params = request.to_dict(self.PARAMS)
|
||||
|
||||
def _run_completion() -> Dict:
|
||||
try:
|
||||
res = self._chat_session.send_message(prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"ChatGPT error {e}.")
|
||||
raise e
|
||||
return self.format_response(res)
|
||||
|
||||
return _run_completion, request_params
|
@ -0,0 +1,89 @@
|
||||
"""Hugging Face client."""
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from manifest.clients.client import Client
|
||||
from manifest.request import EmbeddingRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceEmbeddingClient(Client):
|
||||
"""HuggingFaceEmbedding client."""
|
||||
|
||||
# User param -> (client param, default value)
|
||||
PARAMS: Dict[str, Tuple[str, Any]] = {}
|
||||
REQUEST_CLS = EmbeddingRequest
|
||||
NAME = "huggingfaceembedding"
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the HuggingFace url.
|
||||
|
||||
Arsg:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
if not connection_str:
|
||||
raise ValueError("Must provide connection string")
|
||||
self.host = connection_str.rstrip("/")
|
||||
for key in self.PARAMS:
|
||||
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_generation_url(self) -> str:
|
||||
"""Get generation URL."""
|
||||
return self.host + "/embed"
|
||||
|
||||
def get_generation_header(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get generation header.
|
||||
|
||||
Returns:
|
||||
header.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def supports_batch_inference(self) -> bool:
|
||||
"""Return whether the client supports batch inference."""
|
||||
return True
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
res = requests.post(self.host + "/params").json()
|
||||
res["client_name"] = self.NAME
|
||||
return res
|
||||
|
||||
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Format response to dict.
|
||||
|
||||
Args:
|
||||
response: response
|
||||
request: request
|
||||
|
||||
Return:
|
||||
response as dict
|
||||
"""
|
||||
# Convert array to np.array
|
||||
for choice in response["choices"]:
|
||||
choice["array"] = np.array(choice["array"])
|
||||
return response
|
@ -0,0 +1,157 @@
|
||||
"""OpenAIChat client."""
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from manifest.clients.openai import OpenAIClient
|
||||
from manifest.request import LMRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# List from https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
OPENAICHAT_ENGINES = {"gpt-3.5-turbo", "gpt-4", "gpt-4-32k"}
|
||||
|
||||
|
||||
class OpenAIChatClient(OpenAIClient):
|
||||
"""OpenAI Chat client."""
|
||||
|
||||
# User param -> (client param, default value)
|
||||
PARAMS = {
|
||||
"engine": ("model", "gpt-3.5-turbo"),
|
||||
"temperature": ("temperature", 1.0),
|
||||
"max_tokens": ("max_tokens", 10),
|
||||
"n": ("n", 1),
|
||||
"top_p": ("top_p", 1.0),
|
||||
"stop_sequences": ("stop", None), # OpenAI doesn't like empty lists
|
||||
"presence_penalty": ("presence_penalty", 0.0),
|
||||
"frequency_penalty": ("frequency_penalty", 0.0),
|
||||
}
|
||||
REQUEST_CLS = LMRequest
|
||||
NAME = "openaichat"
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the OpenAI server.
|
||||
|
||||
connection_str is passed as default OPENAI_API_KEY if variable not set.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
|
||||
if self.api_key is None:
|
||||
raise ValueError(
|
||||
"OpenAI API key not set. Set OPENAI_API_KEY environment "
|
||||
"variable or pass through `client_connection`."
|
||||
)
|
||||
self.host = "https://api.openai.com/v1"
|
||||
for key in self.PARAMS:
|
||||
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
||||
if getattr(self, "engine") not in OPENAICHAT_ENGINES:
|
||||
raise ValueError(
|
||||
f"Invalid engine {getattr(self, 'engine')}. "
|
||||
f"Must be {OPENAICHAT_ENGINES}."
|
||||
)
|
||||
|
||||
def get_generation_url(self) -> str:
|
||||
"""Get generation URL."""
|
||||
return self.host + "/chat/completions"
|
||||
|
||||
def supports_batch_inference(self) -> bool:
|
||||
"""Return whether the client supports batch inference."""
|
||||
return False
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
|
||||
|
||||
def _format_request_for_chat(self, request_params: Dict[str, Any]) -> Dict:
|
||||
"""Format request params for chat.
|
||||
|
||||
Args:
|
||||
request_params: request params.
|
||||
|
||||
Returns:
|
||||
formatted request params.
|
||||
"""
|
||||
# Format for chat model
|
||||
request_params = copy.deepcopy(request_params)
|
||||
prompt = request_params.pop("prompt")
|
||||
if isinstance(prompt, str):
|
||||
prompt_list = [prompt]
|
||||
else:
|
||||
prompt_list = prompt
|
||||
messages = [{"role": "user", "content": prompt} for prompt in prompt_list]
|
||||
request_params["messages"] = messages
|
||||
return request_params
|
||||
|
||||
def _format_request_from_chat(self, response_dict: Dict[str, Any]) -> Dict:
|
||||
"""Format response for standard response from chat.
|
||||
|
||||
Args:
|
||||
response_dict: response.
|
||||
|
||||
Return:
|
||||
formatted response.
|
||||
"""
|
||||
new_choices = []
|
||||
response_dict = copy.deepcopy(response_dict)
|
||||
for message in response_dict["choices"]:
|
||||
new_choices.append({"text": message["message"]["content"]})
|
||||
response_dict["choices"] = new_choices
|
||||
return response_dict
|
||||
|
||||
def _run_completion(
|
||||
self, request_params: Dict[str, Any], retry_timeout: int
|
||||
) -> Dict:
|
||||
"""Execute completion request.
|
||||
|
||||
Args:
|
||||
request_params: request params.
|
||||
retry_timeout: retry timeout.
|
||||
|
||||
Returns:
|
||||
response as dict.
|
||||
"""
|
||||
# Format for chat model
|
||||
request_params = self._format_request_for_chat(request_params)
|
||||
response_dict = super()._run_completion(request_params, retry_timeout)
|
||||
# Reformat for text model
|
||||
response_dict = self._format_request_from_chat(response_dict)
|
||||
return response_dict
|
||||
|
||||
async def _arun_completion(
|
||||
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
|
||||
) -> Dict:
|
||||
"""Async execute completion request.
|
||||
|
||||
Args:
|
||||
request_params: request params.
|
||||
retry_timeout: retry timeout.
|
||||
batch_size: batch size for requests.
|
||||
|
||||
Returns:
|
||||
response as dict.
|
||||
"""
|
||||
# Format for chat model
|
||||
request_params = self._format_request_for_chat(request_params)
|
||||
response_dict = await super()._arun_completion(
|
||||
request_params, retry_timeout, batch_size
|
||||
)
|
||||
# Reformat for text model
|
||||
response_dict = self._format_request_from_chat(response_dict)
|
||||
return response_dict
|
@ -0,0 +1,204 @@
|
||||
"""OpenAI client."""
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
|
||||
from manifest.clients.openai import OpenAIClient
|
||||
from manifest.request import EmbeddingRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_EMBEDDING_ENGINES = {
|
||||
"text-embedding-ada-002",
|
||||
}
|
||||
|
||||
|
||||
class OpenAIEmbeddingClient(OpenAIClient):
|
||||
"""OpenAI client."""
|
||||
|
||||
# User param -> (client param, default value)
|
||||
PARAMS = {
|
||||
"engine": ("model", "text-embedding-ada-002"),
|
||||
}
|
||||
REQUEST_CLS = EmbeddingRequest
|
||||
NAME = "openaiembedding"
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the OpenAI server.
|
||||
|
||||
connection_str is passed as default OPENAI_API_KEY if variable not set.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
|
||||
if self.api_key is None:
|
||||
raise ValueError(
|
||||
"OpenAI API key not set. Set OPENAI_API_KEY environment "
|
||||
"variable or pass through `client_connection`."
|
||||
)
|
||||
self.host = "https://api.openai.com/v1"
|
||||
for key in self.PARAMS:
|
||||
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
||||
if getattr(self, "engine") not in OPENAI_EMBEDDING_ENGINES:
|
||||
raise ValueError(
|
||||
f"Invalid engine {getattr(self, 'engine')}. "
|
||||
f"Must be {OPENAI_EMBEDDING_ENGINES}."
|
||||
)
|
||||
|
||||
def get_generation_url(self) -> str:
|
||||
"""Get generation URL."""
|
||||
return self.host + "/embeddings"
|
||||
|
||||
def supports_batch_inference(self) -> bool:
|
||||
"""Return whether the client supports batch inference."""
|
||||
return True
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
|
||||
|
||||
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Format response to dict.
|
||||
|
||||
Args:
|
||||
response: response
|
||||
request: request
|
||||
|
||||
Return:
|
||||
response as dict
|
||||
"""
|
||||
if "data" not in response:
|
||||
raise ValueError(f"Invalid response: {response}")
|
||||
if "usage" in response:
|
||||
# Handle splitting the usages for batch requests
|
||||
if len(response["data"]) == 1:
|
||||
if isinstance(response["usage"], list):
|
||||
response["usage"] = response["usage"][0]
|
||||
response["usage"] = [response["usage"]]
|
||||
else:
|
||||
# Try to split usage
|
||||
split_usage = self.split_usage(request, response["data"])
|
||||
if split_usage:
|
||||
response["usage"] = split_usage
|
||||
return response
|
||||
|
||||
def _format_request_for_embedding(self, request_params: Dict[str, Any]) -> Dict:
|
||||
"""Format request params for embedding.
|
||||
|
||||
Args:
|
||||
request_params: request params.
|
||||
|
||||
Returns:
|
||||
formatted request params.
|
||||
"""
|
||||
# Format for embedding model
|
||||
request_params = copy.deepcopy(request_params)
|
||||
prompt = request_params.pop("prompt")
|
||||
if isinstance(prompt, str):
|
||||
prompt_list = [prompt]
|
||||
else:
|
||||
prompt_list = prompt
|
||||
request_params["input"] = prompt_list
|
||||
return request_params
|
||||
|
||||
def _format_request_from_embedding(self, response_dict: Dict[str, Any]) -> Dict:
|
||||
"""Format response from embedding for standard response.
|
||||
|
||||
Args:
|
||||
response_dict: response.
|
||||
|
||||
Return:
|
||||
formatted response.
|
||||
"""
|
||||
new_choices = []
|
||||
response_dict = copy.deepcopy(response_dict)
|
||||
for res in response_dict.pop("data"):
|
||||
new_choices.append({"array": np.array(res["embedding"])})
|
||||
response_dict["choices"] = new_choices
|
||||
return response_dict
|
||||
|
||||
def _run_completion(
|
||||
self, request_params: Dict[str, Any], retry_timeout: int
|
||||
) -> Dict:
|
||||
"""Execute completion request.
|
||||
|
||||
Args:
|
||||
request_params: request params.
|
||||
retry_timeout: retry timeout.
|
||||
|
||||
Returns:
|
||||
response as dict.
|
||||
"""
|
||||
# Format for embedding model
|
||||
request_params = self._format_request_for_embedding(request_params)
|
||||
response_dict = super()._run_completion(request_params, retry_timeout)
|
||||
# Reformat for text model
|
||||
response_dict = self._format_request_from_embedding(response_dict)
|
||||
return response_dict
|
||||
|
||||
async def _arun_completion(
|
||||
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
|
||||
) -> Dict:
|
||||
"""Async execute completion request.
|
||||
|
||||
Args:
|
||||
request_params: request params.
|
||||
retry_timeout: retry timeout.
|
||||
batch_size: batch size for requests.
|
||||
|
||||
Returns:
|
||||
response as dict.
|
||||
"""
|
||||
# Format for embedding model
|
||||
request_params = self._format_request_for_embedding(request_params)
|
||||
response_dict = await super()._arun_completion(
|
||||
request_params, retry_timeout, batch_size
|
||||
)
|
||||
# Reformat for text model
|
||||
response_dict = self._format_request_from_embedding(response_dict)
|
||||
return response_dict
|
||||
|
||||
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
|
||||
"""Split usage into list of usages for each prompt."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(getattr(self, "engine"))
|
||||
except Exception:
|
||||
return []
|
||||
prompt = request["input"]
|
||||
if isinstance(prompt, str):
|
||||
prompts = [prompt]
|
||||
else:
|
||||
prompts = prompt
|
||||
assert len(prompts) == len(choices)
|
||||
usages = []
|
||||
for pmt in prompts:
|
||||
pmt_tokens = len(encoding.encode(pmt))
|
||||
# No completion tokens for embedding models
|
||||
chc_tokens = 0
|
||||
usage = {
|
||||
"prompt_tokens": pmt_tokens,
|
||||
"completion_tokens": chc_tokens,
|
||||
"total_tokens": pmt_tokens + chc_tokens,
|
||||
}
|
||||
usages.append(usage)
|
||||
return usages
|
@ -1,156 +0,0 @@
|
||||
"""User query session logging."""
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from manifest.caches.serializers import Serializer
|
||||
|
||||
logging.getLogger("sqlitedict").setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Session:
|
||||
"""A user session for caching requests."""
|
||||
|
||||
def __init__(self, session_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Initialize session.
|
||||
|
||||
If session_id already exists, will append to existing session.
|
||||
|
||||
Args:
|
||||
session_id: session id.
|
||||
|
||||
"""
|
||||
manifest_home = Path(os.environ.get("MANIFEST_HOME", Path.home()))
|
||||
self.db_file = manifest_home / ".manifest" / "session.db"
|
||||
self.db_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.conn = sqlite3.connect(str(self.db_file))
|
||||
self.serializer = Serializer()
|
||||
self._create_table()
|
||||
if not session_id:
|
||||
self.session_id = str(uuid.uuid4())
|
||||
self.query_id = 0
|
||||
else:
|
||||
self.session_id = session_id
|
||||
self.query_id = self._get_latest_query_id(self.session_id)
|
||||
self.query_id += 1
|
||||
logger.info(f"Starting session {self.session_id}")
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
self.conn.close()
|
||||
|
||||
@classmethod
|
||||
def get_session_keys(cls, db_file: Path) -> List[str]:
|
||||
"""Get available session keys from cached file."""
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_file))
|
||||
query = """SELECT DISTINCT session_id FROM queries"""
|
||||
cur = conn.cursor()
|
||||
res = cur.execute(query)
|
||||
return [x[0] for x in res.fetchall()]
|
||||
except sqlite3.OperationalError:
|
||||
logger.info(
|
||||
"There is no database with the 'queries' table. "
|
||||
"Are you sure you are using the right session file"
|
||||
)
|
||||
return []
|
||||
|
||||
def _execute_query(self, query: str, *args: Any) -> Any:
|
||||
"""
|
||||
Execute query with optional args.
|
||||
|
||||
Args:
|
||||
query: query to execute.
|
||||
"""
|
||||
cur = self.conn.cursor()
|
||||
res = cur.execute(query, args)
|
||||
self.conn.commit()
|
||||
return res
|
||||
|
||||
def _create_table(self) -> None:
|
||||
"""Create table if not exists."""
|
||||
query = """CREATE TABLE IF NOT EXISTS queries (
|
||||
query_id integer NOT NULL,
|
||||
session_id text NOT NULL,
|
||||
query_key text NOT NULL,
|
||||
response_key text NOT NULL
|
||||
);"""
|
||||
self._execute_query(query)
|
||||
return
|
||||
|
||||
def _get_latest_query_id(self, session_id: str) -> int:
|
||||
"""
|
||||
Get latest query id issued if resuming session.
|
||||
|
||||
If no session_id, return -1.
|
||||
|
||||
Args:
|
||||
session_id: session id.
|
||||
|
||||
Returns:
|
||||
latest query id.
|
||||
"""
|
||||
query = """SELECT query_id
|
||||
FROM queries
|
||||
WHERE session_id = ?
|
||||
ORDER BY query_id DESC LIMIT 1"""
|
||||
res = self._execute_query(query, session_id).fetchone()
|
||||
if res:
|
||||
return res[0]
|
||||
return -1
|
||||
|
||||
def log_query(
|
||||
self, query_key: Dict[str, Any], response_key: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Log the query and response.
|
||||
|
||||
Args:
|
||||
query_key: query of user (dump of request params).
|
||||
response_key: response of server (dump of response).
|
||||
"""
|
||||
query = """INSERT INTO queries VALUES (?, ?, ?, ?);"""
|
||||
self._execute_query(
|
||||
query,
|
||||
self.query_id,
|
||||
self.session_id,
|
||||
self.serializer.request_to_key(query_key),
|
||||
self.serializer.response_to_key(response_key),
|
||||
)
|
||||
self.query_id += 1
|
||||
return
|
||||
|
||||
def get_last_queries(
|
||||
self, last_n: int = -1
|
||||
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
||||
"""
|
||||
Get last n queries from current session.
|
||||
|
||||
If last_n is -1, return all queries.
|
||||
|
||||
Args:
|
||||
last_n: last n queries.
|
||||
|
||||
Returns:
|
||||
last n list of queries and outputs.
|
||||
"""
|
||||
first_query = self.query_id - last_n if last_n > 0 else -1
|
||||
query = """SELECT query_key, response_key
|
||||
FROM queries
|
||||
WHERE session_id = ? AND query_id >= ?
|
||||
ORDER BY query_id;"""
|
||||
res = self._execute_query(query, self.session_id, first_query)
|
||||
parsed_res = [
|
||||
(
|
||||
self.serializer.key_to_request(pair[0]),
|
||||
self.serializer.key_to_response(pair[1]),
|
||||
)
|
||||
for pair in res.fetchall()
|
||||
]
|
||||
return parsed_res
|
@ -1 +1 @@
|
||||
__version__ = "0.1.2"
|
||||
__version__ = "0.1.3"
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,81 +0,0 @@
|
||||
"""Test session."""
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from manifest.session import Session
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("session_cache")
|
||||
def test_init(session_cache: Path) -> None:
|
||||
"""Test session initialization."""
|
||||
session = Session()
|
||||
assert isinstance(session.conn, sqlite3.Connection)
|
||||
assert session.db_file == session_cache / ".manifest" / "session.db"
|
||||
assert session.query_id == 0
|
||||
assert (session_cache / ".manifest" / "session.db").exists()
|
||||
# Remove session cache file.
|
||||
(session_cache / ".manifest" / "session.db").unlink()
|
||||
|
||||
session = Session("dog_days")
|
||||
assert isinstance(session.conn, sqlite3.Connection)
|
||||
assert session.db_file == session_cache / ".manifest" / "session.db"
|
||||
assert session.query_id == 0
|
||||
assert session.session_id == "dog_days"
|
||||
assert (session_cache / ".manifest" / "session.db").exists()
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("session_cache")
|
||||
def test_log_query(session_cache: Path) -> None:
|
||||
"""Test session log_query."""
|
||||
session = Session()
|
||||
assert session.get_last_queries(1) == []
|
||||
|
||||
query_key = {"query": "What is your name?", "time": "now"}
|
||||
response_key = {"response": "I don't have a name", "engine": "nodel"}
|
||||
session.log_query(query_key, response_key)
|
||||
assert session.query_id == 1
|
||||
assert session.get_last_queries(1) == [(query_key, response_key)]
|
||||
|
||||
query_key2 = {"query2": "What is your name?", "time": "now"}
|
||||
response_key2 = {"response2": "I don't have a name", "engine": "nodel"}
|
||||
session.log_query(query_key2, response_key2)
|
||||
assert session.query_id == 2
|
||||
assert len(session.get_last_queries(1)) == 1
|
||||
assert session.get_last_queries(2) == [
|
||||
(query_key, response_key),
|
||||
(query_key2, response_key2),
|
||||
]
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("session_cache")
|
||||
def test_resume_query(session_cache: Path) -> None:
|
||||
"""Test session log_query."""
|
||||
session = Session(session_id="dog_days")
|
||||
query_key = {"query": "What is your name?", "time": "now"}
|
||||
response_key = {"response": "I don't have a name", "engine": "nodel"}
|
||||
session.log_query(query_key, response_key)
|
||||
session.close()
|
||||
|
||||
session = Session(session_id="dog_days")
|
||||
assert session.query_id == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("session_cache")
|
||||
def test_session_keys(session_cache: Path) -> None:
|
||||
"""Test get session keys."""
|
||||
# Assert empty before queries
|
||||
assert Session.get_session_keys(session_cache / ".manifest" / "session.db") == []
|
||||
# Add queries and make sure session is logged
|
||||
session = Session(session_id="dog_days")
|
||||
query_key = {"query": "What is your name?", "time": "now"}
|
||||
response_key = {"response": "I don't have a name", "engine": "nodel"}
|
||||
session.log_query(query_key, response_key)
|
||||
session.close()
|
||||
|
||||
assert Session.get_session_keys(session_cache / ".manifest" / "session.db") == [
|
||||
"dog_days"
|
||||
]
|
Loading…
Reference in New Issue