diff --git a/alfworld_runs/alfworld_trial.py b/alfworld_runs/alfworld_trial.py index d2d4859..b7b3223 100644 --- a/alfworld_runs/alfworld_trial.py +++ b/alfworld_runs/alfworld_trial.py @@ -24,9 +24,9 @@ def llm(prompt: str, model: Model, stop: List[str] = ["\n"]): cur_try = 0 while cur_try < 6: if model == "text-davinci-003": - text = get_completion(prompt=prompt, temperature=cur_try * 0.2, stop=stop) + text = get_completion(prompt=prompt, temperature=cur_try * 0.2, stop_strs=stop) else: - text = get_chat(prompt=prompt, model=model, temperature=cur_try * 0.2, stop=stop) + text = get_chat(prompt=prompt, model=model, temperature=cur_try * 0.2, stop_strs=stop) # dumb way to do this if len(text.strip()) >= 5: return text