mirror of https://github.com/HazyResearch/manifest
feat: chatgpt client added (#47)
parent
defc63bf36
commit
56eae406ce
@ -0,0 +1,79 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from manifest import Manifest\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# ChatGPT tries hard not to give people programmatic access.\n",
|
||||
"# As a warning, this will open a browser window.\n",
|
||||
"# You need to install xvfb and chromium for linux headless mode to work\n",
|
||||
"# See https://github.com/terry3041/pyChatGPT\n",
|
||||
"\n",
|
||||
"# The responses are not fast\n",
|
||||
"manifest = Manifest(\n",
|
||||
" client_name=\"chatgpt\",\n",
|
||||
" client_connection=os.environ.get(\"CHATGPT_SESSION_KEY\"),\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sure! Pickling is a way to save things, like food or toys, so that they can be used later. Imagine you have a toy that you really like, but you have to go to school and can't play with it. You can put the toy in a special jar and close the lid tight to keep it safe until you get home. That's kind of like pickling. You're taking something that you want to save, and putting it in a special container so it won't go bad or get lost. Just like the toy in the jar, pickled food can last a long time without going bad.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(manifest.run(\"Can you explain the pickling process to a four-year old?\"))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "bootleg",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.12"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,130 @@
|
||||
"""Client class."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
from pyChatGPT import ChatGPT
|
||||
|
||||
from manifest.clients.client import Client
|
||||
from manifest.request import LMRequest, Request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatGPTClient(Client):
|
||||
"""ChatGPT Client class."""
|
||||
|
||||
# No params for ChatGPT
|
||||
PARAMS = {}
|
||||
REQUEST_CLS = LMRequest
|
||||
|
||||
def connect(
|
||||
self, connection_str: Optional[str], client_args: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Connect to ChatGPT.
|
||||
|
||||
We use https://github.com/terry3041/pyChatGPT.
|
||||
|
||||
Arsg:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.session_key = os.environ.get("CHATGPT_SESSION_KEY", connection_str)
|
||||
if self.session_key is None:
|
||||
raise ValueError(
|
||||
"ChatGPT session key not set. Set CHATGPT_SESSION_KEY environment "
|
||||
"variable or pass through `client_connection`. "
|
||||
"For details, see https://github.com/terry3041/pyChatGPT "
|
||||
"and go through instructions for getting a session key."
|
||||
)
|
||||
self.host = None
|
||||
for key in self.PARAMS:
|
||||
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
||||
self._chat_session = ChatGPT(self.session_key, verbose=False)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
self._chat_session = None
|
||||
|
||||
def clear_conversations(self) -> None:
|
||||
"""Clear conversations.
|
||||
|
||||
Only works for ChatGPT.
|
||||
"""
|
||||
self._chat_session.clear_conversations()
|
||||
|
||||
def get_generation_url(self) -> str:
|
||||
"""Get generation URL."""
|
||||
return ""
|
||||
|
||||
def get_generation_header(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get generation header.
|
||||
|
||||
Returns:
|
||||
header.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def supports_batch_inference(self) -> bool:
|
||||
"""Return whether the client supports batch inference."""
|
||||
return False
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"model_name": "chatgpt", "engine": "chatgpt"}
|
||||
|
||||
def format_response(self, response: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Format response to dict.
|
||||
|
||||
Args:
|
||||
response: response
|
||||
|
||||
Return:
|
||||
response as dict
|
||||
"""
|
||||
return {
|
||||
"model": "chatgpt",
|
||||
"choices": [
|
||||
{
|
||||
"text": response["message"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
||||
Args:
|
||||
request: request.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
if isinstance(request.prompt, list):
|
||||
raise ValueError("ChatGPT does not support batch inference.")
|
||||
|
||||
prompt = str(request.prompt)
|
||||
request_params = request.to_dict(self.PARAMS)
|
||||
|
||||
def _run_completion() -> Dict:
|
||||
try:
|
||||
res = self._chat_session.send_message(prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"ChatGPT error {e}.")
|
||||
raise e
|
||||
return self.format_response(res)
|
||||
|
||||
return _run_completion, request_params
|
Loading…
Reference in New Issue