fix: ai21 prompt added with logprobs

laurel/helm
Laurel Orr 2 years ago
parent b34e2a714b
commit c04aa1e62c

@ -1,7 +1,7 @@
# Manifest
How to make prompt programming with FMs a little easier.
## Install
# Install
Download the code:
```bash
git clone git@github.com:HazyResearch/manifest.git
@ -20,7 +20,7 @@ pip install poetry
make dev
```
## Getting Started
# Getting Started
Running is simple to get started. If using OpenAI, set `export OPENAI_API_KEY=<OPENAIKEY>` then run
```python
@ -51,14 +51,14 @@ manifest.run("Why is the grass green?")
We also support Redis backend.
## Manifest Components
# Manifest Components
Manifest is meant to be a very light weight package to help with prompt iteration. Three key design decisions are
* Prompt are functional -- they can take an input example and dynamically change
* All models are behind API calls (e.g., OpenAI)
* Everything can cached for reuse to both save credits and to explore past results
### Prompts
## Prompts
A Manifest prompt is a function that accepts a single input to generate a string prompt to send to a model.
```python
@ -75,7 +75,7 @@ print(prompt())
>>> "Hello, my name is static"
```
### Sessions
## Sessions
Each Manifest run is a session that connects to a model endpoint and backend database to record prompt queries. To start a Manifest session for OpenAI, make sure you run
```bash
@ -137,7 +137,7 @@ If you want to change default parameters to a model, we pass those as `kwargs` t
result = manifest.run(prompt, "Laurel", max_tokens=50)
```
### Huggingface Models
## Huggingface Models
To use a HuggingFace generative model, in `manifest/api` we have a Falsk application that hosts the models for you.
In a separate terminal or Tmux/Screen session, run
@ -155,7 +155,7 @@ manifest = Manifest(
If you have a custom model you trained, pass the model path to `--model_name`.
## Development
# Development
Before submitting a PR, run
```bash
export REDIS_PORT="6380" # or whatever PORT local redis is running for those tests

@ -66,7 +66,39 @@ class AI21Client(Client):
"""
return {"model_name": "ai21", "engine": self.engine}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def format_response(self, response: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
Return:
response as dict
"""
return {
"object": "text_completion",
"model": self.engine,
"choices": [
{
"text": item["data"]["text"],
"logprobs": [
{
"token": tok["generatedToken"]["token"],
"logprob": tok["generatedToken"]["logprob"],
"start": tok["textRange"]["start"],
"end": tok["textRange"]["end"],
}
for tok in item["data"]["tokens"]
],
}
for item in response["completions"]
],
}
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -78,26 +110,22 @@ class AI21Client(Client):
request parameters as dict.
"""
request_params = {
"engine": kwargs.get("engine", self.engine),
"engine": request_args.pop("engine", self.engine),
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"maxTokens": kwargs.get("maxTokens", self.max_tokens),
"topKReturn": kwargs.get("topKReturn", self.top_k_return),
"numResults": kwargs.get("numResults", self.num_results),
"topP": kwargs.get("topP", self.top_p),
"temperature": request_args.pop("temperature", self.temperature),
"maxTokens": request_args.pop("maxTokens", self.max_tokens),
"topKReturn": request_args.pop("topKReturn", self.top_k_return),
"numResults": request_args.pop("numResults", self.num_results),
"topP": request_args.pop("topP", self.top_p),
}
def _run_completion() -> Dict:
post_str = self.host + "/" + self.engine + "/complete"
print(self.api_key)
print(post_str)
print("https://api.ai21.com/studio/v1/j1-large/complete")
print(request_params)
res = requests.post(
post_str,
headers={"Authorization": f"Bearer {self.api_key}"},
json=request_params,
)
return res.json()
return self.format_response(res.json())
return _run_completion, request_params

@ -52,7 +52,9 @@ class Client(ABC):
raise NotImplementedError()
@abstractmethod
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request function.
@ -62,6 +64,7 @@ class Client(ABC):
Args:
query: query string.
request_args: request arguments.
Returns:
request function that takes no input.

@ -86,6 +86,7 @@ class CRFMClient(Client):
response as dict
"""
return {
"id": response.id,
"object": "text_completion",
"model": self.engine,
"choices": [
@ -104,7 +105,9 @@ class CRFMClient(Client):
],
}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -116,16 +119,22 @@ class CRFMClient(Client):
request parameters as dict.
"""
request_params = {
"model": kwargs.get("engine", self.engine),
"model": request_args.pop("engine", self.engine),
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_k_per_token": kwargs.get("top_k_per_token", self.top_k_per_token),
"num_completions": kwargs.get("num_completions", self.num_completions),
"stop_sequences": kwargs.get("stop_sequences", self.stop_sequences),
"top_p": kwargs.get("top_p", self.top_p),
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
"frequency_penalty": kwargs.get(
"temperature": request_args.pop("temperature", self.temperature),
"max_tokens": request_args.pop("max_tokens", self.max_tokens),
"top_k_per_token": request_args.pop(
"top_k_per_token", self.top_k_per_token
),
"num_completions": request_args.pop(
"num_completions", self.num_completions
),
"stop_sequences": request_args.pop("stop_sequences", self.stop_sequences),
"top_p": request_args.pop("top_p", self.top_p),
"presence_penalty": request_args.pop(
"presence_penalty", self.presence_penalty
),
"frequency_penalty": request_args.pop(
"frequency_penalty", self.frequency_penalty
),
}

@ -42,7 +42,9 @@ class DummyClient(Client):
"""
return {"engine": "dummy"}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -55,10 +57,10 @@ class DummyClient(Client):
"""
request_params = {
"prompt": query,
"num_results": kwargs.get("num_results", self.num_results),
"num_results": request_args.pop("num_results", self.num_results),
}
def _run_completion() -> Dict:
return {"choices": [{"text": "hello"}] * self.num_results}
return {"choices": [{"text": "hello"}] * request_params["num_results"]}
return _run_completion, request_params

@ -51,7 +51,9 @@ class HuggingFaceClient(Client):
res = requests.post(self.host + "/params")
return res.json()
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -64,15 +66,15 @@ class HuggingFaceClient(Client):
"""
request_params = {
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_p": kwargs.get("top_p", self.top_p),
"top_k": kwargs.get("top_k", self.top_k),
"do_sample": kwargs.get("do_sample", self.do_sample),
"repetition_penalty": kwargs.get(
"temperature": request_args.pop("temperature", self.temperature),
"max_tokens": request_args.pop("max_tokens", self.max_tokens),
"top_p": request_args.pop("top_p", self.top_p),
"top_k": request_args.pop("top_k", self.top_k),
"do_sample": request_args.pop("do_sample", self.do_sample),
"repetition_penalty": request_args.pop(
"repetition_penalty", self.repetition_penalty
),
"n": kwargs.get("n", self.n),
"n": request_args.pop("n", self.n),
}
request_params.update(self.model_params)

@ -71,7 +71,9 @@ class OpenAIClient(Client):
"""
return {"model_name": "openai", "engine": self.engine}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -83,18 +85,20 @@ class OpenAIClient(Client):
request parameters as dict.
"""
request_params = {
"engine": kwargs.get("engine", self.engine),
"engine": request_args.pop("engine", self.engine),
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_p": kwargs.get("top_p", self.top_p),
"frequency_penalty": kwargs.get(
"temperature": request_args.pop("temperature", self.temperature),
"max_tokens": request_args.pop("max_tokens", self.max_tokens),
"top_p": request_args.pop("top_p", self.top_p),
"frequency_penalty": request_args.pop(
"frequency_penalty", self.frequency_penalty
),
"logprobs": kwargs.get("logprobs", self.logprobs),
"best_of": kwargs.get("best_of", self.best_of),
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
"n": kwargs.get("n", self.n),
"logprobs": request_args.pop("logprobs", self.logprobs),
"best_of": request_args.pop("best_of", self.best_of),
"presence_penalty": request_args.pop(
"presence_penalty", self.presence_penalty
),
"n": request_args.pop("n", self.n),
}
def _run_completion() -> Dict:

@ -46,7 +46,9 @@ class OPTClient(Client):
"""
return {"model_name": "opt"}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -60,10 +62,10 @@ class OPTClient(Client):
request_params = {
"prompt": query,
"engine": "opt",
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_p": kwargs.get("top_p", self.top_p),
"n": kwargs.get("n", self.n),
"temperature": request_args.pop("temperature", self.temperature),
"max_tokens": request_args.pop("max_tokens", self.max_tokens),
"top_p": request_args.pop("top_p", self.top_p),
"n": request_args.pop("n", self.n),
}
def _run_completion() -> Dict:

@ -121,7 +121,10 @@ class Manifest:
prompt = Prompt(prompt)
stop_token = stop_token if stop_token is not None else self.stop_token
prompt_str = prompt(input)
possible_request, full_kwargs = self.client.get_request(prompt_str, **kwargs)
# Must pass kwargs as dict for client "pop" methods removed used arguments
possible_request, full_kwargs = self.client.get_request(prompt_str, kwargs)
if len(kwargs) > 0:
raise ValueError(f"{list(kwargs.items())} arguments are not recognized.")
# Create cacke key
cache_key = full_kwargs.copy()
# Make query model dependent

@ -25,6 +25,12 @@ class Response:
"Response must be serialized to a dict with a "
"list of choices with text field"
)
if "logprobs" in self._response["choices"][0]:
if not isinstance(self._response["choices"][0]["logprobs"], list):
raise ValueError(
"Response must be serialized to a dict with a "
"list of choices with logprobs field"
)
self._cached = cached
self._request_params = request_params

@ -26,6 +26,7 @@ transformers = "^4.19.2"
torch = "^1.8"
requests = "^2.27.1"
tqdm = "^4.64.0"
uuid = "^1.30"
[tool.poetry.dev-dependencies]
black = "^22.3.0"

@ -3,12 +3,14 @@ Test client.
We just test the dummy client as we don't want to load a model or use OpenAI tokens.
"""
from manifest.clients.dummy import DummyClient
def test_init():
"""Test client initialization."""
client = DummyClient(connection_str=None)
assert client.num_results == 1
args = {"num_results": 3}
client = DummyClient(connection_str=None, client_args=args)
assert client.num_results == 3
@ -21,3 +23,7 @@ def test_get_request():
request_func, request_params = client.get_request("hello")
assert request_params == {"prompt": "hello", "num_results": 3}
assert request_func() == {"choices": [{"text": "hello"}] * 3}
request_func, request_params = client.get_request("hello", {"num_results": 5})
assert request_params == {"prompt": "hello", "num_results": 5}
assert request_func() == {"choices": [{"text": "hello"}] * 5}

@ -55,6 +55,12 @@ def test_run(sqlite_cache, num_results, return_response):
cache_connection=sqlite_cache,
num_results=num_results,
)
prompt = Prompt("This is a prompt")
with pytest.raises(ValueError) as exc_info:
result = manifest.run(prompt, return_response=return_response, bad_input=5)
assert str(exc_info.value) == "[('bad_input', 5)] arguments are not recognized."
prompt = Prompt("This is a prompt")
result = manifest.run(prompt, return_response=return_response)
if return_response:

Loading…
Cancel
Save