Merge remote-tracking branch 'upstream/main'

pull/79/head
Michael Wornow 1 year ago
commit 3084200233

@ -1,6 +1,18 @@
0.1.1 - Unreleased
---------------------
Added
^^^^^
* Async support in arun_batch
Fixed
^^^^^
* Batched runs now caches individual items
* Score prompt does not truncate outside token
Removed
^^^^^
* Deprecated chatGPT in favor of openaichat which uses OpenAI completions
* Deprecated Sessions
0.1.0 - 2022-01-31
---------------------

@ -7,6 +7,7 @@ How to make prompt programming with Foundation Models a little easier.
- [Getting Started](#getting-started)
- [Manifest](#manifest-components)
- [Local HuggingFace Models](#local-huggingface-models)
- [Embedding Models](#embedding-models)
- [Development](#development)
- [Cite](#cite)
@ -22,12 +23,6 @@ Install with diffusion support:
pip install manifest-ml[diffusers]
```
Install with ChatGPT support:
```bash
pip install manifest-ml[chatgpt]
```
This installs [pyChatGPT](https://github.com/terry3041/pyChatGPT) and uses the ChatGPT session key to start a session. This key must be set as the `CHATGPT_SESSION_KEY` environment variable or passed in with `client_connection`.
Install with HuggingFace local model support:
```bash
pip install manifest-ml[api]
@ -53,6 +48,9 @@ manifest = Manifest(
manifest.run("Why is the grass green?")
```
## Examples
We have example notebook and python scripts located at [examples](examples). These show how to use different models, model types (i.e. text, diffusers, or embedding models), and async running.
# Manifest Components
Manifest is meant to be a very light weight package to help with prompt design and iteration. Three key design decisions of Manifest are
@ -106,24 +104,6 @@ manifest = Manifest(
```
As a hint, if you want to get Redis running, see the `docker run` command below under development.
## Sessions
Each Manifest run supports a session that, in addition to a global cache, connects to a local SQLite DB to store user query history.
```python
manifest = Manifest(
client_name = "openai",
cache_name = "sqlite",
cache_connection = "mycache.sqlite",
session_id = "grass_color",
)
```
will start a Manifest session with the session name `grass_color`. This can be helpful for a user to logically keep track of sessions, see interaction history, and resume sessions if desired. If the session id provided is `_default`, we generate a random id for the user.
After a few queries, the user can explore their history
```python
manifest.get_last_queries(4)
```
will retrieve the last 4 model queries and responses.
## Running Queries
Once you have a session open, you can write and develop prompts.
@ -136,6 +116,12 @@ You can also run over multiple examples if supported by the client.
results = manifest.run(["Where are the cats?", "Where are the dogs?"])
```
We support async queries as well via
```python
import asyncio
results = asyncio.run(manifest.arun_batch(["Where are the cats?", "Where are the dogs?"]))
```
If something doesn't go right, you can also ask to get a raw manifest Response.
```python
result_object = manifest.run(["Where are the cats?", "Where are the dogs?"], return_response=True)
@ -202,6 +188,23 @@ python3 -m manifest.api.app \
--percent_max_gpu_mem_reduction 0.85
```
# Embedding Models
Manifest also supports getting embeddings from models and available APIs. We do this all through changing the `client_name` argument. You still use `run` and `abatch_run`.
To use OpenAI's embedding models, simply run
```python
manifest = Manifest(client_name="openaiembedding")
embedding_as_np = manifest.run("Get me an embedding for a bunny")
```
As explained above, you can load local HuggingFace models that give you embeddings, too. If you want to use a standard generative model, load the model as above use use `client_name="huggingfaceembedding"`. If you want to use a standard embedding model, like those from SentenceTransformers, load your local model via
```bash
python3 -m manifest.api.app \
--model_type sentence_transformers \
--model_name_or_path all-mpnet-base-v2 \
--device 0
```
# Development
Before submitting a PR, run
```bash

@ -0,0 +1,27 @@
import asyncio
import time
from manifest import Manifest
def main():
manifest = Manifest(
client_name="openaichat",
)
print("Running in serial")
prompts = [f"Tell me something interesting about {i}" for i in range(50)]
st = time.time()
for pmt in prompts:
_ = manifest.run(pmt)
print(f"For loop: {time.time() - st :.2f}")
print("Running with async")
st = time.time()
_ = asyncio.run(manifest.arun_batch(prompts, max_tokens=30))
print(f"Async loop: {time.time() - st :.2f}")
if __name__ == "__main__":
main()

@ -1,63 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from manifest import Manifest\n",
"import os\n",
"\n",
"# ChatGPT tries hard not to give people programmatic access.\n",
"# As a warning, this will open a browser window.\n",
"# You need to install xvfb and chromium for linux headless mode to work\n",
"# See https://github.com/terry3041/pyChatGPT\n",
"\n",
"# The responses are not fast\n",
"manifest = Manifest(\n",
" client_name=\"chatgpt\",\n",
" client_connection=os.environ.get(\"CHATGPT_SESSION_KEY\"),\n",
")\n",
"print(manifest.run(\"Describe in a single, short sentence what is the best sandwhich in the world. Be short and concise.\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "mlcore",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "1ea9cc00d433352044b557b1784ac6e58df03de4b7bb312554014351989eb135"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -0,0 +1,156 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use OpenAI\n",
"\n",
"Set you `OPENAI_API_KEY` environment variable."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'model_name': 'openaiembedding', 'engine': 'text-embedding-ada-002'}\n"
]
}
],
"source": [
"from manifest import Manifest\n",
"\n",
"manifest = Manifest(client_name=\"openaiembedding\")\n",
"print(manifest.client.get_model_params())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1536,)\n"
]
}
],
"source": [
"emb = manifest.run(\"Is this an embedding?\")\n",
"print(emb.shape)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using Locally Hosted Huggingface LM\n",
"\n",
"Run\n",
"```\n",
"python3 manifest/api/app.py --model_type huggingface --model_name_or_path EleutherAI/gpt-neo-125M --device 0\n",
"```\n",
"or\n",
"```\n",
"python3 manifest/api/app.py --model_type sentence_transformers --model_name_or_path all-mpnet-base-v2 --device 0\n",
"```\n",
"\n",
"in a separate `screen` or `tmux`. Make sure to note the port. You can change this with `export FLASK_PORT=<port>`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'model_name': 'all-mpnet-base-v2', 'model_path': 'all-mpnet-base-v2', 'client_name': 'huggingfaceembedding'}\n"
]
}
],
"source": [
"from manifest import Manifest\n",
"\n",
"# Local hosted GPT Neo 125M\n",
"manifest = Manifest(\n",
" client_name=\"huggingfaceembedding\",\n",
" client_connection=\"http://127.0.0.1:6000\",\n",
" cache_name=\"sqlite\",\n",
" cache_connection=\"my_sqlite_manifest.sqlite\"\n",
")\n",
"print(manifest.client.get_model_params())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(768,)\n",
"(768,) (768,)\n"
]
}
],
"source": [
"emb = manifest.run(\"Is this an embedding?\")\n",
"print(emb.shape)\n",
"\n",
"emb = manifest.run([\"Is this an embedding?\", \"Bananas!!!\"])\n",
"print(emb[0].shape, emb[1].shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "manifest",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -2,6 +2,5 @@
from manifest.manifest import Manifest
from manifest.request import Request
from manifest.response import Response
from manifest.session import Session
__all__ = ["Manifest", "Response", "Session"]
__all__ = ["Manifest", "Response", "Request"]

@ -15,6 +15,7 @@ from manifest.api.models.huggingface import (
CrossModalEncoderModel,
TextGenerationModel,
)
from manifest.api.models.sentence_transformer import SentenceTransformerModel
from manifest.api.response import ModelResponse
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@ -27,6 +28,7 @@ model_type = None
PORT = int(os.environ.get("FLASK_PORT", 5000))
MODEL_CONSTRUCTORS = {
"huggingface": TextGenerationModel,
"sentence_transformers": SentenceTransformerModel,
"huggingface_crossmodal": CrossModalEncoderModel,
"diffuser": DiffuserModel,
}
@ -184,13 +186,13 @@ def completions() -> Response:
if model_type == "diffuser":
# Assign None logprob as it's not supported in diffusers
results = [
{"array": r[0], "logprob": None, "token_logprobs": None}
{"array": r[0], "logprob": None, "tokens": None, "token_logprobs": None}
for r in result_gens
]
res_type = "image_generation"
else:
results = [
{"text": r[0], "logprob": r[1], "token_logprobs": r[2]}
{"text": r[0], "logprob": r[1], "tokens": r[2], "token_logprobs": r[3]}
for r in result_gens
]
res_type = "text_completion"
@ -210,11 +212,14 @@ def completions() -> Response:
@app.route("/embed", methods=["POST"])
def embed() -> Dict:
def embed() -> Response:
"""Get embed for generation."""
modality = request.json["modality"]
if "modality" in request.json:
modality = request.json["modality"]
else:
modality = "text"
if modality == "text":
prompts = request.json["prompts"]
prompts = request.json["prompt"]
elif modality == "image":
import base64
@ -222,19 +227,36 @@ def embed() -> Dict:
prompts = [
Image.open(io.BytesIO(base64.b64decode(data)))
for data in request.json["prompts"]
for data in request.json["prompt"]
]
else:
raise ValueError("modality must be text or image")
results = []
embeddings = model.embed(prompts)
for embedding in embeddings:
results.append(embedding.tolist())
try:
results = []
embeddings = model.embed(prompts)
for embedding in embeddings:
results.append(
{
"array": embedding,
"logprob": None,
"tokens": None,
"token_logprobs": None,
}
)
# transform the result into the openai format
# return Response(results, response_type="text_completion").__dict__()
return {"result": results}
return Response(
json.dumps(
ModelResponse(results, response_type="embedding_generation").__dict__()
),
status=200,
)
except Exception as e:
logger.error(e)
return Response(
json.dumps({"message": str(e)}),
status=400,
)
@app.route("/score_sequence", methods=["POST"])
@ -253,7 +275,8 @@ def score_sequence() -> Response:
{
"text": prompt if isinstance(prompt, str) else prompt[i],
"logprob": r[0],
"token_logprobs": r[1],
"tokens": r[1],
"token_logprobs": r[2],
}
for i, r in enumerate(score_list)
]

@ -1,7 +1,8 @@
"""Huggingface model."""
"""Diffuser model."""
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers import StableDiffusionPipeline
@ -74,7 +75,7 @@ class DiffuserModel(Model):
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[float]]]:
) -> List[Tuple[Any, float, List[int], List[float]]]:
"""
Generate the prompt from model.
@ -91,12 +92,25 @@ class DiffuserModel(Model):
prompt = [prompt]
result = self.pipeline(prompt, output_type="np.array", **kwargs)
# Return None for logprobs and token logprobs
return [(im, None, None) for im in result["images"]]
return [(im, None, None, None) for im in result["images"]]
@torch.no_grad()
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
"""
Embed the prompt from model.
Args:
prompt: promt to embed from.
Returns:
list of embeddings (list of length 1 for 1 embedding).
"""
raise NotImplementedError("Embed not supported for diffusers")
@torch.no_grad()
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[float, List[float]]]:
) -> List[Tuple[float, List[int], List[float]]]:
"""
Score a sequence of choices.

@ -191,6 +191,7 @@ class GenerationPipeline:
"logprobs": logits[
range(num_generated_tokens), i, output_seq[-num_generated_tokens:]
].tolist(),
"tokens": output_seq[-num_generated_tokens:].tolist(),
}
for i, output_seq in enumerate(output_dict.sequences)
]
@ -547,20 +548,37 @@ class TextGenerationModel(HuggingFaceModel):
@torch.no_grad()
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
"""
Compute embedding for prompts.
Embed the prompt from model.
Args:
prompt: promt to generate from.
prompt: promt to embed from.
Returns:
embedding
list of embeddings (list of length 1 for 1 embedding).
"""
pass
if isinstance(prompt, str):
prompt = [prompt]
encoded_prompt = self.pipeline.tokenizer(
prompt,
max_length=self.pipeline.max_length,
truncation=True,
padding=True,
return_tensors="pt",
)
encoded_prompt = encoded_prompt.to(self.pipeline.device)
# Get last hidden state
output = self.pipeline.model( # type: ignore
**encoded_prompt,
output_hidden_states=True,
return_dict=True,
)
last_hidden_state = output["hidden_states"][-1][:, -1, :]
return last_hidden_state.cpu().numpy()
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[float]]]:
) -> List[Tuple[Any, float, List[int], List[float]]]:
"""
Generate the prompt from model.
@ -589,6 +607,7 @@ class TextGenerationModel(HuggingFaceModel):
(
cast(str, r["generated_text"]),
sum(cast(List[float], r["logprobs"])),
cast(List[int], r["tokens"]),
cast(List[float], r["logprobs"]),
)
for r in result
@ -598,7 +617,7 @@ class TextGenerationModel(HuggingFaceModel):
@torch.no_grad()
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[float, List[float]]]:
) -> List[Tuple[float, List[int], List[float]]]:
"""
Score a sequence of choices.
Args:
@ -622,22 +641,21 @@ class TextGenerationModel(HuggingFaceModel):
**encoded_prompt,
).logits
# For causal decoders, shift logts and labels
labels_attention_mask = encoded_prompt["attention_mask"].unsqueeze(-1)[
..., 1:, :
]
masked_log_probs = (
labels_attention_mask.float()
* torch.log_softmax(logits.float(), dim=-1)[..., :-1, :]
labels_attention_mask = encoded_prompt["attention_mask"].unsqueeze(-1)
masked_log_probs = labels_attention_mask.float() * torch.log_softmax(
logits.float(), dim=-1
)
seq_token_log_probs = torch.gather(
masked_log_probs, -1, encoded_prompt["labels"][..., 1:].unsqueeze(-1)
masked_log_probs, -1, encoded_prompt["labels"].unsqueeze(-1)
)
seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
seq_log_prob = seq_token_log_probs.sum(dim=-1)
return [
(seq, seq_token)
for seq, seq_token in zip(
seq_log_prob.tolist(), seq_token_log_probs.tolist()
(seq, tokens, seq_token)
for seq, tokens, seq_token in zip(
seq_log_prob.tolist(),
encoded_prompt["input_ids"].tolist(),
seq_token_log_probs.tolist(),
)
]

@ -1,14 +1,12 @@
"""Model class."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple, Union
import numpy as np
class Model(ABC):
class Model:
"""Model class."""
@abstractmethod
def __init__(
self,
model_name_or_path: str,
@ -41,14 +39,13 @@ class Model(ABC):
"""
raise NotImplementedError()
@abstractmethod
def get_init_params(self) -> Dict:
"""Return init params to determine what model is being used."""
raise NotImplementedError()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[float]]]:
) -> List[Tuple[Any, float, List[int], List[float]]]:
"""
Generate the prompt from model.
@ -59,26 +56,26 @@ class Model(ABC):
Returns:
list of generated text (list of length 1 for 1 generation).
Each item is the response, answer logprob,
Each item is the response, answer logprob, list of tokens,
and list of logprobs for each token.
"""
raise NotImplementedError()
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
"""
Compute embedding for prompts.
Embed the prompt from model.
Args:
prompt: promt to generate from.
prompt: promt to embed from.
Returns:
embedding
list of embeddings (list of length 1 for 1 embedding).
"""
raise NotImplementedError()
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[float, List[float]]]:
) -> List[Tuple[float, List[int], List[float]]]:
"""
Score a sequence of choices.
@ -89,6 +86,6 @@ class Model(ABC):
Additional keyword arguments passed along to the :obj:`__call__` method.
Returns:
Tuple of scores for each choice and logprobs for the tokens of each choice.
Tuple of total score, tokens, and probs per token.
"""
raise NotImplementedError()

@ -0,0 +1,113 @@
"""Sentence transformer model."""
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from manifest.api.models.model import Model
class SentenceTransformerModel(Model):
"""SentenceTransformer model."""
def __init__(
self,
model_name_or_path: str,
model_type: Optional[str] = None,
model_config: Optional[str] = None,
cache_dir: Optional[str] = None,
device: int = 0,
use_accelerate: bool = False,
use_parallelize: bool = False,
use_bitsandbytes: bool = False,
use_deepspeed: bool = False,
perc_max_gpu_mem_red: float = 1.0,
use_fp16: bool = False,
):
"""
Initialize model.
All arguments will be passed in the request from Manifest.
Args:
model_name_or_path: model name string.
model_config: model config string.
cache_dir: cache directory for model.
device: device to use for model.
use_accelerate: whether to use accelerate for multi-gpu inference.
use_parallelize: use HF default parallelize
use_bitsandbytes: use HF bits and bytes
use_deepspeed: use deepspeed
perc_max_gpu_mem_red: percent max memory reduction in accelerate
use_fp16: use fp16 for model weights.
"""
if use_accelerate or use_parallelize or use_bitsandbytes or use_deepspeed:
raise ValueError(
"Cannot use accelerate or parallelize or "
"bitsandbytes or deepspeeed with sentence transformers"
)
# Check if providing path
self.model_name = model_name_or_path
print("Model Name:", self.model_name)
torch_device = (
torch.device("cpu")
if (device == -1 or not torch.cuda.is_available())
else torch.device(f"cuda:{device}")
)
self.embedding_model = SentenceTransformer(self.model_name, device=torch_device)
self.embedding_model.to(torch_device)
self.embedding_model.eval()
def get_init_params(self) -> Dict:
"""Return init params to determine what model is being used."""
return {"model_name": self.model_name, "model_path": self.model_name}
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[int], List[float]]]:
"""
Generate the prompt from model.
Outputs must be generated text and score, not including prompt.
Args:
prompt: promt to generate from.
Returns:
list of generated text (list of length 1 for 1 generation).
"""
raise NotImplementedError("Generate not supported for sentence transformers")
@torch.no_grad()
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
"""
Embed the prompt from model.
Args:
prompt: promt to embed from.
Returns:
list of embeddings (list of length 1 for 1 embedding).
"""
if isinstance(prompt, str):
prompt = [prompt]
return self.embedding_model.encode(prompt)
@torch.no_grad()
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[float, List[int], List[float]]]:
"""
Score a sequence of choices.
Args:
prompt (:obj:`str` or :obj:`List[str]`):
The prompt to score the choices against.
**kwargs:
Additional keyword arguments passed along to the :obj:`__call__` method.
"""
raise NotImplementedError(
"Score sequence not supported for sentence transformers"
)

@ -16,17 +16,24 @@ class ModelResponse:
"text_completion",
"prompt_logit_score",
"image_generation",
"embedding_generation",
}:
raise ValueError(
f"Invalid response type: {self.response_type}. "
"Must be one of: text_completion, prompt_logit_score, image_generation."
"Must be one of: text_completion, prompt_logit_score, "
"image_generation, embedding_generation."
)
self.response_id = str(uuid.uuid4())
self.created = int(time.time())
def __dict__(self) -> Dict[str, Any]: # type: ignore
"""Return dictionary representation of response."""
key = "text" if self.response_type != "image_generation" else "array"
key = (
"text"
if self.response_type
not in {"prompt_logit_score", "image_generation", "embedding_generation"}
else "array"
)
return {
"id": self.response_id,
"object": self.response_type,
@ -36,6 +43,7 @@ class ModelResponse:
{
key: result[key],
"logprob": result["logprob"],
"tokens": result["tokens"],
"token_logprobs": result["token_logprobs"],
}
if key == "text"

@ -1,26 +1,16 @@
"""Cache for queries and responses."""
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union
from manifest.caches.serializers import ArraySerializer, Serializer
from manifest.response import Response
RESPONSE_CONSTRUCTORS = {
"diffuser": {
"generation_key": "choices",
"logits_key": "token_logprobs",
"item_key": "array",
},
"tomadiffuser": {
"generation_key": "choices",
"logits_key": "token_logprobs",
"item_key": "array",
},
}
from typing import Any, Dict, Union
from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer, Serializer
from manifest.response import RESPONSE_CONSTRUCTORS, Response
CACHE_CONSTRUCTOR = {
"diffuser": ArraySerializer,
"tomadiffuser": ArraySerializer,
# Non-text return type caches
ARRAY_CACHE_TYPES = {
"diffuser",
"tomadiffuser",
"openaiembedding",
"huggingfaceembedding",
}
@ -34,17 +24,21 @@ class Cache(ABC):
cache_args: Dict[str, Any] = {},
):
"""
Initialize client.
Initialize cache.
Args:
connection_str: connection string.
client_name: name of client.
cache_args: arguments for cache.
cache_args are passed to client as default parameters.
cache_args are any arguments needed to initialize the cache.
For clients like OpenAI that do not require a connection,
the connection_str can be None.
Further, cache_args can contain `array_serializer` as a string
for embedding or image return types (e.g. diffusers) with values
as `local_file` or `byte_string`. `local_file` will save the
array in a local file and cache a pointer to the file.
`byte_string` will convert the array to a byte string and cache
the entire byte string. `byte_string` is default.
Args:
connection_str: connection string for client.
@ -52,7 +46,22 @@ class Cache(ABC):
"""
self.client_name = client_name
self.connect(connection_str, cache_args)
self.serializer = CACHE_CONSTRUCTOR.get(client_name, Serializer)()
if self.client_name in ARRAY_CACHE_TYPES:
array_serializer = cache_args.pop("array_serializer", "byte_string")
if array_serializer not in ["local_file", "byte_string"]:
raise ValueError(
"array_serializer must be local_file or byte_string,"
f" not {array_serializer}"
)
self.serializer = (
ArraySerializer()
if array_serializer == "local_file"
else NumpyByteSerializer()
)
else:
# If user has array_serializer type, it will throw an error as
# it is not recognized for non-array return types.
self.serializer = Serializer()
@abstractmethod
def close(self) -> None:
@ -101,20 +110,35 @@ class Cache(ABC):
"""Commit any results."""
raise NotImplementedError()
def get(
self, request: Dict, overwrite_cache: bool, compute: Callable[[], Dict]
) -> Response:
"""Get the result of request (by calling compute as needed)."""
def get(self, request: Dict) -> Union[Response, None]:
"""Get the result of request (by calling compute as needed).
Args:
request: request to get.
response: response to get.
Returns:
Response object or None if not in cache.
"""
key = self.serializer.request_to_key(request)
cached_response = self.get_key(key)
if cached_response and not overwrite_cache:
if cached_response:
cached = True
response = self.serializer.key_to_response(cached_response)
else:
# Type Response
response = compute()
self.set_key(key, self.serializer.response_to_key(response))
cached = False
return Response(
response, cached, request, **RESPONSE_CONSTRUCTORS.get(self.client_name, {})
)
return Response(
response,
cached,
request,
**RESPONSE_CONSTRUCTORS.get(self.client_name, {}),
)
return None
def set(self, request: Dict, response: Dict) -> None:
"""Set the value for the key.
Args:
request: request to set.
response: response to set.
"""
key = self.serializer.request_to_key(request)
self.set_key(key, self.serializer.response_to_key(response))

@ -100,7 +100,7 @@ class PostgresCache(Cache):
table: table to get key in.
"""
request = (
self.session.query(Request)
self.session.query(Request) # type: ignore
.filter_by(key=self._hash_key(key, table))
.first()
)
@ -119,7 +119,7 @@ class PostgresCache(Cache):
table: table to set key in.
"""
key = self._hash_key(key, table)
request = self.session.query(Request).filter_by(key=key).first()
request = self.session.query(Request).filter_by(key=key).first() # type: ignore
if request:
request.response = value # type: ignore
else:

@ -1,10 +1,12 @@
"""Serializer."""
import io
import json
import os
from pathlib import Path
from typing import Dict
import numpy as np
import xxhash
from manifest.caches.array_cache import ArrayCache
@ -62,6 +64,64 @@ class Serializer:
return json.loads(key)
class NumpyByteSerializer(Serializer):
"""Serializer by casting array to byte string."""
def response_to_key(self, response: Dict) -> str:
"""
Normalize a response into a key.
Args:
response: response to normalize.
Returns:
normalized key.
"""
# Assume response is a dict with keys "choices" -> List dicts
# with keys "array".
choices = response["choices"]
# We don't want to modify the response in place
# but we want to avoid calling deepcopy on an array
del response["choices"]
response_copy = response.copy()
response["choices"] = choices
response_copy["choices"] = []
for choice in choices:
if "array" not in choice:
raise ValueError(
f"Choice with keys {choice.keys()} does not have array key."
)
arr = choice["array"]
# Avoid copying an array
del choice["array"]
new_choice = choice.copy()
choice["array"] = arr
with io.BytesIO() as f:
np.savez_compressed(f, data=arr)
hash_str = f.getvalue().hex()
new_choice["array"] = hash_str
response_copy["choices"].append(new_choice)
return json.dumps(response_copy, sort_keys=True)
def key_to_response(self, key: str) -> Dict:
"""
Convert the normalized version to the response.
Args:
key: normalized key to convert.
Returns:
unnormalized response dict.
"""
response = json.loads(key)
for choice in response["choices"]:
hash_str = choice["array"]
byte_str = bytes.fromhex(hash_str)
with io.BytesIO(byte_str) as f:
choice["array"] = np.load(f)["data"]
return response
class ArraySerializer(Serializer):
"""Serializer for array."""

@ -27,9 +27,9 @@ class AI21Client(Client):
"n": ("numResults", 1),
"top_p": ("topP", 1.0),
"stop_sequences": ("stopSequences", []),
"client_timeout": ("client_timeout", 60), # seconds
}
REQUEST_CLS = LMRequest
NAME = "ai21"
def connect(
self,
@ -92,14 +92,15 @@ class AI21Client(Client):
Returns:
model params.
"""
return {"model_name": "ai21", "engine": getattr(self, "engine")}
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict) -> Dict[str, Any]:
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict

@ -1,130 +0,0 @@
"""Client class."""
import logging
import os
from typing import Any, Callable, Dict, Optional, Tuple
from pyChatGPT import ChatGPT
from manifest.clients.client import Client
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
class ChatGPTClient(Client):
"""ChatGPT Client class."""
# No params for ChatGPT
PARAMS: Dict[str, Tuple[str, Any]] = {}
REQUEST_CLS = LMRequest
def connect(
self, connection_str: Optional[str], client_args: Dict[str, Any]
) -> None:
"""
Connect to ChatGPT.
We use https://github.com/terry3041/pyChatGPT.
Arsg:
connection_str: connection string.
client_args: client arguments.
"""
self.session_key = os.environ.get("CHATGPT_SESSION_KEY", connection_str)
if self.session_key is None:
raise ValueError(
"ChatGPT session key not set. Set CHATGPT_SESSION_KEY environment "
"variable or pass through `client_connection`. "
"For details, see https://github.com/terry3041/pyChatGPT "
"and go through instructions for getting a session key."
)
self.host = None
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
self._chat_session = ChatGPT(self.session_key, verbose=False)
def close(self) -> None:
"""Close the client."""
self._chat_session = None
def clear_conversations(self) -> None:
"""Clear conversations.
Only works for ChatGPT.
"""
self._chat_session.clear_conversations()
def get_generation_url(self) -> str:
"""Get generation URL."""
return ""
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return {}
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return False
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "chatgpt", "engine": "chatgpt"}
def format_response(self, response: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
Return:
response as dict
"""
return {
"model": "chatgpt",
"choices": [
{
"text": response["message"],
}
],
}
def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Args:
request: request.
Returns:
request function that takes no input.
request parameters as dict.
"""
if isinstance(request.prompt, list):
raise ValueError("ChatGPT does not support batch inference.")
prompt = str(request.prompt)
request_params = request.to_dict(self.PARAMS)
def _run_completion() -> Dict:
try:
res = self._chat_session.send_message(prompt)
except Exception as e:
logger.error(f"ChatGPT error {e}.")
raise e
return self.format_response(res)
return _run_completion, request_params

@ -1,21 +1,42 @@
"""Client class."""
import asyncio
import copy
import logging
import math
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import aiohttp
import requests
from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential
from manifest.request import Request
from manifest.request import DEFAULT_REQUEST_KEYS, NOT_CACHE_KEYS, Request
from manifest.response import RESPONSE_CONSTRUCTORS, Response
logger = logging.getLogger(__name__)
def retry_if_ratelimit(retry_base: RetryCallState) -> bool:
"""Return whether to retry if ratelimited."""
try:
if isinstance(retry_base.outcome.exception(), requests.exceptions.HTTPError):
exception = cast(
requests.exceptions.HTTPError, retry_base.outcome.exception()
)
if exception.response.status_code == 429: # type: ignore
return True
except Exception:
pass
return False
class Client(ABC):
"""Client class."""
# Must be overridden by child class
PARAMS: Dict[str, Tuple[str, Any]] = {}
REQUEST_CLS = Request
NAME: str = None
def __init__(
self, connection_str: Optional[str] = None, client_args: Dict[str, Any] = {}
@ -93,7 +114,7 @@ class Client(ABC):
"""
return list(self.PARAMS.keys())
def get_request_params(
def get_request(
self, prompt: Union[str, List[str]], request_args: Dict[str, Any]
) -> Request:
"""
@ -109,23 +130,149 @@ class Client(ABC):
params = {"prompt": prompt}
for key in self.PARAMS:
params[key] = request_args.pop(key, getattr(self, key))
return self.REQUEST_CLS(**params)
for key in DEFAULT_REQUEST_KEYS:
if key not in params and key in request_args:
params[key] = request_args.pop(key)
return self.REQUEST_CLS(**params) # type: ignore
def format_response(self, response: Dict) -> Dict[str, Any]:
def get_request_params(self, request: Request) -> Dict[str, Any]:
"""Get request params.
Add default keys that we need for requests such as batch_size.
We drop these before sending to the model.
"""
params_to_add = DEFAULT_REQUEST_KEYS.copy()
params_to_add.update(self.PARAMS)
request_params = request.to_dict(params_to_add)
return request_params
def get_cache_key(self, request: Request) -> Dict[str, Any]:
"""Get cache key for request.
Skip keys that are not cache keys such as batch_size.
"""
request_params = self.get_request_params(request)
for key in NOT_CACHE_KEYS:
request_params.pop(key, None)
request_params.update(self.get_model_params())
return request_params
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
"""Split usage into list of usages for each prompt."""
return []
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict
"""
if "choices" not in response:
raise ValueError(f"Invalid response: {response}")
if "usage" in response:
# Handle splitting the usages for batch requests
if len(response["choices"]) == 1:
if isinstance(response["usage"], list):
response["usage"] = response["usage"][0]
response["usage"] = [response["usage"]]
else:
# Try to split usage
split_usage = self.split_usage(request, response["choices"])
if split_usage:
response["usage"] = split_usage
return response
def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]:
def split_requests(
self, request_params: Dict[str, Any], batch_size: int, key: str = "prompt"
) -> List[Dict[str, Any]]:
"""Split request into batch_sized request.
Args:
request_params: request params.
batch_size: batch size for requests.
key: key to batch over
Returns:
list of request params.
"""
data = copy.deepcopy(request_params[key])
data_size = len(request_params[key])
request_params_list = []
for i in range(0, data_size, batch_size):
params = copy.deepcopy(request_params)
params[key] = data[i] if batch_size == 1 else data[i : i + batch_size]
request_params_list.append(params)
return request_params_list
@retry(
reraise=True,
retry=retry_if_ratelimit,
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(10),
)
def _run_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
Returns:
response as dict.
"""
post_str = self.get_generation_url()
res = requests.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
)
try:
res.raise_for_status()
except requests.exceptions.HTTPError:
logger.error(res.json())
raise requests.exceptions.HTTPError(res.json())
return self.format_response(res.json(), request_params)
@retry(
reraise=True,
retry=retry_if_ratelimit,
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(10),
)
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
batch_size: batch size for requests.
Returns:
response as dict.
"""
post_str = self.get_generation_url()
async with aiohttp.ClientSession(timeout=retry_timeout) as session:
async with session.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
) as res:
res.raise_for_status()
res_json = await res.json(content_type=None)
return self.format_response(res_json, request_params)
def run_request(self, request: Request) -> Response:
"""
Get request string function.
@ -133,44 +280,85 @@ class Client(ABC):
request: request.
Returns:
request function that takes no input.
request parameters as dict.
response.
"""
if isinstance(request.prompt, list) and not self.supports_batch_inference():
raise ValueError(
f"{self.__class__.__name__} does not support batch inference."
)
request_params = request.to_dict(self.PARAMS)
request_params = self.get_request_params(request)
# Take the default keys we need and drop the rest as they
# are not part of the model request.
retry_timeout = request_params.pop("client_timeout")
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
response_dict = self._run_completion(request_params, retry_timeout)
return Response(
response_dict,
cached=False,
request_params=request_params,
**RESPONSE_CONSTRUCTORS.get(self.NAME, {}), # type: ignore
)
def _run_completion() -> Dict:
post_str = self.get_generation_url()
try:
res = requests.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
)
res.raise_for_status()
except requests.Timeout as e:
logger.error(
f"{self.__class__.__name__} request timed out."
" Increase client_timeout."
)
raise e
except requests.exceptions.HTTPError:
logger.error(res.json())
raise requests.exceptions.HTTPError(res.json())
return self.format_response(res.json())
return _run_completion, request_params
async def arun_batch_request(self, request: Request) -> Response:
"""
Get async request string function.
Args:
request: request.
Returns:
response.
"""
required_batch_size = None
if not self.supports_batch_inference():
required_batch_size = 1
if not isinstance(request.prompt, list):
raise AssertionError(
"request.prompt must be a list for async batch inference."
)
request_params = self.get_request_params(request)
# Take the default keys we need and drop the rest as they
# are not part of the model request.
retry_timeout = request_params.pop("client_timeout")
batch_size = request_params.pop("batch_size")
batch_size = required_batch_size or batch_size
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
num_batches = len(request.prompt) // batch_size
if len(request.prompt) % batch_size != 0:
batch_size = int(math.ceil(len(request.prompt) / (num_batches + 1)))
request_batches = self.split_requests(request_params, batch_size)
all_tasks = [
asyncio.create_task(self._arun_completion(batch, retry_timeout, batch_size))
for batch in request_batches
]
responses = await asyncio.gather(*all_tasks)
# Flatten responses
choices = []
usages = []
for res_dict in responses:
choices.extend(res_dict["choices"])
if "usage" in res_dict:
usages.extend(res_dict["usage"])
final_response_dict = {"choices": choices}
if usages:
final_response_dict["usage"] = usages
return Response(
final_response_dict,
cached=False,
request_params=request_params,
**RESPONSE_CONSTRUCTORS.get(self.NAME, {}), # type: ignore
)
def get_score_prompt_request(
self,
request: Request,
) -> Tuple[Callable[[], Dict], Dict]:
) -> Response:
"""
Get the logit score of the prompt via a forward pass of the model.

@ -26,9 +26,9 @@ class CohereClient(Client):
"frequency_penalty": ("frequency_penalty", 0.0),
"presence_penalty": ("presence_penalty", 0.0),
"stop_sequences": ("stop_sequences", None),
"client_timeout": ("client_timeout", 60), # seconds
}
REQUEST_CLS = LMRequest
NAME = "cohere"
def connect(
self,
@ -91,14 +91,15 @@ class CohereClient(Client):
Returns:
model params.
"""
return {"model_name": "cohere", "engine": getattr(self, "engine")}
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict) -> Dict[str, Any]:
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict

@ -22,9 +22,9 @@ class DiffuserClient(Client):
"n": ("num_images_per_prompt", 1),
"guidance_scale": ("guidance_scale", 7.5),
"eta": ("eta", 0.0),
"client_timeout": ("client_timeout", 120), # seconds
}
REQUEST_CLS = DiffusionRequest
NAME = "diffuser"
def connect(
self,
@ -82,15 +82,17 @@ class DiffuserClient(Client):
Returns:
model params.
"""
res = requests.post(self.host + "/params")
return res.json()
res = requests.post(self.host + "/params").json()
res["client_name"] = self.NAME
return res
def format_response(self, response: Dict) -> Dict[str, Any]:
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict

@ -1,9 +1,10 @@
"""Dummy client."""
import logging
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Dict, Optional
from manifest.clients.client import Client
from manifest.request import LMRequest, Request
from manifest.response import Response
logger = logging.getLogger(__name__)
@ -16,6 +17,7 @@ class DummyClient(Client):
"n": ("num_results", 1),
}
REQUEST_CLS = LMRequest
NAME = "dummy"
def connect(
self,
@ -67,7 +69,7 @@ class DummyClient(Client):
"""
return {"engine": "dummy"}
def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]:
def run_request(self, request: Request) -> Response:
"""
Get request string function.
@ -84,19 +86,32 @@ class DummyClient(Client):
num_results = 1
request_params = request.to_dict(self.PARAMS)
def _run_completion() -> Dict:
return {
"choices": [{"text": "hello"}]
* int(request_params["num_results"])
* num_results
}
response_dict = {
"choices": [{"text": "hello"}]
* int(request_params["num_results"])
* num_results,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}]
* int(request_params["num_results"])
* num_results,
}
return Response(response_dict, False, request_params)
async def arun_batch_request(self, request: Request) -> Response:
"""
Get async request string function.
Args:
request: request.
return _run_completion, request_params
Returns:
response.
"""
return self.run_request(request)
def get_score_prompt_request(
self,
request: Request,
) -> Tuple[Callable[[], Dict], Dict]:
) -> Response:
"""
Get the logit score of the prompt via a forward pass of the model.
@ -113,17 +128,15 @@ class DummyClient(Client):
num_results = 1
request_params = {"prompt": request.prompt}
def _run_completion() -> Dict:
return {
"choices": [
{
"text": request.prompt
if isinstance(request.prompt, str)
else request.prompt[i],
"logprob": 0.3,
}
for i in range(num_results)
]
}
return _run_completion, request_params
response_dict = {
"choices": [
{
"text": request.prompt
if isinstance(request.prompt, str)
else request.prompt[i],
"logprob": 0.3,
}
for i in range(num_results)
]
}
return Response(response_dict, False, request_params)

@ -1,11 +1,12 @@
"""Hugging Face client."""
import logging
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Dict, Optional
import requests
from manifest.clients.client import Client
from manifest.request import LMRequest, Request
from manifest.request import DEFAULT_REQUEST_KEYS, LMRequest, Request
from manifest.response import Response
logger = logging.getLogger(__name__)
@ -22,9 +23,9 @@ class HuggingFaceClient(Client):
"top_k": ("top_k", 50),
"repetition_penalty": ("repetition_penalty", 1.0),
"do_sample": ("do_sample", True),
"client_timeout": ("client_timeout", 120), # seconds
}
REQUEST_CLS = LMRequest
NAME = "huggingface"
def connect(
self,
@ -75,13 +76,14 @@ class HuggingFaceClient(Client):
Returns:
model params.
"""
res = requests.post(self.host + "/params")
return res.json()
res = requests.post(self.host + "/params").json()
res["client_name"] = self.NAME
return res
def get_score_prompt_request(
self,
request: Request,
) -> Tuple[Callable[[], Dict], Dict]:
) -> Response:
"""
Get the logit score of the prompt via a forward pass of the model.
@ -92,26 +94,26 @@ class HuggingFaceClient(Client):
request function that takes no input.
request parameters as dict.
"""
request_params = request.to_dict(self.PARAMS)
request_params = self.get_request_params(request)
retry_timeout = request_params.pop("client_timeout")
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
# Do not add params like we do with request as the model isn't sampling
request_params = {"prompt": request.prompt}
def _run_completion() -> Dict:
post_str = self.host + "/score_sequence"
try:
res = requests.post(
post_str,
json=request_params,
timeout=retry_timeout,
)
res.raise_for_status()
except requests.Timeout as e:
logger.error("HF request timed out. Increase client_timeout.")
raise e
except requests.exceptions.HTTPError as e:
logger.error(res.text)
raise e
return res.json()
return _run_completion, request_params
post_str = self.host + "/score_sequence"
try:
res = requests.post(
post_str,
json=request_params,
timeout=retry_timeout,
)
res.raise_for_status()
except requests.Timeout as e:
logger.error("HF request timed out. Increase client_timeout.")
raise e
except requests.exceptions.HTTPError as e:
logger.error(res.text)
raise e
response_dict = res.json()
return Response(response_dict, cached=False, request_params=request_params)

@ -0,0 +1,89 @@
"""Hugging Face client."""
import logging
from typing import Any, Dict, Optional, Tuple
import numpy as np
import requests
from manifest.clients.client import Client
from manifest.request import EmbeddingRequest
logger = logging.getLogger(__name__)
class HuggingFaceEmbeddingClient(Client):
"""HuggingFaceEmbedding client."""
# User param -> (client param, default value)
PARAMS: Dict[str, Tuple[str, Any]] = {}
REQUEST_CLS = EmbeddingRequest
NAME = "huggingfaceembedding"
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the HuggingFace url.
Arsg:
connection_str: connection string.
client_args: client arguments.
"""
if not connection_str:
raise ValueError("Must provide connection string")
self.host = connection_str.rstrip("/")
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
def close(self) -> None:
"""Close the client."""
pass
def get_generation_url(self) -> str:
"""Get generation URL."""
return self.host + "/embed"
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return {}
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return True
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
res = requests.post(self.host + "/params").json()
res["client_name"] = self.NAME
return res
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict
"""
# Convert array to np.array
for choice in response["choices"]:
choice["array"] = np.array(choice["array"])
return response

@ -1,10 +1,12 @@
"""OpenAI client."""
import logging
import os
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Type
import tiktoken
from manifest.clients.client import Client
from manifest.request import LMRequest
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
@ -38,9 +40,9 @@ class OpenAIClient(Client):
"stop_sequences": ("stop", None), # OpenAI doesn't like empty lists
"presence_penalty": ("presence_penalty", 0.0),
"frequency_penalty": ("frequency_penalty", 0.0),
"client_timeout": ("client_timeout", 60), # seconds
}
REQUEST_CLS = LMRequest
REQUEST_CLS: Type[Request] = LMRequest
NAME = "openai"
def connect(
self,
@ -101,4 +103,29 @@ class OpenAIClient(Client):
Returns:
model params.
"""
return {"model_name": "openai", "engine": getattr(self, "engine")}
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
"""Split usage into list of usages for each prompt."""
try:
encoding = tiktoken.encoding_for_model(getattr(self, "engine"))
except Exception:
return []
prompt = request["prompt"]
# If n > 1 and prompt is a string, we need to split it into a list
if isinstance(prompt, str):
prompts = [prompt] * len(choices)
else:
prompts = prompt
assert len(prompts) == len(choices)
usages = []
for pmt, chc in zip(prompts, choices):
pmt_tokens = len(encoding.encode(pmt))
chc_tokens = len(encoding.encode(chc["text"])) # type: ignore
usage = {
"prompt_tokens": pmt_tokens,
"completion_tokens": chc_tokens,
"total_tokens": pmt_tokens + chc_tokens,
}
usages.append(usage)
return usages

@ -0,0 +1,157 @@
"""OpenAIChat client."""
import copy
import logging
import os
from typing import Any, Dict, Optional
from manifest.clients.openai import OpenAIClient
from manifest.request import LMRequest
logger = logging.getLogger(__name__)
# List from https://platform.openai.com/docs/models/model-endpoint-compatibility
OPENAICHAT_ENGINES = {"gpt-3.5-turbo", "gpt-4", "gpt-4-32k"}
class OpenAIChatClient(OpenAIClient):
"""OpenAI Chat client."""
# User param -> (client param, default value)
PARAMS = {
"engine": ("model", "gpt-3.5-turbo"),
"temperature": ("temperature", 1.0),
"max_tokens": ("max_tokens", 10),
"n": ("n", 1),
"top_p": ("top_p", 1.0),
"stop_sequences": ("stop", None), # OpenAI doesn't like empty lists
"presence_penalty": ("presence_penalty", 0.0),
"frequency_penalty": ("frequency_penalty", 0.0),
}
REQUEST_CLS = LMRequest
NAME = "openaichat"
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the OpenAI server.
connection_str is passed as default OPENAI_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
self.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
if self.api_key is None:
raise ValueError(
"OpenAI API key not set. Set OPENAI_API_KEY environment "
"variable or pass through `client_connection`."
)
self.host = "https://api.openai.com/v1"
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAICHAT_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. "
f"Must be {OPENAICHAT_ENGINES}."
)
def get_generation_url(self) -> str:
"""Get generation URL."""
return self.host + "/chat/completions"
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return False
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def _format_request_for_chat(self, request_params: Dict[str, Any]) -> Dict:
"""Format request params for chat.
Args:
request_params: request params.
Returns:
formatted request params.
"""
# Format for chat model
request_params = copy.deepcopy(request_params)
prompt = request_params.pop("prompt")
if isinstance(prompt, str):
prompt_list = [prompt]
else:
prompt_list = prompt
messages = [{"role": "user", "content": prompt} for prompt in prompt_list]
request_params["messages"] = messages
return request_params
def _format_request_from_chat(self, response_dict: Dict[str, Any]) -> Dict:
"""Format response for standard response from chat.
Args:
response_dict: response.
Return:
formatted response.
"""
new_choices = []
response_dict = copy.deepcopy(response_dict)
for message in response_dict["choices"]:
new_choices.append({"text": message["message"]["content"]})
response_dict["choices"] = new_choices
return response_dict
def _run_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
Returns:
response as dict.
"""
# Format for chat model
request_params = self._format_request_for_chat(request_params)
response_dict = super()._run_completion(request_params, retry_timeout)
# Reformat for text model
response_dict = self._format_request_from_chat(response_dict)
return response_dict
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
batch_size: batch size for requests.
Returns:
response as dict.
"""
# Format for chat model
request_params = self._format_request_for_chat(request_params)
response_dict = await super()._arun_completion(
request_params, retry_timeout, batch_size
)
# Reformat for text model
response_dict = self._format_request_from_chat(response_dict)
return response_dict

@ -0,0 +1,204 @@
"""OpenAI client."""
import copy
import logging
import os
from typing import Any, Dict, List, Optional
import numpy as np
import tiktoken
from manifest.clients.openai import OpenAIClient
from manifest.request import EmbeddingRequest
logger = logging.getLogger(__name__)
OPENAI_EMBEDDING_ENGINES = {
"text-embedding-ada-002",
}
class OpenAIEmbeddingClient(OpenAIClient):
"""OpenAI client."""
# User param -> (client param, default value)
PARAMS = {
"engine": ("model", "text-embedding-ada-002"),
}
REQUEST_CLS = EmbeddingRequest
NAME = "openaiembedding"
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the OpenAI server.
connection_str is passed as default OPENAI_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
self.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
if self.api_key is None:
raise ValueError(
"OpenAI API key not set. Set OPENAI_API_KEY environment "
"variable or pass through `client_connection`."
)
self.host = "https://api.openai.com/v1"
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAI_EMBEDDING_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. "
f"Must be {OPENAI_EMBEDDING_ENGINES}."
)
def get_generation_url(self) -> str:
"""Get generation URL."""
return self.host + "/embeddings"
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return True
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict
"""
if "data" not in response:
raise ValueError(f"Invalid response: {response}")
if "usage" in response:
# Handle splitting the usages for batch requests
if len(response["data"]) == 1:
if isinstance(response["usage"], list):
response["usage"] = response["usage"][0]
response["usage"] = [response["usage"]]
else:
# Try to split usage
split_usage = self.split_usage(request, response["data"])
if split_usage:
response["usage"] = split_usage
return response
def _format_request_for_embedding(self, request_params: Dict[str, Any]) -> Dict:
"""Format request params for embedding.
Args:
request_params: request params.
Returns:
formatted request params.
"""
# Format for embedding model
request_params = copy.deepcopy(request_params)
prompt = request_params.pop("prompt")
if isinstance(prompt, str):
prompt_list = [prompt]
else:
prompt_list = prompt
request_params["input"] = prompt_list
return request_params
def _format_request_from_embedding(self, response_dict: Dict[str, Any]) -> Dict:
"""Format response from embedding for standard response.
Args:
response_dict: response.
Return:
formatted response.
"""
new_choices = []
response_dict = copy.deepcopy(response_dict)
for res in response_dict.pop("data"):
new_choices.append({"array": np.array(res["embedding"])})
response_dict["choices"] = new_choices
return response_dict
def _run_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
Returns:
response as dict.
"""
# Format for embedding model
request_params = self._format_request_for_embedding(request_params)
response_dict = super()._run_completion(request_params, retry_timeout)
# Reformat for text model
response_dict = self._format_request_from_embedding(response_dict)
return response_dict
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
batch_size: batch size for requests.
Returns:
response as dict.
"""
# Format for embedding model
request_params = self._format_request_for_embedding(request_params)
response_dict = await super()._arun_completion(
request_params, retry_timeout, batch_size
)
# Reformat for text model
response_dict = self._format_request_from_embedding(response_dict)
return response_dict
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
"""Split usage into list of usages for each prompt."""
try:
encoding = tiktoken.encoding_for_model(getattr(self, "engine"))
except Exception:
return []
prompt = request["input"]
if isinstance(prompt, str):
prompts = [prompt]
else:
prompts = prompt
assert len(prompts) == len(choices)
usages = []
for pmt in prompts:
pmt_tokens = len(encoding.encode(pmt))
# No completion tokens for embedding models
chc_tokens = 0
usage = {
"prompt_tokens": pmt_tokens,
"completion_tokens": chc_tokens,
"total_tokens": pmt_tokens + chc_tokens,
}
usages.append(usage)
return usages

@ -31,9 +31,9 @@ class TOMAClient(Client):
"top_p": ("top_p", 0.9),
"top_k": ("top_k", 40),
"stop_sequences": ("stop", []),
"client_timeout": ("client_timeout", 120), # seconds
}
REQUEST_CLS = LMRequest
NAME = "toma"
def connect(
self,
@ -121,7 +121,7 @@ class TOMAClient(Client):
Returns:
model params.
"""
return {"model_name": "toma", "engine": getattr(self, "engine")}
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def get_model_heartbeats(self) -> Dict[str, Dict]:
"""
@ -143,12 +143,13 @@ class TOMAClient(Client):
}
return heartbeats
def format_response(self, response: Dict) -> Dict[str, Any]:
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict

@ -30,9 +30,9 @@ class TOMADiffuserClient(TOMAClient):
"width": ("width", 512),
"n": ("n", 1),
"guidance_scale": ("guidance_scale", 7.5),
"client_timeout": ("client_timeout", 120), # seconds
}
REQUEST_CLS = DiffusionRequest # type: ignore
NAME = "tomadiffuser"
def get_model_params(self) -> Dict:
"""
@ -44,14 +44,15 @@ class TOMADiffuserClient(TOMAClient):
Returns:
model params.
"""
return {"model_name": "tomadiffuser", "engine": getattr(self, "engine")}
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def format_response(self, response: Dict) -> Dict[str, Any]:
def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
request: request
Return:
response as dict

@ -1,4 +1,5 @@
"""Manifest class."""
import copy
import logging
from typing import Any, Dict, List, Optional, Tuple, Union, cast
@ -12,41 +13,37 @@ from manifest.clients.ai21 import AI21Client
from manifest.clients.cohere import CohereClient
from manifest.clients.dummy import DummyClient
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.huggingface_embedding import HuggingFaceEmbeddingClient
from manifest.clients.openai import OpenAIClient
from manifest.clients.openai_chat import OpenAIChatClient
from manifest.clients.openai_embedding import OpenAIEmbeddingClient
from manifest.clients.toma import TOMAClient
from manifest.request import Request
from manifest.response import Response
from manifest.session import Session
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
CLIENT_CONSTRUCTORS = {
"openai": OpenAIClient,
"cohere": CohereClient,
"ai21": AI21Client,
"huggingface": HuggingFaceClient,
"dummy": DummyClient,
"toma": TOMAClient,
OpenAIClient.NAME: OpenAIClient,
OpenAIChatClient.NAME: OpenAIChatClient,
OpenAIEmbeddingClient.NAME: OpenAIEmbeddingClient,
CohereClient.NAME: CohereClient,
AI21Client.NAME: AI21Client,
HuggingFaceClient.NAME: HuggingFaceClient,
HuggingFaceEmbeddingClient.NAME: HuggingFaceEmbeddingClient,
DummyClient.NAME: DummyClient,
TOMAClient.NAME: TOMAClient,
}
# ChatGPT
try:
from manifest.clients.chatgpt import ChatGPTClient
CLIENT_CONSTRUCTORS["chatgpt"] = ChatGPTClient
except Exception:
logger.info("ChatGPT not installed. Skipping import.")
pass
# Diffusion
DIFFUSION_CLIENTS = ["diffuser", "tomadiffuser"]
try:
from manifest.clients.diffuser import DiffuserClient
from manifest.clients.toma_diffuser import TOMADiffuserClient
CLIENT_CONSTRUCTORS["diffuser"] = DiffuserClient
CLIENT_CONSTRUCTORS["tomadiffuser"] = TOMADiffuserClient
CLIENT_CONSTRUCTORS[DiffuserClient.NAME] = DiffuserClient
CLIENT_CONSTRUCTORS[TOMADiffuserClient.NAME] = TOMADiffuserClient
except Exception:
logger.info("Diffusion not supported. Skipping import.")
pass
@ -69,7 +66,6 @@ class Manifest:
client_connection: Optional[str] = None,
cache_name: str = "noop",
cache_connection: Optional[str] = None,
session_id: Optional[str] = None,
stop_token: str = "",
**kwargs: Any,
):
@ -81,9 +77,6 @@ class Manifest:
client_connection: connection string for client.
cache_name: name of cache.
cache_connection: connection string for cache.
session_id: session id for user session cache.
None (default) means no session logging.
"_default" means generate new session id.
stop_token: stop token prompt generation.
Can be overridden in run
@ -114,17 +107,6 @@ class Manifest:
self.client = CLIENT_CONSTRUCTORS[self.client_name]( # type: ignore
client_connection, client_args=kwargs
)
if session_id is not None:
if self.client_name == "diffuser":
raise NotImplementedError(
"Session logging not implemented for Diffuser client."
)
if session_id == "_default":
# Set session_id to None for Session random id
session_id = None
self.session = Session(session_id)
else:
self.session = None
if len(kwargs) > 0:
raise ValueError(f"{list(kwargs.items())} arguments are not recognized.")
@ -195,11 +177,144 @@ class Manifest:
logger.warning(f"{list(request_unused_kwargs)} arguments are unused.")
return
def _split_cached_requests(
self,
request: Request,
overwrite_cache: bool,
) -> Tuple[Dict[int, Response], Request]:
"""Split a request into cached responses and Requests to run.
Args:
request: request object.
overwrite_cache: whether to overwrite cache.
Returns:
cached_idx_to_response: dict of cached responses.
new_request: request object with only prompts to run.
"""
cached_idx_to_response: Dict[int, Response] = {}
new_request = copy.deepcopy(request)
if not overwrite_cache:
if isinstance(new_request.prompt, list):
new_request.prompt = []
for idx, prompt_str in enumerate(request.prompt):
single_request = copy.deepcopy(request)
single_request.prompt = prompt_str
possible_response = self.cache.get(
self.client.get_cache_key(single_request)
)
if possible_response:
cached_idx_to_response[idx] = possible_response
else:
new_request.prompt.append(prompt_str)
else:
possible_response = self.cache.get(
self.client.get_cache_key(new_request)
)
if possible_response:
cached_idx_to_response[0] = possible_response
new_request.prompt = None
return cached_idx_to_response, new_request
def _stitch_responses_and_cache(
self,
request: Request,
response: Union[Response, None],
cached_idx_to_response: Dict[int, Response],
) -> Response:
"""Stich together the cached and uncached responses."""
# We stitch the responses (the choices) here from both the new request the
# cached entries.
all_model_choices = []
all_usages = []
all_input_prompts = []
response_idx = 0
number_prompts = len(cached_idx_to_response)
single_output = False
if response:
if isinstance(response.get_request()["prompt"], str):
single_output = True
number_prompts += 1
else:
number_prompts += len(response.get_request()["prompt"])
response_gen_key = None
response_logits_key = None
response_item_key = None
for idx in range(number_prompts):
if idx in cached_idx_to_response:
cached_res = cached_idx_to_response[idx]
response_gen_key = cached_res.generation_key
response_logits_key = cached_res.logits_key
response_item_key = cached_res.item_key
response_usage_key = cached_res.usage_key
all_input_prompts.append(cached_res.get_request()["prompt"])
json_response = cached_res.get_json_response()
if request.n == 1:
assert (
len(json_response[response_gen_key]) == 1
), "cached response should have only one choice"
all_model_choices.extend(json_response[response_gen_key])
if response_usage_key:
all_usages.extend(json_response[response_usage_key])
else:
assert response is not None, "response should not be None"
response = cast(Response, response)
response_gen_key = response.generation_key
response_logits_key = response.logits_key
response_item_key = response.item_key
response_usage_key = response.usage_key
# the choices list in the response is a flat one.
# length is request.n * num_prompts
current_choices = response.get_json_response()[response_gen_key][
response_idx * request.n : (response_idx + 1) * request.n
]
all_model_choices.extend(current_choices)
if isinstance(response.get_request()["prompt"], list):
prompt = response.get_request()["prompt"][response_idx]
else:
prompt = str(response.get_request()["prompt"])
if response_usage_key:
usage = response.get_json_response()[response_usage_key][
response_idx * request.n : (response_idx + 1) * request.n
]
all_usages.extend(usage)
all_input_prompts.append(prompt)
# set cache
new_request = copy.deepcopy(request)
new_request.prompt = prompt
cache_key = self.client.get_cache_key(new_request)
new_response_key = copy.deepcopy(response.get_json_response())
new_response_key[response_gen_key] = current_choices
if response_usage_key:
new_response_key[response_usage_key] = usage
self.cache.set(cache_key, new_response_key)
response_idx += 1
new_request = copy.deepcopy(request)
new_request.prompt = (
all_input_prompts
if len(all_input_prompts) > 1 or not single_output
else all_input_prompts[0]
)
new_response = {response_gen_key: all_model_choices}
if response_usage_key:
new_response[response_usage_key] = all_usages
response_obj = Response(
new_response,
cached=len(cached_idx_to_response) > 0,
request_params=self.client.get_cache_key(new_request),
generation_key=response_gen_key,
logits_key=response_logits_key,
item_key=response_item_key,
usage_key=response_usage_key,
)
return response_obj
def run(
self,
prompt: Union[str, List[str]],
overwrite_cache: bool = False,
run_id: Optional[str] = None,
stop_token: Optional[str] = None,
return_response: bool = False,
**kwargs: Any,
@ -210,7 +325,6 @@ class Manifest:
Args:
prompt: prompt(s) to run.
overwrite_cache: whether to overwrite cache.
run_id: run id for cache to repeat same run.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
@ -223,29 +337,88 @@ class Manifest:
stop_token = stop_token if stop_token is not None else self.stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self.client.get_request_params(prompt, kwargs)
request_params = self.client.get_request(prompt, kwargs)
# Avoid nested list of results - enforce n = 1 for batch
if is_batch and request_params.n > 1:
raise ValueError("Batch mode does not support n > 1.")
possible_request, full_kwargs = self.client.get_request(request_params)
self._validate_kwargs(kwargs, request_params)
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, overwrite_cache
)
# If not None value or empty list - run new request
if request_params.prompt:
response = self.client.run_request(request_params)
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
# Extract text results
if return_response:
return final_response
else:
return final_response.get_response(stop_token, is_batch)
async def arun_batch(
self,
prompts: List[str],
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
**kwargs: Any,
) -> Union[List[str], List[np.ndarray], Response]:
"""
Run a batch of prompts with async.
Args:
prompts: prompts to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
Returns:
response from prompt.
"""
stop_token = stop_token if stop_token is not None else self.stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self.client.get_request(prompts, kwargs)
# Avoid nested list of results - enforce n = 1 for batch
if request_params.n > 1:
raise ValueError("Batch mode does not support n > 1.")
self._validate_kwargs(kwargs, request_params)
# Create cacke key
cache_key = full_kwargs.copy()
# Make query model dependent
cache_key.update(self.client.get_model_params())
if run_id:
cache_key["run_id"] = run_id
response_obj = self.cache.get(cache_key, overwrite_cache, possible_request)
# Log session dictionary values
if self.session:
self.session.log_query(cache_key, response_obj.to_dict())
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, overwrite_cache
)
# If not None value or empty list - run new request
if request_params.prompt:
response = await self.client.arun_batch_request(request_params)
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
# Extract text results
if return_response:
return response_obj
return final_response
else:
return response_obj.get_response(stop_token, is_batch)
return cast(
Union[List[str], List[np.ndarray]],
final_response.get_response(stop_token, True),
)
def score_prompt(
self,
@ -258,8 +431,6 @@ class Manifest:
Returns the response object with logits of the prompt.
Prompt scoring is not part of a session cache.
Args:
prompt: prompt(s) to run.
overwrite_cache: whether to overwrite cache.
@ -268,66 +439,31 @@ class Manifest:
response from prompt.
"""
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self.client.get_request_params(prompt, kwargs)
request_params = self.client.get_request(prompt, kwargs)
request_params.request_type = "score_prompt"
if request_params.n > 1:
raise ValueError("Sequence scoring does not support n > 1.")
try:
possible_request, full_kwargs = cast(
HuggingFaceClient, self.client
).get_score_prompt_request(request_params)
except AttributeError:
raise ValueError("`score_prompt` only supported for HF models.")
self._validate_kwargs(kwargs, request_params)
# Create cacke key
cache_key = full_kwargs.copy()
# Make query model dependent
cache_key.update(self.client.get_model_params())
response_obj = self.cache.get(cache_key, overwrite_cache, possible_request)
return response_obj.to_dict()
def get_last_queries(
self,
last_n: int = -1,
return_raw_values: bool = False,
stop_token: Optional[str] = None,
) -> List[Tuple[Any, Any]]:
"""
Get last n queries from current session.
If last_n is -1, return all queries. By default will only return the
prompt text and result text unles return_raw_values is False.
Args:
last_n: last n queries.
return_raw_values: whether to return raw values as dicts.
stop_token: stop token for prompt results to be applied to all results.
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, overwrite_cache
)
# If not None value or empty list - run new request
if request_params.prompt:
try:
response = cast(
HuggingFaceClient, self.client
).get_score_prompt_request(request_params)
except AttributeError:
raise ValueError("`score_prompt` only supported for HF models.")
else:
# Nothing to run
response = None
Returns:
last n list of queries and outputs.
"""
if self.session is None:
raise ValueError(
"Session was not initialized. Set `session_id` when loading Manifest."
)
stop_token = stop_token if stop_token is not None else self.stop_token
last_queries = self.session.get_last_queries(last_n)
if not return_raw_values:
last_queries = [
(
query["prompt"],
Response.from_dict(response).get_response(
stop_token, is_batch=isinstance(query["prompt"], list)
),
) # type: ignore
for query, response in last_queries
]
return last_queries
def open_explorer(self) -> None:
"""Open the explorer for jupyter widget."""
# Open explorer
# TODO: implement
pass
final_response = self._stitch_responses_and_cache(
request=request_params,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
return final_response.to_dict()

@ -3,6 +3,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
NOT_CACHE_KEYS = {"client_timeout", "batch_size"}
DEFAULT_REQUEST_KEYS = {
"client_timeout": ("client_timeout", 60), # seconds
"batch_size": ("batch_size", 1),
"run_id": ("run_id", None),
"request_type": ("request_type", None),
}
class Request(BaseModel):
"""Request object."""
@ -17,7 +25,16 @@ class Request(BaseModel):
n: int = 1
# Timeout
client_timeout: int = 60
client_timeout: int = 120
# Run id used to repeat run with same parameters
run_id: Optional[str] = None
# Batch size for async batch run
batch_size: int = 8
# Request type None is for completion. Used to scoring prompt
request_type: str = None
def to_dict(
self, allowable_keys: Dict[str, Tuple[str, Any]] = None, add_prompt: bool = True
@ -79,9 +96,18 @@ class LMRequest(Request):
frequency_penalty: float = 0
class EmbeddingRequest(Request):
"""Embedding Request object."""
pass
class DiffusionRequest(Request):
"""Diffusion Model Request object."""
# Request type
request_type: str = "diffusion"
# Number of steps
num_inference_steps: int = 50

@ -4,6 +4,25 @@ from typing import Any, Dict, List, Union
import numpy as np
RESPONSE_CONSTRUCTORS = {
"diffuser": {
"logits_key": "token_logprobs",
"item_key": "array",
},
"tomadiffuser": {
"logits_key": "token_logprobs",
"item_key": "array",
},
"openaiembedding": {
"logits_key": "token_logprobs",
"item_key": "array",
},
"huggingfaceembedding": {
"logits_key": "token_logprobs",
"item_key": "array",
},
}
class NumpyArrayEncoder(json.JSONEncoder):
"""Numpy array encoder."""
@ -20,12 +39,13 @@ class Response:
def __init__(
self,
response: Dict,
response: Dict, # TODO: make pydantic model
cached: bool,
request_params: Dict,
request_params: Dict, # TODO: use request pydantic model
generation_key: str = "choices",
logits_key: str = "token_logprobs",
item_key: str = "text",
usage_key: str = "usage",
):
"""
Initialize response.
@ -41,6 +61,7 @@ class Response:
self.generation_key = generation_key
self.logits_key = logits_key
self.item_key = item_key
self.usage_key = usage_key
self.item_dtype = None
if isinstance(response, dict):
self._response = response
@ -55,6 +76,16 @@ class Response:
"Response must be serialized to a dict with a nonempty"
f" list of choices. Response is\n{self._response}."
)
# Turn off usage if it is not in response
if self.usage_key not in self._response:
self.usage_key = None
else:
if not isinstance(self._response[self.usage_key], list):
raise ValueError(
"Response must be a list with usage dicts, one per choice."
f" Response is\n{self._response}."
)
if self.item_key not in self._response[self.generation_key][0]:
raise ValueError(
"Response must be serialized to a dict with a "

@ -1,156 +0,0 @@
"""User query session logging."""
import logging
import os
import sqlite3
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from manifest.caches.serializers import Serializer
logging.getLogger("sqlitedict").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
class Session:
"""A user session for caching requests."""
def __init__(self, session_id: Optional[str] = None) -> None:
"""
Initialize session.
If session_id already exists, will append to existing session.
Args:
session_id: session id.
"""
manifest_home = Path(os.environ.get("MANIFEST_HOME", Path.home()))
self.db_file = manifest_home / ".manifest" / "session.db"
self.db_file.parent.mkdir(parents=True, exist_ok=True)
self.conn = sqlite3.connect(str(self.db_file))
self.serializer = Serializer()
self._create_table()
if not session_id:
self.session_id = str(uuid.uuid4())
self.query_id = 0
else:
self.session_id = session_id
self.query_id = self._get_latest_query_id(self.session_id)
self.query_id += 1
logger.info(f"Starting session {self.session_id}")
return
def close(self) -> None:
"""Close the client."""
self.conn.close()
@classmethod
def get_session_keys(cls, db_file: Path) -> List[str]:
"""Get available session keys from cached file."""
try:
conn = sqlite3.connect(str(db_file))
query = """SELECT DISTINCT session_id FROM queries"""
cur = conn.cursor()
res = cur.execute(query)
return [x[0] for x in res.fetchall()]
except sqlite3.OperationalError:
logger.info(
"There is no database with the 'queries' table. "
"Are you sure you are using the right session file"
)
return []
def _execute_query(self, query: str, *args: Any) -> Any:
"""
Execute query with optional args.
Args:
query: query to execute.
"""
cur = self.conn.cursor()
res = cur.execute(query, args)
self.conn.commit()
return res
def _create_table(self) -> None:
"""Create table if not exists."""
query = """CREATE TABLE IF NOT EXISTS queries (
query_id integer NOT NULL,
session_id text NOT NULL,
query_key text NOT NULL,
response_key text NOT NULL
);"""
self._execute_query(query)
return
def _get_latest_query_id(self, session_id: str) -> int:
"""
Get latest query id issued if resuming session.
If no session_id, return -1.
Args:
session_id: session id.
Returns:
latest query id.
"""
query = """SELECT query_id
FROM queries
WHERE session_id = ?
ORDER BY query_id DESC LIMIT 1"""
res = self._execute_query(query, session_id).fetchone()
if res:
return res[0]
return -1
def log_query(
self, query_key: Dict[str, Any], response_key: Dict[str, Any]
) -> None:
"""
Log the query and response.
Args:
query_key: query of user (dump of request params).
response_key: response of server (dump of response).
"""
query = """INSERT INTO queries VALUES (?, ?, ?, ?);"""
self._execute_query(
query,
self.query_id,
self.session_id,
self.serializer.request_to_key(query_key),
self.serializer.response_to_key(response_key),
)
self.query_id += 1
return
def get_last_queries(
self, last_n: int = -1
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
"""
Get last n queries from current session.
If last_n is -1, return all queries.
Args:
last_n: last n queries.
Returns:
last n list of queries and outputs.
"""
first_query = self.query_id - last_n if last_n > 0 else -1
query = """SELECT query_key, response_key
FROM queries
WHERE session_id = ? AND query_id >= ?
ORDER BY query_id;"""
res = self._execute_query(query, self.session_id, first_query)
parsed_res = [
(
self.serializer.key_to_request(pair[0]),
self.serializer.key_to_response(pair[1]),
)
for pair in res.fetchall()
]
return parsed_res

@ -1 +1 @@
__version__ = "0.1.2"
__version__ = "0.1.3"

@ -9,6 +9,7 @@ module = [
"deepspeed",
"numpy",
"diffusers",
"sentence_transformers",
"sqlitedict",
"dill",
"accelerate",

@ -21,7 +21,7 @@ with open(ver_path) as ver_file:
NAME = "manifest-ml"
DESCRIPTION = "Manifest for Prompting Foundation Models."
URL = "https://github.com/HazyResearch/manifest"
EMAIL = "lorr1@cs.stanford.edu"
EMAIL = "laurel.orr@numbersstation.ai"
AUTHOR = "Laurel Orr"
REQUIRES_PYTHON = ">=3.8.0"
VERSION = main_ns["__version__"]
@ -32,27 +32,28 @@ REQUIRED = [
"pydantic>=1.9.0",
"redis>=4.3.1",
"requests>=2.27.1",
"aiohttp>=3.8.0",
"sqlitedict>=2.0.0",
"tenacity>=8.2.0",
"tiktoken>=0.3.0",
"xxhash>=3.0.0",
]
# What packages are optional?
EXTRAS = {
"api": [
"accelerate>=0.10.0",
"deepspeed>=0.7.0",
"diffusers>=0.6.0",
"Flask>=2.1.2",
"accelerate>=0.10.0",
"transformers>=4.20.0,<4.26.0",
"sentence_transformers>=2.2.0",
"torch>=1.8.0",
"transformers>=4.20.0,<4.26.0",
],
"app": [
"fastapi>=0.70.0",
"uvicorn>=0.18.0",
],
"chatgpt": [
"pyChatGPT>=0.4.3",
],
"diffusers": [
"pillow>=9.0.0",
],
@ -60,7 +61,7 @@ EXTRAS = {
"pg8000",
"cloud-sql-python-connector[pg8000]>=1.0.0",
"sqlalchemy",
],
],
"dev": [
"autopep8>=1.6.0",
"black>=22.3.0",

@ -42,10 +42,3 @@ def postgres_cache(monkeypatch: pytest.MonkeyPatch) -> Generator[str, None, None
engine = sqlalchemy.create_engine(url)
monkeypatch.setattr(sqlalchemy, "create_engine", lambda *args, **kwargs: engine)
return engine # type: ignore
@pytest.fixture
def session_cache(tmpdir: str) -> Generator[Path, None, None]:
"""Session cache dir."""
os.environ["MANIFEST_HOME"] = str(tmpdir)
yield Path(tmpdir)

@ -1,5 +1,5 @@
"""Cache test."""
from typing import cast
from typing import Dict, cast
import numpy as np
import pytest
@ -13,12 +13,15 @@ from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
def _get_postgres_cache(**kwargs) -> Cache: # type: ignore
def _get_postgres_cache(
client_name: str = "", cache_args: Dict = {}
) -> Cache: # type: ignore
"""Get postgres cache."""
cache_args.update({"cache_user": "", "cache_password": "", "cache_db": ""})
return PostgresCache(
"postgres",
cache_args={"cache_user": "", "cache_password": "", "cache_db": ""},
**kwargs,
client_name=client_name,
cache_args=cache_args,
)
@ -85,28 +88,103 @@ def test_get(
cache = cast(Cache, _get_postgres_cache())
test_request = {"test": "hello", "testA": "world"}
compute = lambda: {"choices": [{"text": "hello"}]}
test_response = {"choices": [{"text": "hello"}]}
response = cache.get(test_request, overwrite_cache=False, compute=compute)
response = cache.get(test_request)
assert response is None
cache.set(test_request, test_response)
response = cache.get(test_request)
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.is_cached()
assert response.get_request() == test_request
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_response() == "hello"
# Test array
arr = np.random.rand(4, 4)
test_request = {"test": "hello", "testA": "world of images"}
compute_arr_response = {"choices": [{"array": arr}]}
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, client_name="diffuser")
elif cache_type == "redis":
cache = RedisCache(redis_cache, client_name="diffuser")
elif cache_type == "postgres":
cache = _get_postgres_cache(client_name="diffuser")
response = cache.get(test_request)
assert response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response(), arr)
assert response.is_cached()
assert response.get_request() == test_request
response = cache.get(test_request, overwrite_cache=True, compute=compute)
assert response.get_response() == "hello"
assert not response.is_cached()
# Test array byte string
arr = np.random.rand(4, 4)
test_request = {"test": "hello", "testA": "world of images 2"}
compute_arr_response = {"choices": [{"array": arr}]}
if cache_type == "sqlite":
cache = SQLiteCache(
sqlite_cache,
client_name="diffuser",
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "redis":
cache = RedisCache(
redis_cache,
client_name="diffuser",
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "postgres":
cache = _get_postgres_cache(
client_name="diffuser", cache_args={"array_serializer": "byte_string"}
)
response = cache.get(test_request)
assert response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response(), arr)
assert response.is_cached()
assert response.get_request() == test_request
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.usefixtures("postgres_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis", "postgres"])
def test_get_batch_prompt(
sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str
) -> None:
"""Test cache save prompt."""
if cache_type == "sqlite":
cache = cast(Cache, SQLiteCache(sqlite_cache))
elif cache_type == "redis":
cache = cast(Cache, RedisCache(redis_cache))
elif cache_type == "postgres":
cache = cast(Cache, _get_postgres_cache())
test_request = {"test": ["hello", "goodbye"], "testA": "world"}
test_response = {"choices": [{"text": "hello"}, {"text": "goodbye"}]}
response = cache.get(test_request)
assert response is None
cache.set(test_request, test_response)
response = cache.get(test_request)
assert response.get_response() == ["hello", "goodbye"]
assert response.is_cached()
assert response.get_request() == test_request
# Test arrays
arr = np.random.rand(4, 4)
test_request = {"test": "hello", "testA": "world of images"}
compute_arr = lambda: {"choices": [{"array": arr}]}
arr2 = np.random.rand(4, 4)
test_request = {"test": ["hello", "goodbye"], "testA": "world of images"}
compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]}
# Test array
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, client_name="diffuser")
elif cache_type == "redis":
@ -114,9 +192,47 @@ def test_get(
elif cache_type == "postgres":
cache = _get_postgres_cache(client_name="diffuser")
response = cache.get(test_request, overwrite_cache=False, compute=compute_arr)
assert np.allclose(response.get_response(), arr)
assert not response.is_cached()
response = cache.get(test_request)
assert response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response()[0], arr)
assert np.allclose(response.get_response()[1], arr2)
assert response.is_cached()
assert response.get_request() == test_request
# Test arrays byte serializer
arr = np.random.rand(4, 4)
arr2 = np.random.rand(4, 4)
test_request = {"test": ["hello", "goodbye"], "testA": "world of images 2"}
compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]}
if cache_type == "sqlite":
cache = SQLiteCache(
sqlite_cache,
client_name="diffuser",
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "redis":
cache = RedisCache(
redis_cache,
client_name="diffuser",
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "postgres":
cache = _get_postgres_cache(
client_name="diffuser", cache_args={"array_serializer": "byte_string"}
)
response = cache.get(test_request)
assert response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response()[0], arr)
assert np.allclose(response.get_response()[1], arr2)
assert response.is_cached()
assert response.get_request() == test_request
@ -137,14 +253,11 @@ def test_noop_cache() -> None:
# Assert always not cached
test_request = {"test": "hello", "testA": "world"}
compute = lambda: {"choices": [{"text": "hello"}]}
test_response = {"choices": [{"text": "hello"}]}
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.get_request() == test_request
response = cache.get(test_request)
assert response is None
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.get_request() == test_request
cache.set(test_request, test_response)
response = cache.get(test_request)
assert response is None

@ -1,7 +1,7 @@
"""
Test client.
We just test the dummy client as we don't want to load a model or use OpenAI tokens.
We just test the dummy client.
"""
from manifest.clients.dummy import DummyClient
@ -27,17 +27,38 @@ def test_get_request() -> None:
"""Test client get request."""
args = {"n": 3}
client = DummyClient(connection_str=None, client_args=args)
request_params = client.get_request_params("hello", {})
request_func, request_params_return = client.get_request(request_params)
assert request_params_return == {"prompt": "hello", "num_results": 3}
assert request_func() == {"choices": [{"text": "hello"}] * 3}
request_params = client.get_request_params("hello", {"n": 5})
request_func, request_params_return = client.get_request(request_params)
assert request_params_return == {"prompt": "hello", "num_results": 5}
assert request_func() == {"choices": [{"text": "hello"}] * 5}
request_params = client.get_request_params(["hello"] * 5, {"n": 1})
request_func, request_params_return = client.get_request(request_params)
assert request_params_return == {"prompt": ["hello"] * 5, "num_results": 1}
assert request_func() == {"choices": [{"text": "hello"}] * 5}
request_params = client.get_request("hello", {})
response = client.run_request(request_params)
assert client.get_cache_key(request_params) == {
"prompt": "hello",
"num_results": 3,
"engine": "dummy",
}
assert response.get_json_response() == {
"choices": [{"text": "hello"}] * 3,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 3,
}
request_params = client.get_request("hello", {"n": 5})
response = client.run_request(request_params)
assert client.get_cache_key(request_params) == {
"prompt": "hello",
"num_results": 5,
"engine": "dummy",
}
assert response.get_json_response() == {
"choices": [{"text": "hello"}] * 5,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
}
request_params = client.get_request(["hello"] * 5, {"n": 1})
response = client.run_request(request_params)
assert client.get_cache_key(request_params) == {
"prompt": ["hello"] * 5,
"num_results": 1,
"engine": "dummy",
}
assert response.get_json_response() == {
"choices": [{"text": "hello"}] * 5,
"usage": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
}

@ -4,9 +4,11 @@ import math
import os
from subprocess import PIPE, Popen
import numpy as np
import pytest
from manifest.api.models.huggingface import MODEL_REGISTRY, TextGenerationModel
from manifest.api.models.sentence_transformer import SentenceTransformerModel
NOCUDA = 0
try:
@ -141,12 +143,43 @@ def test_gpt_score() -> None:
result = model.score_sequence(inputs)
assert result is not None
assert len(result) == 2
assert math.isclose(round(result[0][0], 3), -19.935)
assert math.isclose(round(result[1][0], 3), -45.831)
assert math.isclose(round(result[0][0], 3), -46.71)
assert math.isclose(round(result[1][0], 3), -12.752)
assert isinstance(result[0][1], list)
assert isinstance(result[1][1], list)
def test_embed() -> None:
"""Test embedding pipeline."""
model = TextGenerationModel(
model_name_or_path="gpt2",
use_accelerate=False,
use_parallelize=False,
use_bitsandbytes=False,
use_deepspeed=False,
use_fp16=False,
device=-1,
)
inputs = ["Why is the sky green?", "Cats are butterflies"]
embeddings = model.embed(inputs)
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (2, 768)
model2 = SentenceTransformerModel(
model_name_or_path="all-mpnet-base-v2",
use_accelerate=False,
use_parallelize=False,
use_bitsandbytes=False,
use_deepspeed=False,
use_fp16=False,
device=-1,
)
inputs = ["Why is the sky green?", "Cats are butterflies"]
embeddings = model2.embed(inputs)
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (2, 768)
def test_batch_gpt_generate() -> None:
"""Test pipeline generation from a gpt model."""
model = TextGenerationModel(

File diff suppressed because it is too large Load Diff

@ -10,10 +10,10 @@ def test_llm_init() -> None:
request = LMRequest(temperature=0.5)
assert request.temperature == 0.5
request = LMRequest(**{"temperature": 0.5})
request = LMRequest(**{"temperature": 0.5}) # type: ignore
assert request.temperature == 0.5
request = LMRequest(**{"temperature": 0.5, "prompt": "test"})
request = LMRequest(**{"temperature": 0.5, "prompt": "test"}) # type: ignore
assert request.temperature == 0.5
assert request.prompt == "test"
@ -26,10 +26,10 @@ def test_diff_init() -> None:
request = DiffusionRequest(height=128)
assert request.height == 128
request = DiffusionRequest(**{"height": 128})
request = DiffusionRequest(**{"height": 128}) # type: ignore
assert request.height == 128
request = DiffusionRequest(**{"height": 128, "prompt": "test"})
request = DiffusionRequest(**{"height": 128, "prompt": "test"}) # type: ignore
assert request.height == 128
assert request.prompt == "test"

@ -1,13 +1,12 @@
"""Cache test."""
import json
from pathlib import Path
import numpy as np
from manifest.caches.serializers import ArraySerializer
from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer
def test_response_to_key(session_cache: Path) -> None:
def test_response_to_key_array() -> None:
"""Test array serializer initialization."""
serializer = ArraySerializer()
arr = np.random.rand(4, 4)
@ -18,3 +17,16 @@ def test_response_to_key(session_cache: Path) -> None:
res2 = serializer.key_to_response(key)
assert np.allclose(arr, res2["choices"][0]["array"])
def test_response_to_key_numpybytes() -> None:
"""Test array serializer initialization."""
serializer = NumpyByteSerializer()
arr = np.random.rand(4, 4)
res = {"choices": [{"array": arr}]}
key = serializer.response_to_key(res)
key_dct = json.loads(key)
assert isinstance(key_dct["choices"][0]["array"], str)
res2 = serializer.key_to_response(key)
assert np.allclose(arr, res2["choices"][0]["array"])

@ -1,81 +0,0 @@
"""Test session."""
import sqlite3
from pathlib import Path
import pytest
from manifest.session import Session
@pytest.mark.usefixtures("session_cache")
def test_init(session_cache: Path) -> None:
"""Test session initialization."""
session = Session()
assert isinstance(session.conn, sqlite3.Connection)
assert session.db_file == session_cache / ".manifest" / "session.db"
assert session.query_id == 0
assert (session_cache / ".manifest" / "session.db").exists()
# Remove session cache file.
(session_cache / ".manifest" / "session.db").unlink()
session = Session("dog_days")
assert isinstance(session.conn, sqlite3.Connection)
assert session.db_file == session_cache / ".manifest" / "session.db"
assert session.query_id == 0
assert session.session_id == "dog_days"
assert (session_cache / ".manifest" / "session.db").exists()
session.close()
@pytest.mark.usefixtures("session_cache")
def test_log_query(session_cache: Path) -> None:
"""Test session log_query."""
session = Session()
assert session.get_last_queries(1) == []
query_key = {"query": "What is your name?", "time": "now"}
response_key = {"response": "I don't have a name", "engine": "nodel"}
session.log_query(query_key, response_key)
assert session.query_id == 1
assert session.get_last_queries(1) == [(query_key, response_key)]
query_key2 = {"query2": "What is your name?", "time": "now"}
response_key2 = {"response2": "I don't have a name", "engine": "nodel"}
session.log_query(query_key2, response_key2)
assert session.query_id == 2
assert len(session.get_last_queries(1)) == 1
assert session.get_last_queries(2) == [
(query_key, response_key),
(query_key2, response_key2),
]
session.close()
@pytest.mark.usefixtures("session_cache")
def test_resume_query(session_cache: Path) -> None:
"""Test session log_query."""
session = Session(session_id="dog_days")
query_key = {"query": "What is your name?", "time": "now"}
response_key = {"response": "I don't have a name", "engine": "nodel"}
session.log_query(query_key, response_key)
session.close()
session = Session(session_id="dog_days")
assert session.query_id == 1
@pytest.mark.usefixtures("session_cache")
def test_session_keys(session_cache: Path) -> None:
"""Test get session keys."""
# Assert empty before queries
assert Session.get_session_keys(session_cache / ".manifest" / "session.db") == []
# Add queries and make sure session is logged
session = Session(session_id="dog_days")
query_key = {"query": "What is your name?", "time": "now"}
response_key = {"response": "I don't have a name", "engine": "nodel"}
session.log_query(query_key, response_key)
session.close()
assert Session.get_session_keys(session_cache / ".manifest" / "session.db") == [
"dog_days"
]
Loading…
Cancel
Save