feat: chatgpt client added (#47)

laurel/helm
Laurel Orr 1 year ago committed by GitHub
parent defc63bf36
commit 56eae406ce

1
.gitignore vendored

@ -35,7 +35,6 @@ wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template

@ -4,6 +4,8 @@ Added
^^^^^
* Batched inference support in `manifest.run`. No more separate `manifest.run_batch` method.
* Standard request base model for all language inputs.
* ChatGPT client. Requires CHATGPT_SESSION_KEY to be passed in.
* Diffusion model support
Fixed
^^^^^^^^

@ -16,6 +16,12 @@ Install:
pip install manifest-ml
```
Install with ChatGPT Support:
```bash
pip install manifest-ml[chatgpt]
```
This installs [pyChatGPT](https://github.com/terry3041/pyChatGPT) and uses the ChatGPT session key to start a session. This key must be set as the `CHATGPT_SESSION_KEY` environment variable or passed in with `client_connection`.
Install with HuggingFace API Support:
```bash
pip install manifest-ml[api]

@ -428,7 +428,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "manifest",
"display_name": "bootleg",
"language": "python",
"name": "python3"
},
@ -442,11 +442,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15"
"version": "3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:36:06) \n[Clang 11.1.0 ]"
},
"vscode": {
"interpreter": {
"hash": "0d67557b0e03f6eb64c46b70fb42ce7c6498a7305f9f3922c351822f3fc8e363"
"hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250"
}
}
},

@ -0,0 +1,79 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from manifest import Manifest\n",
"import os\n",
"\n",
"# ChatGPT tries hard not to give people programmatic access.\n",
"# As a warning, this will open a browser window.\n",
"# You need to install xvfb and chromium for linux headless mode to work\n",
"# See https://github.com/terry3041/pyChatGPT\n",
"\n",
"# The responses are not fast\n",
"manifest = Manifest(\n",
" client_name=\"chatgpt\",\n",
" client_connection=os.environ.get(\"CHATGPT_SESSION_KEY\"),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sure! Pickling is a way to save things, like food or toys, so that they can be used later. Imagine you have a toy that you really like, but you have to go to school and can't play with it. You can put the toy in a special jar and close the lid tight to keep it safe until you get home. That's kind of like pickling. You're taking something that you want to save, and putting it in a special container so it won't go bad or get lost. Just like the toy in the jar, pickled food can last a long time without going bad.\n"
]
}
],
"source": [
"print(manifest.run(\"Can you explain the pickling process to a four-year old?\"))\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "bootleg",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -0,0 +1,130 @@
"""Client class."""
import logging
import os
from typing import Any, Callable, Dict, Optional, Tuple
from pyChatGPT import ChatGPT
from manifest.clients.client import Client
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
class ChatGPTClient(Client):
"""ChatGPT Client class."""
# No params for ChatGPT
PARAMS = {}
REQUEST_CLS = LMRequest
def connect(
self, connection_str: Optional[str], client_args: Dict[str, Any]
) -> None:
"""
Connect to ChatGPT.
We use https://github.com/terry3041/pyChatGPT.
Arsg:
connection_str: connection string.
client_args: client arguments.
"""
self.session_key = os.environ.get("CHATGPT_SESSION_KEY", connection_str)
if self.session_key is None:
raise ValueError(
"ChatGPT session key not set. Set CHATGPT_SESSION_KEY environment "
"variable or pass through `client_connection`. "
"For details, see https://github.com/terry3041/pyChatGPT "
"and go through instructions for getting a session key."
)
self.host = None
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
self._chat_session = ChatGPT(self.session_key, verbose=False)
def close(self) -> None:
"""Close the client."""
self._chat_session = None
def clear_conversations(self) -> None:
"""Clear conversations.
Only works for ChatGPT.
"""
self._chat_session.clear_conversations()
def get_generation_url(self) -> str:
"""Get generation URL."""
return ""
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return {}
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return False
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "chatgpt", "engine": "chatgpt"}
def format_response(self, response: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
Return:
response as dict
"""
return {
"model": "chatgpt",
"choices": [
{
"text": response["message"],
}
],
}
def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Args:
request: request.
Returns:
request function that takes no input.
request parameters as dict.
"""
if isinstance(request.prompt, list):
raise ValueError("ChatGPT does not support batch inference.")
prompt = str(request.prompt)
request_params = request.to_dict(self.PARAMS)
def _run_completion() -> Dict:
try:
res = self._chat_session.send_message(prompt)
except Exception as e:
logger.error(f"ChatGPT error {e}.")
raise e
return self.format_response(res)
return _run_completion, request_params

@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
# Engines are dynamically instantiated from API
# but a few example engines are listed below.
TOMA_ENGINES = {
"StableDiffusion",
# "StableDiffusion",
"Together-gpt-JT-6B-v1",
}

@ -8,6 +8,7 @@ from manifest.caches.noop import NoopCache
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.ai21 import AI21Client
from manifest.clients.chatgpt import ChatGPTClient
from manifest.clients.cohere import CohereClient
from manifest.clients.diffuser import DiffuserClient
from manifest.clients.dummy import DummyClient
@ -22,6 +23,7 @@ logger = logging.getLogger(__name__)
CLIENT_CONSTRUCTORS = {
"openai": OpenAIClient,
"chatgpt": ChatGPTClient,
"cohere": CohereClient,
"ai21": AI21Client,
"huggingface": HuggingFaceClient,

@ -15,7 +15,8 @@ module = [
"accelerate.utils.modeling",
"transformers",
"flask",
"torch"
"torch",
"pyChatGPT",
]
[tool.isort]

@ -34,13 +34,18 @@ EXTRAS = {
"api": [
"diffusers>=0.6.0",
"Flask>=2.1.2",
"fastapi>=0.70.0",
"uvicorn>=0.18.0",
"accelerate>=0.10.0",
"transformers>=4.20.0",
"torch>=1.8.0",
"numpy>=1.20.0",
],
"app": [
"fastapi>=0.70.0",
"uvicorn>=0.18.0",
],
"chatgpt": [
"pyChatGPT>=0.4.3",
],
"dev": [
"autopep8>=1.6.0",
"black>=22.3.0",

Loading…
Cancel
Save