fix chat completion

main
Noah Shinn 10 months ago
parent 6a295d2083
commit cf0e1c1a4b

@ -24,9 +24,9 @@ def llm(prompt: str, model: Model, stop: List[str] = ["\n"]):
cur_try = 0
while cur_try < 6:
if model == "text-davinci-003":
text = get_completion(prompt=prompt, temperature=cur_try * 0.2)
text = get_completion(prompt=prompt, temperature=cur_try * 0.2, stop=stop)
else:
text = get_chat(prompt=prompt, model=model, temperature=cur_try * 0.2)
text = get_chat(prompt=prompt, model=model, temperature=cur_try * 0.2, stop=stop)
# dumb way to do this
if len(text.strip()) >= 5:
return text

@ -33,7 +33,7 @@ def get_completion(prompt: str, temperature: float = 0.0, max_tokens: int = 256,
return response.choices[0].text
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def get_chat(prompt: str, model: Model, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
def get_chat(prompt: str, model: Model, temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
assert model != "text-davinci-003"
messages = [
{
@ -41,10 +41,11 @@ def get_chat(prompt: str, model: Model, max_tokens: int = 256, stop_strs: Option
"content": prompt
}
]
response = openai.Completion.create(
response = openai.ChatCompletion.create(
model=model,
messages=messages,
max_tokens=max_tokens,
stop=stop_strs,
temperature=temperature,
)
return response.choices[0].message.content
return response.choices[0]["message"]["content"]

Loading…
Cancel
Save