From 56eae406cec9cacc72dd400ce029265d68bfb3ed Mon Sep 17 00:00:00 2001 From: Laurel Orr <57237365+lorr1@users.noreply.github.com> Date: Sun, 8 Jan 2023 14:58:12 -0800 Subject: [PATCH] feat: chatgpt client added (#47) --- .gitignore | 1 - CHANGELOG.rst | 2 + README.md | 6 ++ examples/langchain_chatgpt.ipynb | 6 +- examples/manifest_chatgpt.ipynb | 79 +++++++++++++++++++ manifest/clients/chatgpt.py | 130 +++++++++++++++++++++++++++++++ manifest/clients/toma.py | 2 +- manifest/manifest.py | 2 + pyproject.toml | 3 +- setup.py | 9 ++- 10 files changed, 232 insertions(+), 8 deletions(-) create mode 100644 examples/manifest_chatgpt.ipynb create mode 100644 manifest/clients/chatgpt.py diff --git a/.gitignore b/.gitignore index 58e8776..e0cd9fd 100644 --- a/.gitignore +++ b/.gitignore @@ -35,7 +35,6 @@ wheels/ *.egg-info/ .installed.cfg *.egg -MANIFEST # PyInstaller # Usually these files are written by a python script from a template diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 86eb517..1a70c22 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,8 @@ Added ^^^^^ * Batched inference support in `manifest.run`. No more separate `manifest.run_batch` method. * Standard request base model for all language inputs. +* ChatGPT client. Requires CHATGPT_SESSION_KEY to be passed in. +* Diffusion model support Fixed ^^^^^^^^ diff --git a/README.md b/README.md index 7d42d3c..d77720f 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,12 @@ Install: pip install manifest-ml ``` +Install with ChatGPT Support: +```bash +pip install manifest-ml[chatgpt] +``` +This installs [pyChatGPT](https://github.com/terry3041/pyChatGPT) and uses the ChatGPT session key to start a session. This key must be set as the `CHATGPT_SESSION_KEY` environment variable or passed in with `client_connection`. + Install with HuggingFace API Support: ```bash pip install manifest-ml[api] diff --git a/examples/langchain_chatgpt.ipynb b/examples/langchain_chatgpt.ipynb index d67bbcd..3b4b0c6 100644 --- a/examples/langchain_chatgpt.ipynb +++ b/examples/langchain_chatgpt.ipynb @@ -428,7 +428,7 @@ ], "metadata": { "kernelspec": { - "display_name": "manifest", + "display_name": "bootleg", "language": "python", "name": "python3" }, @@ -442,11 +442,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.15" + "version": "3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:36:06) \n[Clang 11.1.0 ]" }, "vscode": { "interpreter": { - "hash": "0d67557b0e03f6eb64c46b70fb42ce7c6498a7305f9f3922c351822f3fc8e363" + "hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250" } } }, diff --git a/examples/manifest_chatgpt.ipynb b/examples/manifest_chatgpt.ipynb new file mode 100644 index 0000000..931ef00 --- /dev/null +++ b/examples/manifest_chatgpt.ipynb @@ -0,0 +1,79 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from manifest import Manifest\n", + "import os\n", + "\n", + "# ChatGPT tries hard not to give people programmatic access.\n", + "# As a warning, this will open a browser window.\n", + "# You need to install xvfb and chromium for linux headless mode to work\n", + "# See https://github.com/terry3041/pyChatGPT\n", + "\n", + "# The responses are not fast\n", + "manifest = Manifest(\n", + " client_name=\"chatgpt\",\n", + " client_connection=os.environ.get(\"CHATGPT_SESSION_KEY\"),\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sure! Pickling is a way to save things, like food or toys, so that they can be used later. Imagine you have a toy that you really like, but you have to go to school and can't play with it. You can put the toy in a special jar and close the lid tight to keep it safe until you get home. That's kind of like pickling. You're taking something that you want to save, and putting it in a special container so it won't go bad or get lost. Just like the toy in the jar, pickled food can last a long time without going bad.\n" + ] + } + ], + "source": [ + "print(manifest.run(\"Can you explain the pickling process to a four-year old?\"))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bootleg", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/manifest/clients/chatgpt.py b/manifest/clients/chatgpt.py new file mode 100644 index 0000000..794080d --- /dev/null +++ b/manifest/clients/chatgpt.py @@ -0,0 +1,130 @@ +"""Client class.""" +import logging +import os +from typing import Any, Callable, Dict, Optional, Tuple + +from pyChatGPT import ChatGPT + +from manifest.clients.client import Client +from manifest.request import LMRequest, Request + +logger = logging.getLogger(__name__) + + +class ChatGPTClient(Client): + """ChatGPT Client class.""" + + # No params for ChatGPT + PARAMS = {} + REQUEST_CLS = LMRequest + + def connect( + self, connection_str: Optional[str], client_args: Dict[str, Any] + ) -> None: + """ + Connect to ChatGPT. + + We use https://github.com/terry3041/pyChatGPT. + + Arsg: + connection_str: connection string. + client_args: client arguments. + """ + self.session_key = os.environ.get("CHATGPT_SESSION_KEY", connection_str) + if self.session_key is None: + raise ValueError( + "ChatGPT session key not set. Set CHATGPT_SESSION_KEY environment " + "variable or pass through `client_connection`. " + "For details, see https://github.com/terry3041/pyChatGPT " + "and go through instructions for getting a session key." + ) + self.host = None + for key in self.PARAMS: + setattr(self, key, client_args.pop(key, self.PARAMS[key][1])) + self._chat_session = ChatGPT(self.session_key, verbose=False) + + def close(self) -> None: + """Close the client.""" + self._chat_session = None + + def clear_conversations(self) -> None: + """Clear conversations. + + Only works for ChatGPT. + """ + self._chat_session.clear_conversations() + + def get_generation_url(self) -> str: + """Get generation URL.""" + return "" + + def get_generation_header(self) -> Dict[str, str]: + """ + Get generation header. + + Returns: + header. + """ + return {} + + def supports_batch_inference(self) -> bool: + """Return whether the client supports batch inference.""" + return False + + def get_model_params(self) -> Dict: + """ + Get model params. + + By getting model params from the server, we can add to request + and make sure cache keys are unique to model. + + Returns: + model params. + """ + return {"model_name": "chatgpt", "engine": "chatgpt"} + + def format_response(self, response: Dict) -> Dict[str, Any]: + """ + Format response to dict. + + Args: + response: response + + Return: + response as dict + """ + return { + "model": "chatgpt", + "choices": [ + { + "text": response["message"], + } + ], + } + + def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]: + """ + Get request string function. + + Args: + request: request. + + Returns: + request function that takes no input. + request parameters as dict. + """ + if isinstance(request.prompt, list): + raise ValueError("ChatGPT does not support batch inference.") + + prompt = str(request.prompt) + request_params = request.to_dict(self.PARAMS) + + def _run_completion() -> Dict: + try: + res = self._chat_session.send_message(prompt) + except Exception as e: + logger.error(f"ChatGPT error {e}.") + raise e + return self.format_response(res) + + return _run_completion, request_params diff --git a/manifest/clients/toma.py b/manifest/clients/toma.py index 3a20edf..56f849a 100644 --- a/manifest/clients/toma.py +++ b/manifest/clients/toma.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) # Engines are dynamically instantiated from API # but a few example engines are listed below. TOMA_ENGINES = { - "StableDiffusion", + # "StableDiffusion", "Together-gpt-JT-6B-v1", } diff --git a/manifest/manifest.py b/manifest/manifest.py index b81a3e7..a740d35 100644 --- a/manifest/manifest.py +++ b/manifest/manifest.py @@ -8,6 +8,7 @@ from manifest.caches.noop import NoopCache from manifest.caches.redis import RedisCache from manifest.caches.sqlite import SQLiteCache from manifest.clients.ai21 import AI21Client +from manifest.clients.chatgpt import ChatGPTClient from manifest.clients.cohere import CohereClient from manifest.clients.diffuser import DiffuserClient from manifest.clients.dummy import DummyClient @@ -22,6 +23,7 @@ logger = logging.getLogger(__name__) CLIENT_CONSTRUCTORS = { "openai": OpenAIClient, + "chatgpt": ChatGPTClient, "cohere": CohereClient, "ai21": AI21Client, "huggingface": HuggingFaceClient, diff --git a/pyproject.toml b/pyproject.toml index ed9b034..7b4b0ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ module = [ "accelerate.utils.modeling", "transformers", "flask", - "torch" + "torch", + "pyChatGPT", ] [tool.isort] diff --git a/setup.py b/setup.py index 6bc10a9..7933e2a 100644 --- a/setup.py +++ b/setup.py @@ -34,13 +34,18 @@ EXTRAS = { "api": [ "diffusers>=0.6.0", "Flask>=2.1.2", - "fastapi>=0.70.0", - "uvicorn>=0.18.0", "accelerate>=0.10.0", "transformers>=4.20.0", "torch>=1.8.0", "numpy>=1.20.0", ], + "app": [ + "fastapi>=0.70.0", + "uvicorn>=0.18.0", + ], + "chatgpt": [ + "pyChatGPT>=0.4.3", + ], "dev": [ "autopep8>=1.6.0", "black>=22.3.0",