feat: async support, openai chatgpt, batch cache fix (#68)

pull/82/head
Laurel Orr 1 year ago committed by GitHub
parent bed6773f75
commit 395ac06a95

@ -1,6 +1,18 @@
0.1.1 - Unreleased
---------------------
Added
^^^^^
* Async support in arun_batch
Fixed
^^^^^
* Batched runs now caches individual items
* Score prompt does not truncate outside token
Removed
^^^^^
* Deprecated chatGPT in favor of openaichat which uses OpenAI completions
* Deprecated Sessions
0.1.0 - 2022-01-31
---------------------

@ -22,12 +22,6 @@ Install with diffusion support:
pip install manifest-ml[diffusers]
```
Install with ChatGPT support:
```bash
pip install manifest-ml[chatgpt]
```
This installs [pyChatGPT](https://github.com/terry3041/pyChatGPT) and uses the ChatGPT session key to start a session. This key must be set as the `CHATGPT_SESSION_KEY` environment variable or passed in with `client_connection`.
Install with HuggingFace local model support:
```bash
pip install manifest-ml[api]
@ -106,24 +100,6 @@ manifest = Manifest(
```
As a hint, if you want to get Redis running, see the `docker run` command below under development.
## Sessions
Each Manifest run supports a session that, in addition to a global cache, connects to a local SQLite DB to store user query history.
```python
manifest = Manifest(
client_name = "openai",
cache_name = "sqlite",
cache_connection = "mycache.sqlite",
session_id = "grass_color",
)
```
will start a Manifest session with the session name `grass_color`. This can be helpful for a user to logically keep track of sessions, see interaction history, and resume sessions if desired. If the session id provided is `_default`, we generate a random id for the user.
After a few queries, the user can explore their history
```python
manifest.get_last_queries(4)
```
will retrieve the last 4 model queries and responses.
## Running Queries
Once you have a session open, you can write and develop prompts.

@ -0,0 +1,27 @@
import asyncio
import time
from manifest import Manifest
def main():
manifest = Manifest(
client_name="openaichat",
)
print("Running in serial")
prompts = [f"Tell me something interesting about {i}" for i in range(50)]
st = time.time()
for pmt in prompts:
_ = manifest.run(pmt)
print(f"For loop: {time.time() - st :.2f}")
print("Running with async")
st = time.time()
_ = asyncio.run(manifest.arun_batch(prompts, max_tokens=30))
print(f"Async loop: {time.time() - st :.2f}")
if __name__ == "__main__":
main()

@ -1,63 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from manifest import Manifest\n",
"import os\n",
"\n",
"# ChatGPT tries hard not to give people programmatic access.\n",
"# As a warning, this will open a browser window.\n",
"# You need to install xvfb and chromium for linux headless mode to work\n",
"# See https://github.com/terry3041/pyChatGPT\n",
"\n",
"# The responses are not fast\n",
"manifest = Manifest(\n",
" client_name=\"chatgpt\",\n",
" client_connection=os.environ.get(\"CHATGPT_SESSION_KEY\"),\n",
")\n",
"print(manifest.run(\"Describe in a single, short sentence what is the best sandwhich in the world. Be short and concise.\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "mlcore",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "1ea9cc00d433352044b557b1784ac6e58df03de4b7bb312554014351989eb135"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -2,6 +2,5 @@
from manifest.manifest import Manifest
from manifest.request import Request
from manifest.response import Response
from manifest.session import Session
__all__ = ["Manifest", "Response", "Session"]
__all__ = ["Manifest", "Response", "Request"]

@ -174,13 +174,13 @@ def completions() -> Response:
if model_type == "diffuser":
# Assign None logprob as it's not supported in diffusers
results = [
{"array": r[0], "logprob": None, "token_logprobs": None}
{"array": r[0], "logprob": None, "tokens": None, "token_logprobs": None}
for r in result_gens
]
res_type = "image_generation"
else:
results = [
{"text": r[0], "logprob": r[1], "token_logprobs": r[2]}
{"text": r[0], "logprob": r[1], "tokens": r[2], "token_logprobs": r[3]}
for r in result_gens
]
res_type = "text_completion"
@ -241,7 +241,8 @@ def score_sequence() -> Response:
{
"text": prompt if isinstance(prompt, str) else prompt[i],
"logprob": r[0],
"token_logprobs": r[1],
"tokens": r[1],
"token_logprobs": r[2],
}
for i, r in enumerate(score_list)
]

@ -74,7 +74,7 @@ class DiffuserModel(Model):
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[float]]]:
) -> List[Tuple[Any, float, List[int], List[float]]]:
"""
Generate the prompt from model.
@ -91,12 +91,12 @@ class DiffuserModel(Model):
prompt = [prompt]
result = self.pipeline(prompt, output_type="np.array", **kwargs)
# Return None for logprobs and token logprobs
return [(im, None, None) for im in result["images"]]
return [(im, None, None, None) for im in result["images"]]
@torch.no_grad()
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[float, List[float]]]:
) -> List[Tuple[float, List[int], List[float]]]:
"""
Score a sequence of choices.

@ -179,6 +179,7 @@ class GenerationPipeline:
"logprobs": logits[
range(num_generated_tokens), i, output_seq[-num_generated_tokens:]
].tolist(),
"tokens": output_seq[-num_generated_tokens:].tolist(),
}
for i, output_seq in enumerate(output_dict.sequences)
]
@ -547,7 +548,7 @@ class TextGenerationModel(HuggingFaceModel):
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[float]]]:
) -> List[Tuple[Any, float, List[int], List[float]]]:
"""
Generate the prompt from model.
@ -576,6 +577,7 @@ class TextGenerationModel(HuggingFaceModel):
(
cast(str, r["generated_text"]),
sum(cast(List[float], r["logprobs"])),
cast(List[int], r["tokens"]),
cast(List[float], r["logprobs"]),
)
for r in result
@ -585,7 +587,7 @@ class TextGenerationModel(HuggingFaceModel):
@torch.no_grad()
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[float, List[float]]]:
) -> List[Tuple[float, List[int], List[float]]]:
"""
Score a sequence of choices.
@ -610,21 +612,20 @@ class TextGenerationModel(HuggingFaceModel):
**encoded_prompt,
).logits
# For causal decoders, shift logts and labels
labels_attention_mask = encoded_prompt["attention_mask"].unsqueeze(-1)[
..., 1:, :
]
masked_log_probs = (
labels_attention_mask.float()
* torch.log_softmax(logits.float(), dim=-1)[..., :-1, :]
labels_attention_mask = encoded_prompt["attention_mask"].unsqueeze(-1)
masked_log_probs = labels_attention_mask.float() * torch.log_softmax(
logits.float(), dim=-1
)
seq_token_log_probs = torch.gather(
masked_log_probs, -1, encoded_prompt["labels"][..., 1:].unsqueeze(-1)
masked_log_probs, -1, encoded_prompt["labels"].unsqueeze(-1)
)
seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
seq_log_prob = seq_token_log_probs.sum(dim=-1)
return [
(seq, seq_token)
for seq, seq_token in zip(
seq_log_prob.tolist(), seq_token_log_probs.tolist()
(seq, tokens, seq_token)
for seq, tokens, seq_token in zip(
seq_log_prob.tolist(),
encoded_prompt["input_ids"].tolist(),
seq_token_log_probs.tolist(),
)
]

@ -48,7 +48,7 @@ class Model(ABC):
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[float]]]:
) -> List[Tuple[Any, float, List[int], List[float]]]:
"""
Generate the prompt from model.
@ -59,7 +59,7 @@ class Model(ABC):
Returns:
list of generated text (list of length 1 for 1 generation).
Each item is the response, answer logprob,
Each item is the response, answer logprob, list of tokens,
and list of logprobs for each token.
"""
raise NotImplementedError()
@ -78,7 +78,7 @@ class Model(ABC):
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[float, List[float]]]:
) -> List[Tuple[float, List[int], List[float]]]:
"""
Score a sequence of choices.
@ -89,6 +89,6 @@ class Model(ABC):
Additional keyword arguments passed along to the :obj:`__call__` method.
Returns:
Tuple of scores for each choice and logprobs for the tokens of each choice.
Tuple of total score, tokens, and probs per token.
"""
raise NotImplementedError()

@ -36,6 +36,7 @@ class ModelResponse:
{
key: result[key],
"logprob": result["logprob"],
"tokens": result["tokens"],
"token_logprobs": result["token_logprobs"],
}
if key == "text"

@ -1,22 +1,9 @@
"""Cache for queries and responses."""
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union
from typing import Any, Dict, Union
from manifest.caches.serializers import ArraySerializer, Serializer
from manifest.response import Response
RESPONSE_CONSTRUCTORS = {
"diffuser": {
"generation_key": "choices",
"logits_key": "token_logprobs",
"item_key": "array",
},
"tomadiffuser": {
"generation_key": "choices",
"logits_key": "token_logprobs",
"item_key": "array",
},
}
from manifest.response import RESPONSE_CONSTRUCTORS, Response
CACHE_CONSTRUCTOR = {
"diffuser": ArraySerializer,
@ -101,20 +88,35 @@ class Cache(ABC):
"""Commit any results."""
raise NotImplementedError()
def get(
self, request: Dict, overwrite_cache: bool, compute: Callable[[], Dict]
) -> Response:
"""Get the result of request (by calling compute as needed)."""
def get(self, request: Dict) -> Union[Response, None]:
"""Get the result of request (by calling compute as needed).
Args:
request: request to get.
response: response to get.
Returns:
Response object or None if not in cache.
"""
key = self.serializer.request_to_key(request)
cached_response = self.get_key(key)
if cached_response and not overwrite_cache:
if cached_response:
cached = True
response = self.serializer.key_to_response(cached_response)
else:
# Type Response
response = compute()
self.set_key(key, self.serializer.response_to_key(response))
cached = False
return Response(
response, cached, request, **RESPONSE_CONSTRUCTORS.get(self.client_name, {})
)
return Response(
response,
cached,
request,
**RESPONSE_CONSTRUCTORS.get(self.client_name, {})
)
return None
def set(self, request: Dict, response: Dict) -> None:
"""Set the value for the key.
Args:
request: request to set.
response: response to set.
"""
key = self.serializer.request_to_key(request)
self.set_key(key, self.serializer.response_to_key(response))

@ -100,7 +100,7 @@ class PostgresCache(Cache):
table: table to get key in.
"""
request = (
self.session.query(Request)
self.session.query(Request) # type: ignore
.filter_by(key=self._hash_key(key, table))
.first()
)
@ -119,7 +119,7 @@ class PostgresCache(Cache):
table: table to set key in.
"""
key = self._hash_key(key, table)
request = self.session.query(Request).filter_by(key=key).first()
request = self.session.query(Request).filter_by(key=key).first() # type: ignore
if request:
request.response = value # type: ignore
else:

@ -27,9 +27,9 @@ class AI21Client(Client):
"n": ("numResults", 1),
"top_p": ("topP", 1.0),
"stop_sequences": ("stopSequences", []),
"client_timeout": ("client_timeout", 60), # seconds
}
REQUEST_CLS = LMRequest
NAME = "ai21"
def connect(
self,

@ -1,130 +0,0 @@
"""Client class."""
import logging
import os
from typing import Any, Callable, Dict, Optional, Tuple
from pyChatGPT import ChatGPT
from manifest.clients.client import Client
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
class ChatGPTClient(Client):
"""ChatGPT Client class."""
# No params for ChatGPT
PARAMS: Dict[str, Tuple[str, Any]] = {}
REQUEST_CLS = LMRequest
def connect(
self, connection_str: Optional[str], client_args: Dict[str, Any]
) -> None:
"""
Connect to ChatGPT.
We use https://github.com/terry3041/pyChatGPT.
Arsg:
connection_str: connection string.
client_args: client arguments.
"""
self.session_key = os.environ.get("CHATGPT_SESSION_KEY", connection_str)
if self.session_key is None:
raise ValueError(
"ChatGPT session key not set. Set CHATGPT_SESSION_KEY environment "
"variable or pass through `client_connection`. "
"For details, see https://github.com/terry3041/pyChatGPT "
"and go through instructions for getting a session key."
)
self.host = None
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
self._chat_session = ChatGPT(self.session_key, verbose=False)
def close(self) -> None:
"""Close the client."""
self._chat_session = None
def clear_conversations(self) -> None:
"""Clear conversations.
Only works for ChatGPT.
"""
self._chat_session.clear_conversations()
def get_generation_url(self) -> str:
"""Get generation URL."""
return ""
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return {}
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return False
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "chatgpt", "engine": "chatgpt"}
def format_response(self, response: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
Return:
response as dict
"""
return {
"model": "chatgpt",
"choices": [
{
"text": response["message"],
}
],
}
def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Args:
request: request.
Returns:
request function that takes no input.
request parameters as dict.
"""
if isinstance(request.prompt, list):
raise ValueError("ChatGPT does not support batch inference.")
prompt = str(request.prompt)
request_params = request.to_dict(self.PARAMS)
def _run_completion() -> Dict:
try:
res = self._chat_session.send_message(prompt)
except Exception as e:
logger.error(f"ChatGPT error {e}.")
raise e
return self.format_response(res)
return _run_completion, request_params

@ -1,11 +1,16 @@
"""Client class."""
import asyncio
import copy
import logging
import math
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import aiohttp
import requests
from manifest.request import Request
from manifest.request import DEFAULT_REQUEST_KEYS, NOT_CACHE_KEYS, Request
from manifest.response import RESPONSE_CONSTRUCTORS, Response
logger = logging.getLogger(__name__)
@ -16,6 +21,7 @@ class Client(ABC):
# Must be overridden by child class
PARAMS: Dict[str, Tuple[str, Any]] = {}
REQUEST_CLS = Request
NAME: str = None
def __init__(
self, connection_str: Optional[str] = None, client_args: Dict[str, Any] = {}
@ -93,7 +99,7 @@ class Client(ABC):
"""
return list(self.PARAMS.keys())
def get_request_params(
def get_request(
self, prompt: Union[str, List[str]], request_args: Dict[str, Any]
) -> Request:
"""
@ -109,7 +115,32 @@ class Client(ABC):
params = {"prompt": prompt}
for key in self.PARAMS:
params[key] = request_args.pop(key, getattr(self, key))
return self.REQUEST_CLS(**params)
for key in DEFAULT_REQUEST_KEYS:
if key not in params and key in request_args:
params[key] = request_args.pop(key)
return self.REQUEST_CLS(**params) # type: ignore
def get_request_params(self, request: Request) -> Dict[str, Any]:
"""Get request params.
Add default keys that we need for requests such as batch_size.
We drop these before sending to the model.
"""
params_to_add = DEFAULT_REQUEST_KEYS.copy()
params_to_add.update(self.PARAMS)
request_params = request.to_dict(params_to_add)
return request_params
def get_cache_key(self, request: Request) -> Dict[str, Any]:
"""Get cache key for request.
Skip keys that are not cache keys such as batch_size.
"""
request_params = self.get_request_params(request)
for key in NOT_CACHE_KEYS:
request_params.pop(key, None)
request_params.update(self.get_model_params())
return request_params
def format_response(self, response: Dict) -> Dict[str, Any]:
"""
@ -125,7 +156,90 @@ class Client(ABC):
raise ValueError(f"Invalid response: {response}")
return response
def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]:
def split_requests(
self, request_params: Dict[str, Any], batch_size: int, key: str = "prompt"
) -> List[Dict[str, Any]]:
"""Split request into batch_sized request.
Args:
request_params: request params.
batch_size: batch size for requests.
key: key to batch over
Returns:
list of request params.
"""
data = copy.deepcopy(request_params[key])
data_size = len(request_params[key])
request_params_list = []
for i in range(0, data_size, batch_size):
params = copy.deepcopy(request_params)
params[key] = data[i] if batch_size == 1 else data[i : i + batch_size]
request_params_list.append(params)
return request_params_list
def _run_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
Returns:
response as dict.
"""
post_str = self.get_generation_url()
try:
res = requests.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
)
res.raise_for_status()
except requests.Timeout as e:
logger.error(
f"{self.__class__.__name__} request timed out."
" Increase client_timeout."
)
raise e
except requests.exceptions.HTTPError:
logger.error(res.json())
raise requests.exceptions.HTTPError(res.json())
return self.format_response(res.json())
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
batch_size: batch size for requests.
Returns:
response as dict.
"""
post_str = self.get_generation_url()
try:
async with aiohttp.ClientSession(timeout=retry_timeout) as session:
async with session.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
) as res:
res.raise_for_status()
res_json = await res.json(content_type=None)
return self.format_response(res_json)
except aiohttp.ClientError as e:
logger.error(f"{self.__class__.__name__} request error {e}")
raise e
def run_request(self, request: Request) -> Response:
"""
Get request string function.
@ -133,44 +247,80 @@ class Client(ABC):
request: request.
Returns:
request function that takes no input.
request parameters as dict.
response.
"""
if isinstance(request.prompt, list) and not self.supports_batch_inference():
raise ValueError(
f"{self.__class__.__name__} does not support batch inference."
)
request_params = request.to_dict(self.PARAMS)
request_params = self.get_request_params(request)
# Take the default keys we need and drop the rest as they
# are not part of the model request.
retry_timeout = request_params.pop("client_timeout")
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
response_dict = self._run_completion(request_params, retry_timeout)
return Response(
response_dict,
cached=False,
request_params=request_params,
**RESPONSE_CONSTRUCTORS.get(self.NAME, {}), # type: ignore
)
def _run_completion() -> Dict:
post_str = self.get_generation_url()
try:
res = requests.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
)
res.raise_for_status()
except requests.Timeout as e:
logger.error(
f"{self.__class__.__name__} request timed out."
" Increase client_timeout."
)
raise e
except requests.exceptions.HTTPError:
logger.error(res.json())
raise requests.exceptions.HTTPError(res.json())
return self.format_response(res.json())
return _run_completion, request_params
async def arun_batch_request(self, request: Request) -> Response:
"""
Get async request string function.
Args:
request: request.
Returns:
response.
"""
required_batch_size = None
if not self.supports_batch_inference():
required_batch_size = 1
if not isinstance(request.prompt, list):
raise AssertionError(
"request.prompt must be a list for async batch inference."
)
request_params = self.get_request_params(request)
# Take the default keys we need and drop the rest as they
# are not part of the model request.
retry_timeout = request_params.pop("client_timeout")
batch_size = request_params.pop("batch_size")
batch_size = required_batch_size or batch_size
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
num_batches = len(request.prompt) // batch_size
if len(request.prompt) % batch_size != 0:
batch_size = int(math.ceil(len(request.prompt) / (num_batches + 1)))
request_batches = self.split_requests(request_params, batch_size)
all_tasks = [
asyncio.create_task(self._arun_completion(batch, retry_timeout, batch_size))
for batch in request_batches
]
responses = await asyncio.gather(*all_tasks)
# Flatten responses
choices = []
for res_dict in responses:
choices.extend(res_dict["choices"])
final_response_dict = self.format_response({"choices": choices})
return Response(
final_response_dict,
cached=False,
request_params=request_params,
**RESPONSE_CONSTRUCTORS.get(self.NAME, {}), # type: ignore
)
def get_score_prompt_request(
self,
request: Request,
) -> Tuple[Callable[[], Dict], Dict]:
) -> Response:
"""
Get the logit score of the prompt via a forward pass of the model.

@ -26,9 +26,9 @@ class CohereClient(Client):
"frequency_penalty": ("frequency_penalty", 0.0),
"presence_penalty": ("presence_penalty", 0.0),
"stop_sequences": ("stop_sequences", None),
"client_timeout": ("client_timeout", 60), # seconds
}
REQUEST_CLS = LMRequest
NAME = "cohere"
def connect(
self,

@ -22,9 +22,9 @@ class DiffuserClient(Client):
"n": ("num_images_per_prompt", 1),
"guidance_scale": ("guidance_scale", 7.5),
"eta": ("eta", 0.0),
"client_timeout": ("client_timeout", 120), # seconds
}
REQUEST_CLS = DiffusionRequest
NAME = "diffuser"
def connect(
self,

@ -1,9 +1,10 @@
"""Dummy client."""
import logging
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Dict, Optional
from manifest.clients.client import Client
from manifest.request import LMRequest, Request
from manifest.response import Response
logger = logging.getLogger(__name__)
@ -16,6 +17,7 @@ class DummyClient(Client):
"n": ("num_results", 1),
}
REQUEST_CLS = LMRequest
NAME = "dummy"
def connect(
self,
@ -67,7 +69,7 @@ class DummyClient(Client):
"""
return {"engine": "dummy"}
def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]:
def run_request(self, request: Request) -> Response:
"""
Get request string function.
@ -84,19 +86,29 @@ class DummyClient(Client):
num_results = 1
request_params = request.to_dict(self.PARAMS)
def _run_completion() -> Dict:
return {
"choices": [{"text": "hello"}]
* int(request_params["num_results"])
* num_results
}
response_dict = {
"choices": [{"text": "hello"}]
* int(request_params["num_results"])
* num_results
}
return Response(response_dict, False, request_params)
return _run_completion, request_params
async def arun_batch_request(self, request: Request) -> Response:
"""
Get async request string function.
Args:
request: request.
Returns:
response.
"""
return self.run_request(request)
def get_score_prompt_request(
self,
request: Request,
) -> Tuple[Callable[[], Dict], Dict]:
) -> Response:
"""
Get the logit score of the prompt via a forward pass of the model.
@ -113,17 +125,15 @@ class DummyClient(Client):
num_results = 1
request_params = {"prompt": request.prompt}
def _run_completion() -> Dict:
return {
"choices": [
{
"text": request.prompt
if isinstance(request.prompt, str)
else request.prompt[i],
"logprob": 0.3,
}
for i in range(num_results)
]
}
return _run_completion, request_params
response_dict = {
"choices": [
{
"text": request.prompt
if isinstance(request.prompt, str)
else request.prompt[i],
"logprob": 0.3,
}
for i in range(num_results)
]
}
return Response(response_dict, False, request_params)

@ -1,11 +1,12 @@
"""Hugging Face client."""
import logging
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Dict, Optional
import requests
from manifest.clients.client import Client
from manifest.request import LMRequest, Request
from manifest.request import DEFAULT_REQUEST_KEYS, LMRequest, Request
from manifest.response import Response
logger = logging.getLogger(__name__)
@ -22,9 +23,9 @@ class HuggingFaceClient(Client):
"top_k": ("top_k", 50),
"repetition_penalty": ("repetition_penalty", 1.0),
"do_sample": ("do_sample", True),
"client_timeout": ("client_timeout", 120), # seconds
}
REQUEST_CLS = LMRequest
NAME = "huggingface"
def connect(
self,
@ -81,7 +82,7 @@ class HuggingFaceClient(Client):
def get_score_prompt_request(
self,
request: Request,
) -> Tuple[Callable[[], Dict], Dict]:
) -> Response:
"""
Get the logit score of the prompt via a forward pass of the model.
@ -92,26 +93,26 @@ class HuggingFaceClient(Client):
request function that takes no input.
request parameters as dict.
"""
request_params = request.to_dict(self.PARAMS)
request_params = self.get_request_params(request)
retry_timeout = request_params.pop("client_timeout")
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
# Do not add params like we do with request as the model isn't sampling
request_params = {"prompt": request.prompt}
def _run_completion() -> Dict:
post_str = self.host + "/score_sequence"
try:
res = requests.post(
post_str,
json=request_params,
timeout=retry_timeout,
)
res.raise_for_status()
except requests.Timeout as e:
logger.error("HF request timed out. Increase client_timeout.")
raise e
except requests.exceptions.HTTPError as e:
logger.error(res.text)
raise e
return res.json()
return _run_completion, request_params
post_str = self.host + "/score_sequence"
try:
res = requests.post(
post_str,
json=request_params,
timeout=retry_timeout,
)
res.raise_for_status()
except requests.Timeout as e:
logger.error("HF request timed out. Increase client_timeout.")
raise e
except requests.exceptions.HTTPError as e:
logger.error(res.text)
raise e
response_dict = res.json()
return Response(response_dict, cached=False, request_params=request_params)

@ -38,9 +38,9 @@ class OpenAIClient(Client):
"stop_sequences": ("stop", None), # OpenAI doesn't like empty lists
"presence_penalty": ("presence_penalty", 0.0),
"frequency_penalty": ("frequency_penalty", 0.0),
"client_timeout": ("client_timeout", 60), # seconds
}
REQUEST_CLS = LMRequest
NAME = "openai"
def connect(
self,

@ -0,0 +1,171 @@
"""OpenAIChat client."""
import copy
import logging
import os
from typing import Any, Dict, Optional
from manifest.clients.client import Client
from manifest.request import LMRequest
logger = logging.getLogger(__name__)
OPENAICHAT_ENGINES = {
"gpt-3.5-turbo",
}
class OpenAIChatClient(Client):
"""OpenAI Chat client."""
# User param -> (client param, default value)
PARAMS = {
"engine": ("model", "gpt-3.5-turbo"),
"temperature": ("temperature", 1.0),
"max_tokens": ("max_tokens", 10),
"n": ("n", 1),
"top_p": ("top_p", 1.0),
"stop_sequences": ("stop", None), # OpenAI doesn't like empty lists
"presence_penalty": ("presence_penalty", 0.0),
"frequency_penalty": ("frequency_penalty", 0.0),
}
REQUEST_CLS = LMRequest
NAME = "openaichat"
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the OpenAI server.
connection_str is passed as default OPENAI_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
self.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
if self.api_key is None:
raise ValueError(
"OpenAI API key not set. Set OPENAI_API_KEY environment "
"variable or pass through `client_connection`."
)
self.host = "https://api.openai.com/v1"
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAICHAT_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. "
f"Must be {OPENAICHAT_ENGINES}."
)
def close(self) -> None:
"""Close the client."""
pass
def get_generation_url(self) -> str:
"""Get generation URL."""
return self.host + "/chat/completions"
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return {"Authorization": f"Bearer {self.api_key}"}
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return False
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "openaichat", "engine": getattr(self, "engine")}
def _format_request_for_chat(self, request_params: Dict[str, Any]) -> Dict:
"""Format request params for chat.
Args:
request_params: request params.
Returns:
formatted request params.
"""
# Format for chat model
request_params = copy.deepcopy(request_params)
prompt = request_params.pop("prompt")
if isinstance(prompt, str):
prompt_list = [prompt]
else:
prompt_list = prompt
messages = [{"role": "user", "content": prompt} for prompt in prompt_list]
request_params["messages"] = messages
return request_params
def _format_request_for_text(self, response_dict: Dict[str, Any]) -> Dict:
"""Format response for text.
Args:
response_dict: response.
Return:
formatted response.
"""
new_choices = []
response_dict = copy.deepcopy(response_dict)
for message in response_dict["choices"]:
new_choices.append({"text": message["message"]["content"]})
response_dict["choices"] = new_choices
return response_dict
def _run_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
Returns:
response as dict.
"""
# Format for chat model
request_params = self._format_request_for_chat(request_params)
response_dict = super()._run_completion(request_params, retry_timeout)
# Reformat for text model
response_dict = self._format_request_for_text(response_dict)
return response_dict
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
batch_size: batch size for requests.
Returns:
response as dict.
"""
# Format for chat model
request_params = self._format_request_for_chat(request_params)
response_dict = await super()._arun_completion(
request_params, retry_timeout, batch_size
)
# Reformat for text model
response_dict = self._format_request_for_text(response_dict)
return response_dict

@ -31,9 +31,9 @@ class TOMAClient(Client):
"top_p": ("top_p", 0.9),
"top_k": ("top_k", 40),
"stop_sequences": ("stop", []),
"client_timeout": ("client_timeout", 120), # seconds
}
REQUEST_CLS = LMRequest
NAME = "toma"
def connect(
self,

@ -30,9 +30,9 @@ class TOMADiffuserClient(TOMAClient):
"width": ("width", 512),
"n": ("n", 1),
"guidance_scale": ("guidance_scale", 7.5),
"client_timeout": ("client_timeout", 120), # seconds
}
REQUEST_CLS = DiffusionRequest # type: ignore
NAME = "tomadiffuser"
def get_model_params(self) -> Dict:
"""

@ -1,4 +1,5 @@
"""Manifest class."""
import copy
import logging
from typing import Any, Dict, List, Optional, Tuple, Union, cast
@ -13,40 +14,32 @@ from manifest.clients.cohere import CohereClient
from manifest.clients.dummy import DummyClient
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.openai import OpenAIClient
from manifest.clients.openaichat import OpenAIChatClient
from manifest.clients.toma import TOMAClient
from manifest.request import Request
from manifest.response import Response
from manifest.session import Session
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
CLIENT_CONSTRUCTORS = {
"openai": OpenAIClient,
"cohere": CohereClient,
"ai21": AI21Client,
"huggingface": HuggingFaceClient,
"dummy": DummyClient,
"toma": TOMAClient,
OpenAIClient.NAME: OpenAIClient,
OpenAIChatClient.NAME: OpenAIChatClient,
CohereClient.NAME: CohereClient,
AI21Client.NAME: AI21Client,
HuggingFaceClient.NAME: HuggingFaceClient,
DummyClient.NAME: DummyClient,
TOMAClient.NAME: TOMAClient,
}
# ChatGPT
try:
from manifest.clients.chatgpt import ChatGPTClient
CLIENT_CONSTRUCTORS["chatgpt"] = ChatGPTClient
except Exception:
logger.info("ChatGPT not installed. Skipping import.")
pass
# Diffusion
DIFFUSION_CLIENTS = ["diffuser", "tomadiffuser"]
try:
from manifest.clients.diffuser import DiffuserClient
from manifest.clients.toma_diffuser import TOMADiffuserClient
CLIENT_CONSTRUCTORS["diffuser"] = DiffuserClient
CLIENT_CONSTRUCTORS["tomadiffuser"] = TOMADiffuserClient
CLIENT_CONSTRUCTORS[DiffuserClient.NAME] = DiffuserClient
CLIENT_CONSTRUCTORS[TOMADiffuserClient.NAME] = TOMADiffuserClient
except Exception:
logger.info("Diffusion not supported. Skipping import.")
pass
@ -69,7 +62,6 @@ class Manifest:
client_connection: Optional[str] = None,
cache_name: str = "noop",
cache_connection: Optional[str] = None,
session_id: Optional[str] = None,
stop_token: str = "",
**kwargs: Any,
):
@ -81,9 +73,6 @@ class Manifest:
client_connection: connection string for client.
cache_name: name of cache.
cache_connection: connection string for cache.
session_id: session id for user session cache.
None (default) means no session logging.
"_default" means generate new session id.
stop_token: stop token prompt generation.
Can be overridden in run
@ -114,17 +103,6 @@ class Manifest:
self.client = CLIENT_CONSTRUCTORS[self.client_name]( # type: ignore
client_connection, client_args=kwargs
)
if session_id is not None:
if self.client_name == "diffuser":
raise NotImplementedError(
"Session logging not implemented for Diffuser client."
)
if session_id == "_default":
# Set session_id to None for Session random id
session_id = None
self.session = Session(session_id)
else:
self.session = None
if len(kwargs) > 0:
raise ValueError(f"{list(kwargs.items())} arguments are not recognized.")
@ -195,11 +173,133 @@ class Manifest:
logger.warning(f"{list(request_unused_kwargs)} arguments are unused.")
return
def _split_cached_requests(
self,
request: Request,
overwrite_cache: bool,
) -> Tuple[Dict[int, Response], Request]:
"""Split a request into cached responses and Requests to run.
Args:
request: request object.
overwrite_cache: whether to overwrite cache.
Returns:
cached_idx_to_response: dict of cached responses.
new_request: request object with only prompts to run.
"""
cached_idx_to_response: Dict[int, Response] = {}
new_request = copy.deepcopy(request)
if not overwrite_cache:
if isinstance(new_request.prompt, list):
new_request.prompt = []
for idx, prompt_str in enumerate(request.prompt):
single_request = copy.deepcopy(request)
single_request.prompt = prompt_str
possible_response = self.cache.get(
self.client.get_cache_key(single_request)
)
if possible_response:
cached_idx_to_response[idx] = possible_response
else:
new_request.prompt.append(prompt_str)
else:
possible_response = self.cache.get(
self.client.get_cache_key(new_request)
)
if possible_response:
cached_idx_to_response[0] = possible_response
new_request.prompt = None
return cached_idx_to_response, new_request
def _stitch_responses_and_cache(
self,
request: Request,
response: Union[Response, None],
cached_idx_to_response: Dict[int, Response],
) -> Response:
"""Stich together the cached and uncached responses."""
# We stitch the responses (the choices) here from both the new request the
# cached entries.
all_model_choices = []
all_input_prompts = []
response_idx = 0
number_prompts = len(cached_idx_to_response)
single_output = False
if response:
if isinstance(response.get_request()["prompt"], str):
single_output = True
number_prompts += 1
else:
number_prompts += len(response.get_request()["prompt"])
response_gen_key = None
response_logits_key = None
response_item_key = None
for idx in range(number_prompts):
if idx in cached_idx_to_response:
cached_res = cached_idx_to_response[idx]
response_gen_key = cached_res.generation_key
response_logits_key = cached_res.logits_key
response_item_key = cached_res.item_key
all_input_prompts.append(cached_res.get_request()["prompt"])
if request.n == 1:
assert (
len(cached_res.get_json_response()[response_gen_key]) == 1
), "cached response should have only one choice"
all_model_choices.append(
cached_res.get_json_response()[response_gen_key][0]
)
else:
all_model_choices.extend(
cached_res.get_json_response()[response_gen_key]
)
else:
assert response is not None, "response should not be None"
response = cast(Response, response)
response_gen_key = response.generation_key
response_logits_key = response.logits_key
response_item_key = response.item_key
# the choices list in the response is a flat one.
# length is request.n * num_prompts
current_choices = response.get_json_response()[response_gen_key][
response_idx * request.n : (response_idx + 1) * request.n
]
all_model_choices.extend(current_choices)
if isinstance(response.get_request()["prompt"], list):
prompt = response.get_request()["prompt"][response_idx]
else:
prompt = str(response.get_request()["prompt"])
all_input_prompts.append(prompt)
# set cache
new_request = copy.deepcopy(request)
new_request.prompt = prompt
cache_key = self.client.get_cache_key(new_request)
new_response_key = copy.deepcopy(response.get_json_response())
new_response_key[response_gen_key] = current_choices
self.cache.set(cache_key, new_response_key)
response_idx += 1
new_request = copy.deepcopy(request)
new_request.prompt = (
all_input_prompts
if len(all_input_prompts) > 1 or not single_output
else all_input_prompts[0]
)
response_obj = Response(
{response_gen_key: all_model_choices},
cached=len(cached_idx_to_response) > 0,
request_params=self.client.get_cache_key(new_request),
generation_key=response_gen_key,
logits_key=response_logits_key,
item_key=response_item_key,
)
return response_obj
def run(
self,
prompt: Union[str, List[str]],
overwrite_cache: bool = False,
run_id: Optional[str] = None,
stop_token: Optional[str] = None,
return_response: bool = False,
**kwargs: Any,
@ -210,7 +310,6 @@ class Manifest:
Args:
prompt: prompt(s) to run.
overwrite_cache: whether to overwrite cache.
run_id: run id for cache to repeat same run.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
@ -223,28 +322,88 @@ class Manifest:
stop_token = stop_token if stop_token is not None else self.stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self.client.get_request_params(prompt, kwargs)
request_params = self.client.get_request(prompt, kwargs)
# Avoid nested list of results - enforce n = 1 for batch
if is_batch and request_params.n > 1:
raise ValueError("Batch mode does not support n > 1.")
possible_request, full_kwargs = self.client.get_request(request_params)
self._validate_kwargs(kwargs, request_params)
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, overwrite_cache
)
# If not None value or empty list - run new request
if request_params.prompt:
response = self.client.run_request(request_params)
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
# Extract text results
if return_response:
return final_response
else:
return final_response.get_response(stop_token, is_batch)
async def arun_batch(
self,
prompts: List[str],
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
**kwargs: Any,
) -> Union[List[str], List[np.ndarray], Response]:
"""
Run a batch of prompts with async.
Args:
prompts: prompts to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
Returns:
response from prompt.
"""
stop_token = stop_token if stop_token is not None else self.stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self.client.get_request(prompts, kwargs)
# Avoid nested list of results - enforce n = 1 for batch
if request_params.n > 1:
raise ValueError("Batch mode does not support n > 1.")
self._validate_kwargs(kwargs, request_params)
# Create cacke key
cache_key = full_kwargs.copy()
# Make query model dependent
cache_key.update(self.client.get_model_params())
if run_id:
cache_key["run_id"] = run_id
response_obj = self.cache.get(cache_key, overwrite_cache, possible_request)
# Log session dictionary values
if self.session:
self.session.log_query(cache_key, response_obj.to_dict())
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, overwrite_cache
)
# If not None value or empty list - run new request
if request_params.prompt:
response = await self.client.arun_batch_request(request_params)
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
# Extract text results
if return_response:
return response_obj
return final_response
else:
return response_obj.get_response(stop_token, is_batch)
return cast(
Union[List[str], List[np.ndarray]],
final_response.get_response(stop_token, True),
)
def score_prompt(
self,
@ -257,8 +416,6 @@ class Manifest:
Returns the response object with logits of the prompt.
Prompt scoring is not part of a session cache.
Args:
prompt: prompt(s) to run.
overwrite_cache: whether to overwrite cache.
@ -267,66 +424,31 @@ class Manifest:
response from prompt.
"""
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self.client.get_request_params(prompt, kwargs)
request_params = self.client.get_request(prompt, kwargs)
request_params.request_type = "score_prompt"
if request_params.n > 1:
raise ValueError("Sequence scoring does not support n > 1.")
try:
possible_request, full_kwargs = cast(
HuggingFaceClient, self.client
).get_score_prompt_request(request_params)
except AttributeError:
raise ValueError("`score_prompt` only supported for HF models.")
self._validate_kwargs(kwargs, request_params)
# Create cacke key
cache_key = full_kwargs.copy()
# Make query model dependent
cache_key.update(self.client.get_model_params())
response_obj = self.cache.get(cache_key, overwrite_cache, possible_request)
return response_obj.to_dict()
def get_last_queries(
self,
last_n: int = -1,
return_raw_values: bool = False,
stop_token: Optional[str] = None,
) -> List[Tuple[Any, Any]]:
"""
Get last n queries from current session.
If last_n is -1, return all queries. By default will only return the
prompt text and result text unles return_raw_values is False.
Args:
last_n: last n queries.
return_raw_values: whether to return raw values as dicts.
stop_token: stop token for prompt results to be applied to all results.
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, overwrite_cache
)
# If not None value or empty list - run new request
if request_params.prompt:
try:
response = cast(
HuggingFaceClient, self.client
).get_score_prompt_request(request_params)
except AttributeError:
raise ValueError("`score_prompt` only supported for HF models.")
else:
# Nothing to run
response = None
Returns:
last n list of queries and outputs.
"""
if self.session is None:
raise ValueError(
"Session was not initialized. Set `session_id` when loading Manifest."
)
stop_token = stop_token if stop_token is not None else self.stop_token
last_queries = self.session.get_last_queries(last_n)
if not return_raw_values:
last_queries = [
(
query["prompt"],
Response.from_dict(response).get_response(
stop_token, is_batch=isinstance(query["prompt"], list)
),
) # type: ignore
for query, response in last_queries
]
return last_queries
def open_explorer(self) -> None:
"""Open the explorer for jupyter widget."""
# Open explorer
# TODO: implement
pass
final_response = self._stitch_responses_and_cache(
request=request_params,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
return final_response.to_dict()

@ -3,6 +3,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
NOT_CACHE_KEYS = {"client_timeout", "batch_size"}
DEFAULT_REQUEST_KEYS = {
"client_timeout": ("client_timeout", 60), # seconds
"batch_size": ("batch_size", 1),
"run_id": ("run_id", None),
"request_type": ("request_type", None),
}
class Request(BaseModel):
"""Request object."""
@ -17,7 +25,16 @@ class Request(BaseModel):
n: int = 1
# Timeout
client_timeout: int = 60
client_timeout: int = 120
# Run id used to repeat run with same parameters
run_id: Optional[str] = None
# Batch size for async batch run
batch_size: int = 8
# Request type None is for completion. Used to scoring prompt
request_type: str = None
def to_dict(
self, allowable_keys: Dict[str, Tuple[str, Any]] = None, add_prompt: bool = True
@ -82,6 +99,9 @@ class LMRequest(Request):
class DiffusionRequest(Request):
"""Diffusion Model Request object."""
# Request type
request_type: str = "diffusion"
# Number of steps
num_inference_steps: int = 50

@ -4,6 +4,17 @@ from typing import Any, Dict, List, Union
import numpy as np
RESPONSE_CONSTRUCTORS = {
"diffuser": {
"logits_key": "token_logprobs",
"item_key": "array",
},
"tomadiffuser": {
"logits_key": "token_logprobs",
"item_key": "array",
},
}
class NumpyArrayEncoder(json.JSONEncoder):
"""Numpy array encoder."""

@ -1,156 +0,0 @@
"""User query session logging."""
import logging
import os
import sqlite3
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from manifest.caches.serializers import Serializer
logging.getLogger("sqlitedict").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
class Session:
"""A user session for caching requests."""
def __init__(self, session_id: Optional[str] = None) -> None:
"""
Initialize session.
If session_id already exists, will append to existing session.
Args:
session_id: session id.
"""
manifest_home = Path(os.environ.get("MANIFEST_HOME", Path.home()))
self.db_file = manifest_home / ".manifest" / "session.db"
self.db_file.parent.mkdir(parents=True, exist_ok=True)
self.conn = sqlite3.connect(str(self.db_file))
self.serializer = Serializer()
self._create_table()
if not session_id:
self.session_id = str(uuid.uuid4())
self.query_id = 0
else:
self.session_id = session_id
self.query_id = self._get_latest_query_id(self.session_id)
self.query_id += 1
logger.info(f"Starting session {self.session_id}")
return
def close(self) -> None:
"""Close the client."""
self.conn.close()
@classmethod
def get_session_keys(cls, db_file: Path) -> List[str]:
"""Get available session keys from cached file."""
try:
conn = sqlite3.connect(str(db_file))
query = """SELECT DISTINCT session_id FROM queries"""
cur = conn.cursor()
res = cur.execute(query)
return [x[0] for x in res.fetchall()]
except sqlite3.OperationalError:
logger.info(
"There is no database with the 'queries' table. "
"Are you sure you are using the right session file"
)
return []
def _execute_query(self, query: str, *args: Any) -> Any:
"""
Execute query with optional args.
Args:
query: query to execute.
"""
cur = self.conn.cursor()
res = cur.execute(query, args)
self.conn.commit()
return res
def _create_table(self) -> None:
"""Create table if not exists."""
query = """CREATE TABLE IF NOT EXISTS queries (
query_id integer NOT NULL,
session_id text NOT NULL,
query_key text NOT NULL,
response_key text NOT NULL
);"""
self._execute_query(query)
return
def _get_latest_query_id(self, session_id: str) -> int:
"""
Get latest query id issued if resuming session.
If no session_id, return -1.
Args:
session_id: session id.
Returns:
latest query id.
"""
query = """SELECT query_id
FROM queries
WHERE session_id = ?
ORDER BY query_id DESC LIMIT 1"""
res = self._execute_query(query, session_id).fetchone()
if res:
return res[0]
return -1
def log_query(
self, query_key: Dict[str, Any], response_key: Dict[str, Any]
) -> None:
"""
Log the query and response.
Args:
query_key: query of user (dump of request params).
response_key: response of server (dump of response).
"""
query = """INSERT INTO queries VALUES (?, ?, ?, ?);"""
self._execute_query(
query,
self.query_id,
self.session_id,
self.serializer.request_to_key(query_key),
self.serializer.response_to_key(response_key),
)
self.query_id += 1
return
def get_last_queries(
self, last_n: int = -1
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
"""
Get last n queries from current session.
If last_n is -1, return all queries.
Args:
last_n: last n queries.
Returns:
last n list of queries and outputs.
"""
first_query = self.query_id - last_n if last_n > 0 else -1
query = """SELECT query_key, response_key
FROM queries
WHERE session_id = ? AND query_id >= ?
ORDER BY query_id;"""
res = self._execute_query(query, self.session_id, first_query)
parsed_res = [
(
self.serializer.key_to_request(pair[0]),
self.serializer.key_to_response(pair[1]),
)
for pair in res.fetchall()
]
return parsed_res

@ -32,6 +32,7 @@ REQUIRED = [
"pydantic>=1.9.0",
"redis>=4.3.1",
"requests>=2.27.1",
"aiohttp>=3.8.0",
"sqlitedict>=2.0.0",
"xxhash>=3.0.0",
]
@ -50,9 +51,6 @@ EXTRAS = {
"fastapi>=0.70.0",
"uvicorn>=0.18.0",
],
"chatgpt": [
"pyChatGPT>=0.4.3",
],
"diffusers": [
"pillow>=9.0.0",
],
@ -60,7 +58,7 @@ EXTRAS = {
"pg8000",
"cloud-sql-python-connector[pg8000]>=1.0.0",
"sqlalchemy",
],
],
"dev": [
"autopep8>=1.6.0",
"black>=22.3.0",

@ -42,10 +42,3 @@ def postgres_cache(monkeypatch: pytest.MonkeyPatch) -> Generator[str, None, None
engine = sqlalchemy.create_engine(url)
monkeypatch.setattr(sqlalchemy, "create_engine", lambda *args, **kwargs: engine)
return engine # type: ignore
@pytest.fixture
def session_cache(tmpdir: str) -> Generator[Path, None, None]:
"""Session cache dir."""
os.environ["MANIFEST_HOME"] = str(tmpdir)
yield Path(tmpdir)

@ -85,26 +85,20 @@ def test_get(
cache = cast(Cache, _get_postgres_cache())
test_request = {"test": "hello", "testA": "world"}
compute = lambda: {"choices": [{"text": "hello"}]}
test_response = {"choices": [{"text": "hello"}]}
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.get_request() == test_request
response = cache.get(test_request)
assert response is None
response = cache.get(test_request, overwrite_cache=False, compute=compute)
cache.set(test_request, test_response)
response = cache.get(test_request)
assert response.get_response() == "hello"
assert response.is_cached()
assert response.get_request() == test_request
response = cache.get(test_request, overwrite_cache=True, compute=compute)
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.get_request() == test_request
arr = np.random.rand(4, 4)
test_request = {"test": "hello", "testA": "world of images"}
compute_arr = lambda: {"choices": [{"array": arr}]}
compute_arr_response = {"choices": [{"array": arr}]}
# Test array
if cache_type == "sqlite":
@ -114,9 +108,64 @@ def test_get(
elif cache_type == "postgres":
cache = _get_postgres_cache(client_name="diffuser")
response = cache.get(test_request, overwrite_cache=False, compute=compute_arr)
response = cache.get(test_request)
assert response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response(), arr)
assert not response.is_cached()
assert response.is_cached()
assert response.get_request() == test_request
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.usefixtures("postgres_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis", "postgres"])
def test_get_batch_prompt(
sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str
) -> None:
"""Test cache save prompt."""
if cache_type == "sqlite":
cache = cast(Cache, SQLiteCache(sqlite_cache))
elif cache_type == "redis":
cache = cast(Cache, RedisCache(redis_cache))
elif cache_type == "postgres":
cache = cast(Cache, _get_postgres_cache())
test_request = {"test": ["hello", "goodbye"], "testA": "world"}
test_response = {"choices": [{"text": "hello"}, {"text": "goodbye"}]}
response = cache.get(test_request)
assert response is None
cache.set(test_request, test_response)
response = cache.get(test_request)
assert response.get_response() == ["hello", "goodbye"]
assert response.is_cached()
assert response.get_request() == test_request
# Test arrays
arr = np.random.rand(4, 4)
arr2 = np.random.rand(4, 4)
test_request = {"test": ["hello", "goodbye"], "testA": "world of images"}
compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]}
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, client_name="diffuser")
elif cache_type == "redis":
cache = RedisCache(redis_cache, client_name="diffuser")
elif cache_type == "postgres":
cache = _get_postgres_cache(client_name="diffuser")
response = cache.get(test_request)
assert response is None
cache.set(test_request, compute_arr_response)
response = cache.get(test_request)
assert np.allclose(response.get_response()[0], arr)
assert np.allclose(response.get_response()[1], arr2)
assert response.is_cached()
assert response.get_request() == test_request
@ -137,14 +186,11 @@ def test_noop_cache() -> None:
# Assert always not cached
test_request = {"test": "hello", "testA": "world"}
compute = lambda: {"choices": [{"text": "hello"}]}
test_response = {"choices": [{"text": "hello"}]}
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.get_request() == test_request
response = cache.get(test_request)
assert response is None
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.get_request() == test_request
cache.set(test_request, test_response)
response = cache.get(test_request)
assert response is None

@ -27,17 +27,29 @@ def test_get_request() -> None:
"""Test client get request."""
args = {"n": 3}
client = DummyClient(connection_str=None, client_args=args)
request_params = client.get_request_params("hello", {})
request_func, request_params_return = client.get_request(request_params)
assert request_params_return == {"prompt": "hello", "num_results": 3}
assert request_func() == {"choices": [{"text": "hello"}] * 3}
request_params = client.get_request_params("hello", {"n": 5})
request_func, request_params_return = client.get_request(request_params)
assert request_params_return == {"prompt": "hello", "num_results": 5}
assert request_func() == {"choices": [{"text": "hello"}] * 5}
request_params = client.get_request_params(["hello"] * 5, {"n": 1})
request_func, request_params_return = client.get_request(request_params)
assert request_params_return == {"prompt": ["hello"] * 5, "num_results": 1}
assert request_func() == {"choices": [{"text": "hello"}] * 5}
request_params = client.get_request("hello", {})
response = client.run_request(request_params)
assert client.get_cache_key(request_params) == {
"prompt": "hello",
"num_results": 3,
"engine": "dummy",
}
assert response.get_json_response() == {"choices": [{"text": "hello"}] * 3}
request_params = client.get_request("hello", {"n": 5})
response = client.run_request(request_params)
assert client.get_cache_key(request_params) == {
"prompt": "hello",
"num_results": 5,
"engine": "dummy",
}
assert response.get_json_response() == {"choices": [{"text": "hello"}] * 5}
request_params = client.get_request(["hello"] * 5, {"n": 1})
response = client.run_request(request_params)
assert client.get_cache_key(request_params) == {
"prompt": ["hello"] * 5,
"num_results": 1,
"engine": "dummy",
}
assert response.get_json_response() == {"choices": [{"text": "hello"}] * 5}

@ -141,8 +141,8 @@ def test_gpt_score() -> None:
result = model.score_sequence(inputs)
assert result is not None
assert len(result) == 2
assert math.isclose(round(result[0][0], 3), -19.935)
assert math.isclose(round(result[1][0], 3), -45.831)
assert math.isclose(round(result[0][0], 3), -46.71)
assert math.isclose(round(result[1][0], 3), -12.752)
assert isinstance(result[0][1], list)
assert isinstance(result[1][1], list)

@ -1,19 +1,25 @@
"""Manifest test."""
import json
import asyncio
from typing import cast
import pytest
import requests
from manifest import Manifest, Response
from manifest.caches.noop import NoopCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.dummy import DummyClient
from manifest.session import Session
URL = "http://localhost:6000"
try:
_ = requests.post(URL + "/params").json()
MODEL_ALIVE = True
except Exception:
MODEL_ALIVE = False
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("session_cache")
def test_init(sqlite_cache: str, session_cache: str) -> None:
def test_init(sqlite_cache: str) -> None:
"""Test manifest initialization."""
with pytest.raises(ValueError) as exc_info:
Manifest(
@ -32,7 +38,6 @@ def test_init(sqlite_cache: str, session_cache: str) -> None:
assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache)
assert manifest.session is None
assert manifest.client.n == 1 # type: ignore
assert manifest.stop_token == ""
@ -41,19 +46,16 @@ def test_init(sqlite_cache: str, session_cache: str) -> None:
cache_name="noop",
n=3,
stop_token="\n",
session_id="_default",
)
assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, NoopCache)
assert isinstance(manifest.session, Session)
assert manifest.client.n == 3 # type: ignore
assert manifest.stop_token == "\n"
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("session_cache")
def test_change_manifest(sqlite_cache: str, session_cache: str) -> None:
def test_change_manifest(sqlite_cache: str) -> None:
"""Test manifest change."""
manifest = Manifest(
client_name="dummy",
@ -65,7 +67,6 @@ def test_change_manifest(sqlite_cache: str, session_cache: str) -> None:
assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache)
assert manifest.session is None
assert manifest.client.n == 1 # type: ignore
assert manifest.stop_token == ""
@ -73,18 +74,14 @@ def test_change_manifest(sqlite_cache: str, session_cache: str) -> None:
assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache)
assert manifest.session is None
assert manifest.client.n == 1 # type: ignore
assert manifest.stop_token == "\n"
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("session_cache")
@pytest.mark.parametrize("n", [1, 2])
@pytest.mark.parametrize("return_response", [True, False])
def test_run(
sqlite_cache: str, session_cache: str, n: int, return_response: bool
) -> None:
def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
"""Test manifest run."""
manifest = Manifest(
client_name="dummy",
@ -111,15 +108,12 @@ def test_run(
else:
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
{
"prompt": "This is a prompt",
"engine": "dummy",
"num_results": n,
},
sort_keys=True,
)
manifest.cache.get(
{
"prompt": "This is a prompt",
"engine": "dummy",
"num_results": n,
},
)
is not None
)
@ -136,16 +130,13 @@ def test_run(
else:
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
{
"prompt": "This is a prompt",
"engine": "dummy",
"num_results": n,
"run_id": "34",
},
sort_keys=True,
)
manifest.cache.get(
{
"prompt": "This is a prompt",
"engine": "dummy",
"num_results": n,
"run_id": "34",
}
)
is not None
)
@ -162,15 +153,12 @@ def test_run(
else:
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"num_results": n,
},
sort_keys=True,
)
manifest.cache.get(
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"num_results": n,
},
)
is not None
)
@ -187,15 +175,12 @@ def test_run(
else:
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"num_results": n,
},
sort_keys=True,
)
manifest.cache.get(
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"num_results": n,
},
)
is not None
)
@ -206,12 +191,9 @@ def test_run(
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("session_cache")
@pytest.mark.parametrize("n", [1, 2])
@pytest.mark.parametrize("return_response", [True, False])
def test_batch_run(
sqlite_cache: str, session_cache: str, n: int, return_response: bool
) -> None:
def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
"""Test manifest run."""
manifest = Manifest(
client_name="dummy",
@ -233,6 +215,16 @@ def test_batch_run(
else:
res = cast(str, result)
assert res == ["hello"]
assert (
manifest.cache.get(
{
"prompt": "This is a prompt",
"engine": "dummy",
"num_results": n,
},
)
is not None
)
prompt = ["Hello is a prompt", "Hello is a prompt"]
result = manifest.run(prompt, return_response=return_response)
@ -243,6 +235,42 @@ def test_batch_run(
else:
res = cast(str, result)
assert res == ["hello", "hello"]
assert (
manifest.cache.get(
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"num_results": n,
},
)
is not None
)
result = manifest.run(prompt, return_response=True)
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
assert cast(Response, result).is_cached()
assert (
manifest.cache.get(
{
"prompt": "New prompt",
"engine": "dummy",
"num_results": n,
},
)
is None
)
prompt = ["This is a prompt", "New prompt"]
result = manifest.run(prompt, return_response=return_response)
if return_response:
res = cast(Response, result).get_response(
manifest.stop_token, is_batch=True
)
# Cached because one item is in cache
assert cast(Response, result).is_cached()
else:
res = cast(str, result)
assert res == ["hello", "hello"]
prompt = ["Hello is a prompt", "Hello is a prompt"]
result = manifest.run(prompt, stop_token="ll", return_response=return_response)
@ -253,6 +281,72 @@ def test_batch_run(
assert res == ["he", "he"]
@pytest.mark.usefixtures("sqlite_cache")
def test_abatch_run(sqlite_cache: str) -> None:
"""Test manifest run."""
manifest = Manifest(
client_name="dummy",
cache_name="sqlite",
cache_connection=sqlite_cache,
)
prompt = ["This is a prompt"]
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
assert res == ["hello"]
assert (
manifest.cache.get(
{
"prompt": "This is a prompt",
"engine": "dummy",
"num_results": 1,
},
)
is not None
)
prompt = ["Hello is a prompt", "Hello is a prompt"]
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
assert res == ["hello", "hello"]
assert (
manifest.cache.get(
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"num_results": 1,
},
)
is not None
)
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
assert cast(Response, result).is_cached()
assert (
manifest.cache.get(
{
"prompt": "New prompt",
"engine": "dummy",
"num_results": 1,
},
)
is None
)
prompt = ["This is a prompt", "New prompt"]
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
# Cached because one item is in cache
assert cast(Response, result).is_cached()
assert res == ["hello", "hello"]
prompt = ["Hello is a prompt", "Hello is a prompt"]
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
res = cast(Response, result).get_response(stop_token="ll", is_batch=True)
assert res == ["he", "he"]
@pytest.mark.usefixtures("sqlite_cache")
def test_score_run(sqlite_cache: str) -> None:
"""Test manifest run."""
@ -264,16 +358,14 @@ def test_score_run(sqlite_cache: str) -> None:
prompt = "This is a prompt"
result = manifest.score_prompt(prompt)
assert (
manifest.cache.get_key(
json.dumps(
{
"prompt": "This is a prompt",
"engine": "dummy",
},
sort_keys=True,
)
manifest.cache.get(
{
"prompt": "This is a prompt",
"engine": "dummy",
"num_results": 1,
"request_type": "score_prompt",
},
)
is not None
)
@ -284,20 +376,35 @@ def test_score_run(sqlite_cache: str) -> None:
"item_dtype": None,
"response": {"choices": [{"text": "This is a prompt", "logprob": 0.3}]},
"cached": False,
"request_params": {"prompt": "This is a prompt", "engine": "dummy"},
"request_params": {
"prompt": "This is a prompt",
"engine": "dummy",
"num_results": 1,
"request_type": "score_prompt",
},
}
prompt_list = ["Hello is a prompt", "Hello is another prompt"]
result = manifest.score_prompt(prompt_list)
assert (
manifest.cache.get_key(
json.dumps(
{
"prompt": ["Hello is a prompt", "Hello is another prompt"],
"engine": "dummy",
},
sort_keys=True,
)
manifest.cache.get(
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"num_results": 1,
"request_type": "score_prompt",
},
)
is not None
)
assert (
manifest.cache.get(
{
"prompt": "Hello is another prompt",
"engine": "dummy",
"num_results": 1,
"request_type": "score_prompt",
},
)
is not None
)
@ -316,76 +423,64 @@ def test_score_run(sqlite_cache: str) -> None:
"request_params": {
"prompt": ["Hello is a prompt", "Hello is another prompt"],
"engine": "dummy",
"num_results": 1,
"request_type": "score_prompt",
},
}
@pytest.mark.usefixtures("session_cache")
def test_log_query(session_cache: str) -> None:
"""Test manifest session logging."""
manifest = Manifest(client_name="dummy", cache_name="noop", session_id="_default")
prompt = "This is a prompt"
_ = manifest.run(prompt, return_response=False)
query_key = {
"prompt": "This is a prompt",
"engine": "dummy",
"num_results": 1,
}
response_key = {
"cached": False,
"request_params": query_key,
"response": {"choices": [{"text": "hello"}]},
"generation_key": "choices",
"item_dtype": None,
"item_key": "text",
"logits_key": "token_logprobs",
}
assert manifest.get_last_queries(1) == [("This is a prompt", "hello")]
assert manifest.get_last_queries(1, return_raw_values=True) == [
(query_key, response_key)
]
assert manifest.get_last_queries(3, return_raw_values=True) == [
(query_key, response_key)
]
prior_cache_item = (query_key, response_key)
prompt_lst = ["This is a prompt", "This is a prompt2"]
_ = manifest.run(prompt_lst, return_response=False)
query_key = {
"prompt": ["This is a prompt", "This is a prompt2"],
"engine": "dummy",
"num_results": 1,
}
response_key = {
"cached": False,
"generation_key": "choices",
"item_dtype": None,
"item_key": "text",
"logits_key": "token_logprobs",
"request_params": query_key,
"response": {"choices": [{"text": "hello"}, {"text": "hello"}]},
}
assert manifest.get_last_queries(1) == [
(["This is a prompt", "This is a prompt2"], ["hello", "hello"])
]
assert manifest.get_last_queries(1, return_raw_values=True) == [
(query_key, response_key)
]
assert manifest.get_last_queries(3, return_raw_values=True) == [
prior_cache_item,
(query_key, response_key),
]
# Test no session
manifest = Manifest(
client_name="dummy",
cache_name="noop",
@pytest.mark.skipif(not MODEL_ALIVE, reason=f"No model at {URL}")
@pytest.mark.usefixtures("sqlite_cache")
def test_local_huggingface(sqlite_cache: str) -> None:
"""Test local huggingface client."""
client = Manifest(
client_name="huggingface",
client_connection=URL,
cache_name="sqlite",
cache_connection=sqlite_cache,
)
prompt = "This is a prompt"
_ = manifest.run(prompt, return_response=False)
with pytest.raises(ValueError) as exc_info:
manifest.get_last_queries(1)
assert (
str(exc_info.value)
== "Session was not initialized. Set `session_id` when loading Manifest."
res = client.run("Why are there apples?")
assert isinstance(res, str) and len(res) > 0
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
assert response.is_cached() is True
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert response.is_cached() is True
res_list = client.run(["Why are there apples?", "Why are there bananas?"])
assert isinstance(res_list, list) and len(res_list) == 2
response = cast(
Response, client.run("Why are there bananas?", return_response=True)
)
assert response.is_cached() is True
res_list = asyncio.run(
client.arun_batch(["Why are there pears?", "Why are there oranges?"])
)
assert isinstance(res_list, list) and len(res_list) == 2
response = cast(
Response, client.run("Why are there oranges?", return_response=True)
)
assert response.is_cached() is True
scores = client.score_prompt("Why are there apples?")
assert isinstance(scores, dict) and len(scores) > 0
assert scores["cached"] is False
assert len(scores["response"]["choices"][0]["token_logprobs"]) == len(
scores["response"]["choices"][0]["tokens"]
)
scores = client.score_prompt(["Why are there apples?", "Why are there bananas?"])
assert isinstance(scores, dict) and len(scores) > 0
assert scores["cached"] is True
assert len(scores["response"]["choices"][0]["token_logprobs"]) == len(
scores["response"]["choices"][0]["tokens"]
)
assert len(scores["response"]["choices"][0]["token_logprobs"]) == len(
scores["response"]["choices"][0]["tokens"]
)

@ -10,10 +10,10 @@ def test_llm_init() -> None:
request = LMRequest(temperature=0.5)
assert request.temperature == 0.5
request = LMRequest(**{"temperature": 0.5})
request = LMRequest(**{"temperature": 0.5}) # type: ignore
assert request.temperature == 0.5
request = LMRequest(**{"temperature": 0.5, "prompt": "test"})
request = LMRequest(**{"temperature": 0.5, "prompt": "test"}) # type: ignore
assert request.temperature == 0.5
assert request.prompt == "test"
@ -26,10 +26,10 @@ def test_diff_init() -> None:
request = DiffusionRequest(height=128)
assert request.height == 128
request = DiffusionRequest(**{"height": 128})
request = DiffusionRequest(**{"height": 128}) # type: ignore
assert request.height == 128
request = DiffusionRequest(**{"height": 128, "prompt": "test"})
request = DiffusionRequest(**{"height": 128, "prompt": "test"}) # type: ignore
assert request.height == 128
assert request.prompt == "test"

@ -1,13 +1,12 @@
"""Cache test."""
import json
from pathlib import Path
import numpy as np
from manifest.caches.serializers import ArraySerializer
def test_response_to_key(session_cache: Path) -> None:
def test_response_to_key() -> None:
"""Test array serializer initialization."""
serializer = ArraySerializer()
arr = np.random.rand(4, 4)

@ -1,81 +0,0 @@
"""Test session."""
import sqlite3
from pathlib import Path
import pytest
from manifest.session import Session
@pytest.mark.usefixtures("session_cache")
def test_init(session_cache: Path) -> None:
"""Test session initialization."""
session = Session()
assert isinstance(session.conn, sqlite3.Connection)
assert session.db_file == session_cache / ".manifest" / "session.db"
assert session.query_id == 0
assert (session_cache / ".manifest" / "session.db").exists()
# Remove session cache file.
(session_cache / ".manifest" / "session.db").unlink()
session = Session("dog_days")
assert isinstance(session.conn, sqlite3.Connection)
assert session.db_file == session_cache / ".manifest" / "session.db"
assert session.query_id == 0
assert session.session_id == "dog_days"
assert (session_cache / ".manifest" / "session.db").exists()
session.close()
@pytest.mark.usefixtures("session_cache")
def test_log_query(session_cache: Path) -> None:
"""Test session log_query."""
session = Session()
assert session.get_last_queries(1) == []
query_key = {"query": "What is your name?", "time": "now"}
response_key = {"response": "I don't have a name", "engine": "nodel"}
session.log_query(query_key, response_key)
assert session.query_id == 1
assert session.get_last_queries(1) == [(query_key, response_key)]
query_key2 = {"query2": "What is your name?", "time": "now"}
response_key2 = {"response2": "I don't have a name", "engine": "nodel"}
session.log_query(query_key2, response_key2)
assert session.query_id == 2
assert len(session.get_last_queries(1)) == 1
assert session.get_last_queries(2) == [
(query_key, response_key),
(query_key2, response_key2),
]
session.close()
@pytest.mark.usefixtures("session_cache")
def test_resume_query(session_cache: Path) -> None:
"""Test session log_query."""
session = Session(session_id="dog_days")
query_key = {"query": "What is your name?", "time": "now"}
response_key = {"response": "I don't have a name", "engine": "nodel"}
session.log_query(query_key, response_key)
session.close()
session = Session(session_id="dog_days")
assert session.query_id == 1
@pytest.mark.usefixtures("session_cache")
def test_session_keys(session_cache: Path) -> None:
"""Test get session keys."""
# Assert empty before queries
assert Session.get_session_keys(session_cache / ".manifest" / "session.db") == []
# Add queries and make sure session is logged
session = Session(session_id="dog_days")
query_key = {"query": "What is your name?", "time": "now"}
response_key = {"response": "I don't have a name", "engine": "nodel"}
session.log_query(query_key, response_key)
session.close()
assert Session.get_session_keys(session_cache / ".manifest" / "session.db") == [
"dog_days"
]
Loading…
Cancel
Save