HotPotQA runs
parent
9e4cd6402a
commit
2b3357a534
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,245 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import joblib\n",
|
||||
"from react_cls import ReactAgent\n",
|
||||
"from mocks import DocStoreExplorerMock, LLMMock"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def summarize_trial(agents):\n",
|
||||
" correct = [a for a in agents if a.is_correct()]\n",
|
||||
" halted = [a for a in agents if a.is_halted()]\n",
|
||||
" incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]\n",
|
||||
" return correct, incorrect, halted\n",
|
||||
"\n",
|
||||
"def log_trial(agents, trial_n):\n",
|
||||
" correct, incorrect, halted = summarize_trial(agents)\n",
|
||||
"\n",
|
||||
" log = f\"\"\"\n",
|
||||
"########################################\n",
|
||||
"BEGIN TRIAL {trial_n}\n",
|
||||
"Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)}\n",
|
||||
"#######################################\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
" log += '------------- BEGIN CORRECT AGENTS -------------\\n\\n'\n",
|
||||
" for agent in correct:\n",
|
||||
" log += f'Question: {agent.question}{agent.scratchpad}\\nCorrect answer: {agent.key}\\n\\n'\n",
|
||||
"\n",
|
||||
" log += '------------- BEGIN INCORRECT AGENTS -----------\\n\\n'\n",
|
||||
" for agent in incorrect:\n",
|
||||
" log += f'Question: {agent.question}{agent.scratchpad}\\nCorrect answer: {agent.key}\\n\\n'\n",
|
||||
"\n",
|
||||
" log += '------------- BEGIN HALTED AGENTS --------------\\n\\n'\n",
|
||||
" for agent in halted:\n",
|
||||
" log += f'Question: {agent.question}{agent.scratchpad}\\nCorrect answer: {agent.key}\\n\\n'\n",
|
||||
"\n",
|
||||
" return log"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hotpot = joblib.load('data/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agents = [ReactAgent(row['question'], row['answer']) for _, row in hotpot.iterrows()]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trial = 0\n",
|
||||
"log = ''"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"q = 0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Trial: 4 (0/66)\n",
|
||||
"Trial: 4 (1/66)\n",
|
||||
"Trial: 4 (2/66)\n",
|
||||
"Trial: 4 (3/66)\n",
|
||||
"Trial: 4 (4/66)\n",
|
||||
"Trial: 4 (5/66)\n",
|
||||
"Trial: 4 (6/66)\n",
|
||||
"Trial: 4 (7/66)\n",
|
||||
"Trial: 4 (8/66)\n",
|
||||
"Trial: 4 (9/66)\n",
|
||||
"Trial: 4 (10/66)\n",
|
||||
"Trial: 4 (11/66)\n",
|
||||
"Trial: 4 (12/66)\n",
|
||||
"Trial: 4 (13/66)\n",
|
||||
"Trial: 4 (14/66)\n",
|
||||
"Trial: 4 (15/66)\n",
|
||||
"Trial: 4 (16/66)\n",
|
||||
"Trial: 4 (17/66)\n",
|
||||
"Trial: 4 (18/66)\n",
|
||||
"Trial: 4 (19/66)\n",
|
||||
"Trial: 4 (20/66)\n",
|
||||
"Trial: 4 (21/66)\n",
|
||||
"Trial: 4 (22/66)\n",
|
||||
"Trial: 4 (23/66)\n",
|
||||
"Trial: 4 (24/66)\n",
|
||||
"Trial: 4 (25/66)\n",
|
||||
"Trial: 4 (26/66)\n",
|
||||
"Trial: 4 (27/66)\n",
|
||||
"Trial: 4 (28/66)\n",
|
||||
"Trial: 4 (29/66)\n",
|
||||
"Trial: 4 (30/66)\n",
|
||||
"Trial: 4 (31/66)\n",
|
||||
"Trial: 4 (32/66)\n",
|
||||
"Trial: 4 (33/66)\n",
|
||||
"Trial: 4 (34/66)\n",
|
||||
"Trial: 4 (35/66)\n",
|
||||
"Trial: 4 (36/66)\n",
|
||||
"Trial: 4 (37/66)\n",
|
||||
"Trial: 4 (38/66)\n",
|
||||
"Trial: 4 (39/66)\n",
|
||||
"Trial: 4 (40/66)\n",
|
||||
"Trial: 4 (41/66)\n",
|
||||
"Trial: 4 (42/66)\n",
|
||||
"Trial: 4 (43/66)\n",
|
||||
"Trial: 4 (44/66)\n",
|
||||
"Trial: 4 (45/66)\n",
|
||||
"Trial: 4 (46/66)\n",
|
||||
"Trial: 4 (47/66)\n",
|
||||
"Trial: 4 (48/66)\n",
|
||||
"Trial: 4 (49/66)\n",
|
||||
"Trial: 4 (50/66)\n",
|
||||
"Trial: 4 (51/66)\n",
|
||||
"Trial: 4 (52/66)\n",
|
||||
"Trial: 4 (53/66)\n",
|
||||
"Trial: 4 (54/66)\n",
|
||||
"Trial: 4 (55/66)\n",
|
||||
"Trial: 4 (56/66)\n",
|
||||
"Trial: 4 (57/66)\n",
|
||||
"Trial: 4 (58/66)\n",
|
||||
"Trial: 4 (59/66)\n",
|
||||
"Trial: 4 (60/66)\n",
|
||||
"Trial: 4 (61/66)\n",
|
||||
"Trial: 4 (62/66)\n",
|
||||
"Trial: 4 (63/66)\n",
|
||||
"Trial: 4 (64/66)\n",
|
||||
"Trial: 4 (65/66)\n",
|
||||
"Finished Trial 5, Correct: 34, Incorrect: 56, Halted: 12\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agents_to_run = [a for a in agents if not a.is_correct()]\n",
|
||||
"\n",
|
||||
"while q < len(agents_to_run):\n",
|
||||
" print(f'Trial: {trial} ({q}/{len(agents_to_run)})')\n",
|
||||
" agents_to_run[q].run()\n",
|
||||
" q += 1\n",
|
||||
"\n",
|
||||
"trial += 1\n",
|
||||
"\n",
|
||||
"log += log_trial(agents, trial)\n",
|
||||
"correct, incorrect, halted = summarize_trial(agents)\n",
|
||||
"print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open('output/base_react/100_questions_5_trials.txt', 'w') as f:\n",
|
||||
" f.write(log)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['output/base_react_dicts.joblib']"
|
||||
]
|
||||
},
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dicts = [dict(a.__dict__) for a in agents]\n",
|
||||
"for d in dicts:\n",
|
||||
" for k, v in d.items():\n",
|
||||
" d[k] = str(v)\n",
|
||||
"\n",
|
||||
"joblib.dump(dicts, 'output/base_react_dicts.joblib')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.9"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "e23f799cbd2581634725fbf6ce3480ae26192d78438dfafc8efe944acd6490d5"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,215 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import joblib\n",
|
||||
"from react_cls import ReactReflectAgent, format_reflections\n",
|
||||
"from mocks import DocStoreExplorerMock, LLMMock"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def summarize_trial(agents):\n",
|
||||
" correct = [a for a in agents if a.is_correct()]\n",
|
||||
" incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]\n",
|
||||
" return correct, incorrect\n",
|
||||
"\n",
|
||||
"def remove_fewshot(prompt: str) -> str:\n",
|
||||
" prefix = prompt.split('Here are some examples:')[0]\n",
|
||||
" suffix = prompt.split('(END OF EXAMPLES)')[1]\n",
|
||||
" return prefix.strip('\\n').strip() +'\\n' + suffix.strip('\\n').strip()\n",
|
||||
"\n",
|
||||
"def log_trial(agents, trial_n):\n",
|
||||
" correct, incorrect = summarize_trial(agents)\n",
|
||||
"\n",
|
||||
" log = f\"\"\"\n",
|
||||
"########################################\n",
|
||||
"BEGIN TRIAL {trial_n}\n",
|
||||
"Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}\n",
|
||||
"#######################################\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
" log += '------------- BEGIN CORRECT AGENTS -------------\\n\\n'\n",
|
||||
" for agent in correct:\n",
|
||||
" log += remove_fewshot(agent._build_agent_prompt()) + f'\\nCorrect answer: {agent.key}\\n\\n'\n",
|
||||
"\n",
|
||||
" log += '------------- BEGIN INCORRECT AGENTS -----------\\n\\n'\n",
|
||||
" for agent in incorrect:\n",
|
||||
" log += remove_fewshot(agent._build_agent_prompt()) + f'\\nCorrect answer: {agent.key}\\n\\n'\n",
|
||||
"\n",
|
||||
" return log\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hotpot = joblib.load('data/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agents = [ReactReflectAgent(row['question'], row['answer']) for _, row in hotpot.iterrows()]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trial = 0\n",
|
||||
"log = ''\n",
|
||||
"last_correct = 0 "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for agent in [a for a in agents if not a.is_correct()]:\n",
|
||||
" agent.run(reflect_strategy='last_attempt')\n",
|
||||
" print(f'Answer: {agent.key}')\n",
|
||||
"trial += 1\n",
|
||||
"log += log_trial(agents, trial)\n",
|
||||
"correct, incorrect = summarize_trial(agents)\n",
|
||||
"print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['output/last_trial_react/react_incorrect_dicts_trial_0.joblib']"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dicts = [dict(a.__dict__) for a in incorrect]\n",
|
||||
"for d in dicts:\n",
|
||||
" for k, v in d.items():\n",
|
||||
" d[k] = str(v)\n",
|
||||
"\n",
|
||||
"joblib.dump(dicts, 'output/last_trial_react/react_incorrect_dicts_trial_0.joblib')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"while last_correct != correct:\n",
|
||||
" last_correct, _ = summarize_trial(agents)\n",
|
||||
" for agent in [a for a in agents if not a.is_correct()]:\n",
|
||||
" agent.run(reflect_strategy='last_attempt')\n",
|
||||
" print(f'Answer: {agent.key}')\n",
|
||||
" trial += 1\n",
|
||||
" log += log_trial(agents, trial)\n",
|
||||
" correct, incorrect = summarize_trial(agents)\n",
|
||||
" print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for agent in [a for a in agents if not a.is_correct()]:\n",
|
||||
" agent.run(reflect_strategy='last_attempt + reflexion')\n",
|
||||
" print(f'Answer: {agent.key}')\n",
|
||||
"trial += 1\n",
|
||||
"log += log_trial(agents, trial)\n",
|
||||
"correct, incorrect = summarize_trial(agents)\n",
|
||||
"print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open('output/last_trial_react/100_questions_5_trials.txt', 'w') as f:\n",
|
||||
" f.write(log)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['output/reflect/react_reflect_50_correct_dicts.joblib']"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dicts = [dict(a.__dict__) for a in correct]\n",
|
||||
"for d in dicts:\n",
|
||||
" for k, v in d.items():\n",
|
||||
" d[k] = str(v)\n",
|
||||
"\n",
|
||||
"joblib.dump(dicts, 'output/reflect/react_reflect_50_correct_dicts.joblib')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.9"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "e23f799cbd2581634725fbf6ce3480ae26192d78438dfafc8efe944acd6490d5"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,172 @@
|
||||
import os
|
||||
from typing import List
|
||||
import dotenv
|
||||
|
||||
import gym
|
||||
import tiktoken
|
||||
from langchain import OpenAI
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from environment import QAEnv
|
||||
from prompts import reflect_prompt, react_agent_prompt, react_reflect_agent_prompt, REFLECTION_HEADER
|
||||
from fewshots import WEBTHINK_SIMPLE6, REFLECTIONS
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
class ReactAgent:
|
||||
"""
|
||||
A question answering ReAct Agent.
|
||||
"""
|
||||
def __init__(self,
|
||||
question: str,
|
||||
env: QAEnv,
|
||||
agent_prompt: PromptTemplate = react_agent_prompt,
|
||||
react_llm: BaseLLM = OpenAI(
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
model_name="text-davinci-003",
|
||||
model_kwargs={"stop": "\n"},
|
||||
openai_api_key=os.environ['OPENAI_API_KEY']),
|
||||
) -> None:
|
||||
|
||||
self.question = question
|
||||
self.agent_prompt = agent_prompt
|
||||
self.react_examples = WEBTHINK_SIMPLE6
|
||||
|
||||
self.env = env
|
||||
self.env.reset()
|
||||
self.reset()
|
||||
self.truncated, self.reward, self.terminated = False, False, False
|
||||
|
||||
self.llm = react_llm
|
||||
|
||||
self.enc = tiktoken.encoding_for_model("text-davinci-003")
|
||||
|
||||
def run(self, reset = True) -> None:
|
||||
if reset:
|
||||
self.env.reset()
|
||||
self.reset()
|
||||
|
||||
while not (self.is_truncated() or self.is_terminated()):
|
||||
self.step()
|
||||
|
||||
def step(self) -> None:
|
||||
# Think
|
||||
self.scratchpad += f'\nThought {self.curr_step}:'
|
||||
self.scratchpad += ' ' + self.prompt_agent()
|
||||
print(self.scratchpad.split('\n')[-1])
|
||||
|
||||
# Act
|
||||
self.scratchpad += f'\nAction {self.curr_step}:'
|
||||
action = self.prompt_agent()
|
||||
self.scratchpad += ' ' + action
|
||||
print(self.scratchpad.split('\n')[-1])
|
||||
|
||||
# Observe
|
||||
self.scratchpad += f'\nObservation {self.curr_step}: '
|
||||
observation, self.reward, self.terminated, self.truncated, self.curr_step = self.env.step(action)
|
||||
self.scratchpad += observation
|
||||
print(self.scratchpad.split('\n')[-1])
|
||||
|
||||
def prompt_agent(self) -> str:
|
||||
return format_step(self.llm(self._build_agent_prompt()))
|
||||
|
||||
def _build_agent_prompt(self) -> str:
|
||||
return self.agent_prompt.format(
|
||||
examples = self.react_examples,
|
||||
question = self.question,
|
||||
scratchpad = self.scratchpad)
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
return self.env.is_terminated()
|
||||
|
||||
def is_correct(self) -> bool:
|
||||
return self.env.is_correct()
|
||||
|
||||
def is_truncated(self) -> bool:
|
||||
return self.env.is_truncated() or (len(self.enc.encode(self._build_agent_prompt())) > 3896)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.scratchpad = ''
|
||||
self.curr_step = 1
|
||||
|
||||
|
||||
class ReactReflectAgent(ReactAgent):
|
||||
"""
|
||||
A question answering Self-Reflecting React Agent.
|
||||
"""
|
||||
def __init__(self,
|
||||
question: str,
|
||||
env: QAEnv,
|
||||
agent_prompt: PromptTemplate = react_reflect_agent_prompt,
|
||||
reflect_prompt: PromptTemplate = reflect_prompt,
|
||||
react_llm: BaseLLM = OpenAI(
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
model_name="text-davinci-003",
|
||||
model_kwargs={"stop": "\n"},
|
||||
openai_api_key=os.environ['OPENAI_API_KEY']),
|
||||
reflect_llm: BaseLLM = OpenAI(
|
||||
temperature=0,
|
||||
max_tokens=250,
|
||||
model_name="text-davinci-003",
|
||||
openai_api_key=os.environ['OPENAI_API_KEY']),
|
||||
) -> None:
|
||||
|
||||
super().__init__(question, env, agent_prompt, react_llm)
|
||||
self.reflect_llm = reflect_llm
|
||||
self.reflect_prompt = reflect_prompt
|
||||
self.reflect_examples = REFLECTIONS
|
||||
self.reflections = []
|
||||
|
||||
def run(self, reset = True) -> None:
|
||||
if (self.is_terminated() or self.is_truncated()) and not self.is_correct():
|
||||
self.reflect()
|
||||
|
||||
ReactAgent.run(self, reset)
|
||||
|
||||
def reflect(self) -> None:
|
||||
self.reflections.append(self.prompt_reflection())
|
||||
|
||||
def prompt_reflection(self) -> str:
|
||||
return format_step(self.reflect_llm(self._build_reflection_prompt()))
|
||||
|
||||
|
||||
def _build_reflection_prompt(self) -> str:
|
||||
return self.reflect_prompt.format(
|
||||
examples = self.reflect_examples,
|
||||
question = self.question,
|
||||
scratchpad = self._format_scratchpad())
|
||||
|
||||
def _build_agent_prompt(self) -> str:
|
||||
return self.agent_prompt.format(
|
||||
examples = self.react_examples,
|
||||
reflections = format_reflections(self.reflections),
|
||||
question = self.question,
|
||||
scratchpad = self.scratchpad)
|
||||
|
||||
def _format_scratchpad(self) -> str:
|
||||
lines = self.scratchpad.split('\n')
|
||||
lines_by_tokens = sorted(lines, key=lambda x: len(self.enc.encode(x)))
|
||||
while len(self.enc.encode('\n'.join(lines))) > 1600:
|
||||
ind = lines.index(lines_by_tokens.pop(-1))
|
||||
line = lines[ind]
|
||||
lines[ind] = line.split(':')[0] + ': ...'
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
|
||||
### String Operations ###
|
||||
def format_reflections(reflections: List[str]) -> str:
|
||||
if reflections == []:
|
||||
return ''
|
||||
else:
|
||||
header = REFLECTION_HEADER
|
||||
return header + 'Reflections:\n- ' + '\n- '.join([r.strip() for r in reflections])
|
||||
|
||||
def format_step(step: str) -> str:
|
||||
return step.strip('\n').strip().replace('\n', '')
|
||||
|
||||
|
||||
|
@ -0,0 +1,84 @@
|
||||
import joblib
|
||||
from react_cls import CoTAgent
|
||||
from mocks import DocStoreExplorerMock, LLMMock
|
||||
import numpy as np
|
||||
|
||||
def summarize_trial(agents):
|
||||
correct = [a for a in agents if a.is_correct()]
|
||||
incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]
|
||||
return correct, incorrect
|
||||
|
||||
def log_trial(agents, trial_n):
|
||||
correct, incorrect = summarize_trial(agents)
|
||||
|
||||
log = f"""
|
||||
########################################
|
||||
BEGIN TRIAL {trial_n}
|
||||
Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}
|
||||
#######################################
|
||||
"""
|
||||
|
||||
log += '------------- BEGIN CORRECT AGENTS -------------\n\n'
|
||||
for agent in correct:
|
||||
log += f'Context: {agent.context} Question: {agent.question}{agent.scratchpad}\nCorrect answer: {agent.key}\n\n'
|
||||
|
||||
log += '------------- BEGIN INCORRECT AGENTS -----------\n\n'
|
||||
for agent in incorrect:
|
||||
log += f'Context: {agent.context} Question: {agent.question}{agent.scratchpad}\nCorrect answer: {agent.key}\n\n'
|
||||
return log
|
||||
|
||||
if __name__ == '__main__':
|
||||
hotpot = joblib.load('data/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)
|
||||
hotpot['supporting_paragraphs'] = None
|
||||
for ind, row in hotpot.iterrows():
|
||||
supporting_articles = row['supporting_facts']['title']
|
||||
articles = row['context']['title']
|
||||
sentences = row['context']['sentences']
|
||||
supporting_paragraphs = []
|
||||
for article in supporting_articles:
|
||||
supporting_paragraph = ''.join(sentences[np.where(articles == article)][0])
|
||||
supporting_paragraphs.append(supporting_paragraph)
|
||||
hotpot.at[ind, 'supporting_paragraphs'] = supporting_paragraphs
|
||||
|
||||
for ind, row in hotpot.iterrows():
|
||||
supporting_paragraphs = row['supporting_paragraphs']
|
||||
supporting_paragraphs = '\n\n'.join(supporting_paragraphs)
|
||||
hotpot.at[ind, 'supporting_paragraphs'] = supporting_paragraphs
|
||||
|
||||
agents = [CoTAgent(row['question'], row['supporting_paragraphs'], row['answer']) for _, row in hotpot.iterrows()]
|
||||
trial = 0
|
||||
log = ''
|
||||
for agent in [a for a in agents if not a.is_correct()]:
|
||||
agent.run(reflect = False)
|
||||
print(f'Answer: {agent.key}')
|
||||
trial += 1
|
||||
|
||||
log += log_trial(agents, trial)
|
||||
correct, incorrect = summarize_trial(agents)
|
||||
print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')
|
||||
dicts = [dict(a.__dict__) for a in agents]
|
||||
for d in dicts:
|
||||
for k, v in d.items():
|
||||
d[k] = str(v)
|
||||
|
||||
joblib.dump(dicts, 'output/base_cot/cot_reflect_50_correct_dicts-8-trials.joblib')
|
||||
print(log)
|
||||
|
||||
with open('output/base_cot/100_questions_8_trials.txt', 'w') as f:
|
||||
f.write(log)
|
||||
|
||||
trial = 0
|
||||
log = ''
|
||||
q = 0
|
||||
agents_to_run = [a for a in agents if not a.is_correct()]
|
||||
|
||||
while q < len(agents_to_run):
|
||||
print(f'Trial: {trial} ({q}/{len(agents_to_run)})')
|
||||
agents_to_run[q].run()
|
||||
q += 1
|
||||
|
||||
trial += 1
|
||||
|
||||
log += log_trial(agents, trial)
|
||||
correct, incorrect, halted = summarize_trial(agents)
|
||||
print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)}')
|
Binary file not shown.
@ -0,0 +1,101 @@
|
||||
import re
|
||||
import string
|
||||
from typing import Tuple
|
||||
|
||||
import gym
|
||||
from langchain import Wikipedia
|
||||
from langchain.agents.react.base import DocstoreExplorer
|
||||
|
||||
class QAEnv(gym.Env):
|
||||
def __init__(self,
|
||||
question: str,
|
||||
key: str,
|
||||
max_steps: int = 6,
|
||||
explorer: DocstoreExplorer = DocstoreExplorer(Wikipedia())):
|
||||
|
||||
self.question = question
|
||||
self.key = key
|
||||
self.max_steps = max_steps
|
||||
self.explorer = explorer
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.curr_step = 0
|
||||
self.terminated = False
|
||||
self.answer = ''
|
||||
|
||||
def step(self, action: str) -> Tuple[str, bool, bool, bool, bool]:
|
||||
action_type, argument = parse_action(action)
|
||||
|
||||
if action_type == 'Finish':
|
||||
self.answer = argument
|
||||
if self.is_correct():
|
||||
observation = 'Answer is CORRECT'
|
||||
else:
|
||||
observation = 'Answer is INCORRECT'
|
||||
self.terminated = True
|
||||
|
||||
elif action_type == 'Search':
|
||||
try:
|
||||
observation = self.explorer.search(argument).strip('\n').strip()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
observation = f'Could not find that page, please try again.'
|
||||
|
||||
elif action_type == 'Lookup':
|
||||
try:
|
||||
observation = self.explorer.lookup(argument).strip('\n').strip()
|
||||
except ValueError:
|
||||
observation = f'The last page Searched was not found, so you cannot Lookup a keyword in it. Please try one of the similar pages given.'
|
||||
|
||||
else:
|
||||
observation = 'Invalid Action. Valid Actions are Lookup[<topic>] Search[<topic>] and Finish[<answer>].'
|
||||
|
||||
reward = self.is_correct()
|
||||
terminated = self.is_terminated()
|
||||
truncated = self.is_truncated()
|
||||
|
||||
self.curr_step += 1
|
||||
|
||||
return observation, reward, terminated, truncated, self.curr_step
|
||||
|
||||
def is_correct(self) -> bool:
|
||||
return EM(self.answer, self.key)
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
return self.terminated
|
||||
|
||||
def is_truncated(self) -> bool:
|
||||
return self.curr_step >= self.max_steps
|
||||
|
||||
def parse_action(string):
|
||||
pattern = r'^(\w+)\[(.+)\]$'
|
||||
match = re.match(pattern, string)
|
||||
|
||||
if match:
|
||||
action_type = match.group(1)
|
||||
argument = match.group(2)
|
||||
return action_type, argument
|
||||
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def normalize_answer(s):
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
def EM(answer, key) -> bool:
|
||||
return normalize_answer(answer) == normalize_answer(key)
|
@ -0,0 +1,43 @@
|
||||
from langchain.agents.react.base import DocstoreExplorer
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
def reactLLMMock(prompt: str) -> str:
|
||||
last_line = prompt.split('\n')[-1].strip()
|
||||
last_action = last_line.split(' ')[0].lower()
|
||||
if last_action == 'thought':
|
||||
return 'It does not mention the eastern sector. So I need to look up eastern sector.'
|
||||
elif last_action == 'action':
|
||||
return 'Lookup[eastern sector]'
|
||||
else:
|
||||
raise Exception('Invalid action type')
|
||||
|
||||
|
||||
def reflectLLMMock(prompt: str) -> str:
|
||||
return "Last time i should have answered correctly"
|
||||
|
||||
class LLMMock(BaseLLM):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
def __call__(self, prompt: str) -> str:
|
||||
if prompt.split('\n')[0].split(' ')[0] == 'Solve':
|
||||
return reactLLMMock(prompt)
|
||||
|
||||
elif prompt.split('\n')[0].split(' ')[0] == 'You':
|
||||
return reflectLLMMock(prompt)
|
||||
else:
|
||||
raise Exception("Invalid LLM prompt")
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return 0
|
||||
|
||||
class DocStoreExplorerMock(DocstoreExplorer):
|
||||
def __init__(self):
|
||||
self.summary = "The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas."
|
||||
self.body = "(Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny."
|
||||
|
||||
def search(self, search: str, sents: int = 5) -> str:
|
||||
return self.summary
|
||||
|
||||
def lookup(self, term: str) -> str:
|
||||
return self.body
|
@ -0,0 +1,142 @@
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
COT_INSTRUCTION = """Solve a question answering task by having a Thought, then Finish with your answer. Thought can reason about the current situation. Finish[answer] returns the answer and finishes the task. You will be given context that you should use to help you answer the question.
|
||||
Here are some examples:
|
||||
{examples}
|
||||
(END OF EXAMPLES)
|
||||
{reflections}
|
||||
Relevant Context: {context}
|
||||
Question: {question}{scratchpad}"""
|
||||
|
||||
COT_AGENT_REFLECT_INSTRUCTION = """Solve a question answering task by having a Thought, then Finish with your answer. Thought can reason about the current situation. Finish[answer] returns the answer and finishes the task. You will be given context that you should use to help you answer the question.
|
||||
Here are some examples:
|
||||
{examples}
|
||||
(END OF EXAMPLES)
|
||||
|
||||
{reflections}
|
||||
|
||||
Relevant Context: {context}
|
||||
Question: {question}{scratchpad}"""
|
||||
|
||||
COT_REFLECT_INSTRUCTION = """You are an advanced reasoning agent that can improve based on self refection. You will be given a previous reasoning trial in which you were given access to relevant context and a question to answer. You were unsuccessful in answering the question either because you guessed the wrong answer with Finish[<answer>] or there is a phrasing discrepancy with your provided answer and the answer key. In a few sentences, Diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, high level plan that aims to mitigate the same failure. Use complete sentences.
|
||||
Here are some examples:
|
||||
{examples}
|
||||
(END OF EXAMPLES)
|
||||
|
||||
Previous trial:
|
||||
Relevant Context: {context}
|
||||
Question: {question}{scratchpad}
|
||||
|
||||
Reflection:"""
|
||||
|
||||
cot_agent_prompt = PromptTemplate(
|
||||
input_variables=["examples", "reflections", "context", "question", "scratchpad"],
|
||||
template = COT_INSTRUCTION,
|
||||
)
|
||||
|
||||
cot_reflect_agent_prompt = PromptTemplate(
|
||||
input_variables=["examples", "reflections", "context", "question", "scratchpad"],
|
||||
template = COT_AGENT_REFLECT_INSTRUCTION,
|
||||
)
|
||||
|
||||
cot_reflect_prompt = PromptTemplate(
|
||||
input_variables=["examples", "context", "question", "scratchpad"],
|
||||
template = COT_REFLECT_INSTRUCTION,
|
||||
)
|
||||
|
||||
COT_SIMPLE_INSTRUCTION = """Solve a question answering task by having a Thought, then Finish with your answer. Thought can reason about the current situation. Finish[answer] returns the answer and finishes the task.
|
||||
Here are some examples:
|
||||
{examples}
|
||||
(END OF EXAMPLES)
|
||||
{reflections}
|
||||
{context}
|
||||
Question: {question}{scratchpad}"""
|
||||
|
||||
COT_SIMPLE_AGENT_REFLECT_INSTRUCTION = """Solve a question answering task by having a Thought, then Finish with your answer. Thought can reason about the current situation. Finish[answer] returns the answer and finishes the task.
|
||||
Here are some examples:
|
||||
{examples}
|
||||
(END OF EXAMPLES)
|
||||
{context}
|
||||
{reflections}
|
||||
|
||||
Question: {question}{scratchpad}"""
|
||||
|
||||
COT_SIMPLE_REFLECT_INSTRUCTION = """You are an advanced reasoning agent that can improve based on self refection. You will be given a previous reasoning trial in which you were given a question to answer. You were unsuccessful in answering the question either because you guessed the wrong answer with Finish[<answer>] or there is a phrasing discrepancy with your provided answer and the answer key. In a few sentences, Diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, high level plan that aims to mitigate the same failure. Use complete sentences.
|
||||
Here are some examples:
|
||||
{examples}
|
||||
(END OF EXAMPLES)
|
||||
{context}
|
||||
Previous trial:
|
||||
Question: {question}{scratchpad}
|
||||
|
||||
Reflection:"""
|
||||
|
||||
cot_simple_agent_prompt = PromptTemplate(
|
||||
input_variables=["examples", "question", "reflections", "context", "scratchpad"],
|
||||
template = COT_SIMPLE_INSTRUCTION,
|
||||
)
|
||||
|
||||
cot_simple_reflect_agent_prompt = PromptTemplate(
|
||||
input_variables=["examples", "context", "reflections", "question", "scratchpad"],
|
||||
template = COT_SIMPLE_AGENT_REFLECT_INSTRUCTION,
|
||||
)
|
||||
|
||||
cot_simple_reflect_prompt = PromptTemplate(
|
||||
input_variables=["examples", "question", "context", "scratchpad"],
|
||||
template = COT_SIMPLE_REFLECT_INSTRUCTION,
|
||||
)
|
||||
|
||||
|
||||
REACT_INSTRUCTION = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:
|
||||
(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.
|
||||
(2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search.
|
||||
(3) Finish[answer], which returns the answer and finishes the task.
|
||||
You may take as many steps as necessary.
|
||||
Here are some examples:
|
||||
{examples}
|
||||
(END OF EXAMPLES)
|
||||
Question: {question}{scratchpad}"""
|
||||
|
||||
REACT_REFLECT_INSTRUCTION = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:
|
||||
(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.
|
||||
(2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search.
|
||||
(3) Finish[answer], which returns the answer and finishes the task.
|
||||
You may take as many steps as necessary.
|
||||
Here are some examples:
|
||||
{examples}
|
||||
(END OF EXAMPLES)
|
||||
|
||||
{reflections}
|
||||
|
||||
Question: {question}{scratchpad}"""
|
||||
|
||||
REFLECTION_HEADER = 'You have attempted to answer following question before and failed. The following reflection(s) give a plan to avoid failing to answer the question in the same way you did previously. Use them to improve your strategy of correctly answering the given question.\n'
|
||||
REFLECTION_AFTER_LAST_TRIAL_HEADER = 'The following reflection(s) give a plan to avoid failing to answer the question in the same way you did previously. Use them to improve your strategy of correctly answering the given question.\n'
|
||||
LAST_TRIAL_HEADER = 'You have attempted to answer the following question before and failed. Below is the last trial you attempted to answer the question.\n'
|
||||
|
||||
REFLECT_INSTRUCTION = """You are an advanced reasoning agent that can improve based on self refection. You will be given a previous reasoning trial in which you were given access to an Docstore API environment and a question to answer. You were unsuccessful in answering the question either because you guessed the wrong answer with Finish[<answer>], or you used up your set number of reasoning steps. In a few sentences, Diagnose a possible reason for failure and devise a new, concise, high level plan that aims to mitigate the same failure. Use complete sentences.
|
||||
Here are some examples:
|
||||
{examples}
|
||||
|
||||
Previous trial:
|
||||
Question: {question}{scratchpad}
|
||||
|
||||
Reflection:"""
|
||||
|
||||
react_agent_prompt = PromptTemplate(
|
||||
input_variables=["examples", "question", "scratchpad"],
|
||||
template = REACT_INSTRUCTION,
|
||||
)
|
||||
|
||||
react_reflect_agent_prompt = PromptTemplate(
|
||||
input_variables=["examples", "reflections", "question", "scratchpad"],
|
||||
template = REACT_REFLECT_INSTRUCTION,
|
||||
)
|
||||
|
||||
reflect_prompt = PromptTemplate(
|
||||
input_variables=["examples", "question", "scratchpad"],
|
||||
template = REFLECT_INSTRUCTION,
|
||||
)
|
||||
|
||||
|
||||
|
@ -0,0 +1,379 @@
|
||||
import re, string, os
|
||||
from typing import List, Union, Literal
|
||||
|
||||
import tiktoken
|
||||
from langchain import OpenAI, Wikipedia
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.agents.react.base import DocstoreExplorer
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.prompts import PromptTemplate
|
||||
from prompts import reflect_prompt, react_agent_prompt, react_reflect_agent_prompt, REFLECTION_HEADER, LAST_TRIAL_HEADER, REFLECTION_AFTER_LAST_TRIAL_HEADER
|
||||
from prompts import cot_agent_prompt, cot_reflect_agent_prompt, cot_reflect_prompt, COT_INSTRUCTION, COT_REFLECT_INSTRUCTION
|
||||
from fewshots import WEBTHINK_SIMPLE6, REFLECTIONS, COT, COT_REFLECT
|
||||
|
||||
class CoTAgent:
|
||||
def __init__(self,
|
||||
question: str,
|
||||
context: str,
|
||||
key: str,
|
||||
agent_prompt: PromptTemplate = cot_reflect_agent_prompt,
|
||||
reflect_prompt: PromptTemplate = cot_reflect_prompt,
|
||||
reflect_header: str = REFLECTION_HEADER,
|
||||
cot_examples: str = COT,
|
||||
reflect_examples: str = COT_REFLECT,
|
||||
self_reflect_llm: BaseLLM = OpenAI(
|
||||
temperature=0,
|
||||
max_tokens=250,
|
||||
model_name="text-davinci-003",
|
||||
model_kwargs={"stop": "\n"},
|
||||
openai_api_key=os.environ['OPENAI_API_KEY']),
|
||||
action_llm: BaseLLM = OpenAI(
|
||||
temperature=0,
|
||||
max_tokens=250,
|
||||
model_name="text-davinci-003",
|
||||
model_kwargs={"stop": "\n"},
|
||||
openai_api_key=os.environ['OPENAI_API_KEY']),
|
||||
) -> None:
|
||||
|
||||
self.question = question
|
||||
self.context = context
|
||||
self.key = key
|
||||
self.agent_prompt = agent_prompt
|
||||
self.reflect_prompt = reflect_prompt
|
||||
self.reflect_header = reflect_header
|
||||
self.cot_examples = cot_examples
|
||||
self.reflect_examples = reflect_examples
|
||||
self.self_reflect_llm = self_reflect_llm
|
||||
self.action_llm = action_llm
|
||||
self.reflections: List[str] = []
|
||||
self.reflections_str = ''
|
||||
self.answer = ''
|
||||
self.step_n: int = 0
|
||||
self.reset()
|
||||
|
||||
def run(self, reflect: bool = True,
|
||||
reflect_strategy: Union[Literal['last_attempt'],
|
||||
Literal['reflexion'],
|
||||
Literal['last_attempt + reflexion']] = 'reflexion') -> None:
|
||||
if self.step_n > 0 and not self.is_correct() and reflect:
|
||||
self.reflect(reflect_strategy)
|
||||
self.reset()
|
||||
self.step()
|
||||
self.step_n += 1
|
||||
|
||||
def step(self) -> None:
|
||||
# Think
|
||||
self.scratchpad += f'\nThought:'
|
||||
self.scratchpad += ' ' + self.prompt_agent()
|
||||
print(self.scratchpad.split('\n')[-1])
|
||||
|
||||
# Act
|
||||
self.scratchpad += f'\nAction:'
|
||||
action = self.prompt_agent()
|
||||
self.scratchpad += ' ' + action
|
||||
action_type, argument = parse_action(action)
|
||||
print(self.scratchpad.split('\n')[-1])
|
||||
|
||||
self.scratchpad += f'\nObservation: '
|
||||
if action_type == 'Finish':
|
||||
self.answer = argument
|
||||
if self.is_correct():
|
||||
self.scratchpad += 'Answer is CORRECT'
|
||||
else:
|
||||
self.scratchpad += 'Answer is INCORRECT'
|
||||
self.finished = True
|
||||
return
|
||||
else:
|
||||
print('Invalid action type, please try again.')
|
||||
|
||||
def reflect(self,
|
||||
strategy: Union[Literal['last_attempt'],
|
||||
Literal['reflexion'],
|
||||
Literal['last_attempt + reflexion']]) -> None:
|
||||
print('Reflecting...')
|
||||
if strategy == 'last_attempt':
|
||||
self.reflections = [self.scratchpad]
|
||||
self.reflections_str = format_last_attempt(self.question , self.reflections[0])
|
||||
elif strategy == 'reflexion':
|
||||
self.reflections += [self.prompt_reflection()]
|
||||
self.reflections_str = format_reflections(self.reflections)
|
||||
elif strategy == 'last_attempt + reflexion':
|
||||
self.reflections_str = format_last_attempt(self.question , self.scratchpad)
|
||||
self.reflections = [self.prompt_reflection()]
|
||||
self.reflections_str += '\n'+ format_reflections(self.reflections, header = REFLECTION_AFTER_LAST_TRIAL_HEADER)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown reflection strategy: {strategy}')
|
||||
print(self.reflections_str)
|
||||
|
||||
def prompt_reflection(self) -> str:
|
||||
return format_step(self.self_reflect_llm(self._build_reflection_prompt()))
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
self.scratchpad: str = ''
|
||||
self.finished = False
|
||||
|
||||
def prompt_agent(self) -> str:
|
||||
return format_step(self.action_llm(self._build_agent_prompt()))
|
||||
|
||||
def _build_agent_prompt(self) -> str:
|
||||
return self.agent_prompt.format(
|
||||
examples = self.cot_examples,
|
||||
reflections = self.reflections_str,
|
||||
context = self.context,
|
||||
question = self.question,
|
||||
scratchpad = self.scratchpad)
|
||||
|
||||
def _build_reflection_prompt(self) -> str:
|
||||
return self.reflect_prompt.format(
|
||||
examples = self.reflect_examples,
|
||||
context = self.context,
|
||||
question = self.question,
|
||||
scratchpad = self.scratchpad)
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self.finished
|
||||
|
||||
def is_correct(self) -> bool:
|
||||
return EM(self.answer, self.key)
|
||||
|
||||
class ReactAgent:
|
||||
def __init__(self,
|
||||
question: str,
|
||||
key: str,
|
||||
max_steps: int = 6,
|
||||
agent_prompt: PromptTemplate = react_agent_prompt,
|
||||
docstore: Docstore = Wikipedia(),
|
||||
react_llm: BaseLLM = OpenAI(
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
model_name="text-davinci-003",
|
||||
model_kwargs={"stop": "\n"},
|
||||
openai_api_key=os.environ['OPENAI_API_KEY']),
|
||||
) -> None:
|
||||
|
||||
self.question = question
|
||||
self.answer = ''
|
||||
self.key = key
|
||||
self.max_steps = max_steps
|
||||
self.agent_prompt = agent_prompt
|
||||
self.react_examples = WEBTHINK_SIMPLE6
|
||||
|
||||
self.docstore = DocstoreExplorer(docstore) # Search, Lookup
|
||||
self.llm = react_llm
|
||||
|
||||
self.enc = tiktoken.encoding_for_model("text-davinci-003")
|
||||
|
||||
self.__reset_agent()
|
||||
|
||||
def run(self, reset = True) -> None:
|
||||
if reset:
|
||||
self.__reset_agent()
|
||||
|
||||
while not self.is_halted() and not self.is_finished():
|
||||
self.step()
|
||||
|
||||
def step(self) -> None:
|
||||
# Think
|
||||
self.scratchpad += f'\nThought {self.step_n}:'
|
||||
self.scratchpad += ' ' + self.prompt_agent()
|
||||
print(self.scratchpad.split('\n')[-1])
|
||||
|
||||
# Act
|
||||
self.scratchpad += f'\nAction {self.step_n}:'
|
||||
action = self.prompt_agent()
|
||||
self.scratchpad += ' ' + action
|
||||
action_type, argument = parse_action(action)
|
||||
print(self.scratchpad.split('\n')[-1])
|
||||
|
||||
# Observe
|
||||
self.scratchpad += f'\nObservation {self.step_n}: '
|
||||
|
||||
if action_type == 'Finish':
|
||||
self.answer = argument
|
||||
if self.is_correct():
|
||||
self.scratchpad += 'Answer is CORRECT'
|
||||
else:
|
||||
self.scratchpad += 'Answer is INCORRECT'
|
||||
self.finished = True
|
||||
self.step_n += 1
|
||||
return
|
||||
|
||||
if action_type == 'Search':
|
||||
try:
|
||||
self.scratchpad += format_step(self.docstore.search(argument))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.scratchpad += f'Could not find that page, please try again.'
|
||||
|
||||
elif action_type == 'Lookup':
|
||||
try:
|
||||
self.scratchpad += format_step(self.docstore.lookup(argument))
|
||||
except ValueError:
|
||||
self.scratchpad += f'The last page Searched was not found, so you cannot Lookup a keyword in it. Please try one of the similar pages given.'
|
||||
|
||||
else:
|
||||
self.scratchpad += 'Invalid Action. Valid Actions are Lookup[<topic>] Search[<topic>] and Finish[<answer>].'
|
||||
|
||||
print(self.scratchpad.split('\n')[-1])
|
||||
|
||||
self.step_n += 1
|
||||
|
||||
def prompt_agent(self) -> str:
|
||||
return format_step(self.llm(self._build_agent_prompt()))
|
||||
|
||||
def _build_agent_prompt(self) -> str:
|
||||
return self.agent_prompt.format(
|
||||
examples = self.react_examples,
|
||||
question = self.question,
|
||||
scratchpad = self.scratchpad)
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self.finished
|
||||
|
||||
def is_correct(self) -> bool:
|
||||
return EM(self.answer, self.key)
|
||||
|
||||
def is_halted(self) -> bool:
|
||||
return ((self.step_n > self.max_steps) or (len(self.enc.encode(self._build_agent_prompt())) > 3896)) and not self.finished
|
||||
|
||||
def __reset_agent(self) -> None:
|
||||
self.step_n = 1
|
||||
self.finished = False
|
||||
self.scratchpad: str = ''
|
||||
|
||||
def set_qa(self, question: str, key: str) -> None:
|
||||
self.question = question
|
||||
self.key = key
|
||||
|
||||
class ReactReflectAgent(ReactAgent):
|
||||
def __init__(self,
|
||||
question: str,
|
||||
key: str,
|
||||
max_steps: int = 6,
|
||||
agent_prompt: PromptTemplate = react_reflect_agent_prompt,
|
||||
reflect_prompt: PromptTemplate = reflect_prompt,
|
||||
reflect_header: str = REFLECTION_HEADER,
|
||||
docstore: Docstore = Wikipedia(),
|
||||
react_llm: BaseLLM = OpenAI(
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
model_name="text-davinci-003",
|
||||
model_kwargs={"stop": "\n"},
|
||||
openai_api_key=os.environ['OPENAI_API_KEY']),
|
||||
reflect_llm: BaseLLM = OpenAI(
|
||||
temperature=0,
|
||||
max_tokens=250,
|
||||
model_name="text-davinci-003",
|
||||
openai_api_key=os.environ['OPENAI_API_KEY']),
|
||||
) -> None:
|
||||
|
||||
super().__init__(question, key, max_steps, agent_prompt, docstore, react_llm)
|
||||
self.reflect_header = reflect_header
|
||||
self.reflect_llm = reflect_llm
|
||||
self.reflect_prompt = reflect_prompt
|
||||
self.reflect_examples = REFLECTIONS
|
||||
self.reflections: List[str] = []
|
||||
self.reflections_str: str = ''
|
||||
|
||||
def run(self, reset = True, reflect_strategy: Union[Literal['last_attempt'], Literal['reflexion'], Literal['last_attempt + reflexion']] = 'reflexion') -> None:
|
||||
if (self.is_finished() or self.is_halted()) and not self.is_correct():
|
||||
self.reflect(reflect_strategy)
|
||||
|
||||
ReactAgent.run(self, reset)
|
||||
|
||||
def reflect(self,
|
||||
strategy: Union[Literal['last_attempt'], Literal['reflexion'], Literal['last_attempt + reflexion']]) -> None:
|
||||
print('Reflecting...')
|
||||
if strategy == 'last_attempt':
|
||||
self.reflections = [self.scratchpad]
|
||||
self.reflections_str = format_last_attempt(self.question, self.reflections[0])
|
||||
elif strategy == 'reflexion':
|
||||
self.reflections += [self.prompt_reflection()]
|
||||
self.reflections_str = format_reflections(self.reflections)
|
||||
elif strategy == 'last_attempt + reflexion':
|
||||
self.reflections_str = format_last_attempt(self.question, self.scratchpad)
|
||||
self.reflections = [self.prompt_reflection()]
|
||||
self.reflections_str += format_reflections(self.reflections, header = REFLECTION_AFTER_LAST_TRIAL_HEADER)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown reflection strategy: {strategy}')
|
||||
print(self.reflections_str)
|
||||
|
||||
def prompt_reflection(self) -> str:
|
||||
return format_step(self.reflect_llm(self._build_reflection_prompt()))
|
||||
|
||||
def _build_reflection_prompt(self) -> str:
|
||||
return self.reflect_prompt.format(
|
||||
examples = self.reflect_examples,
|
||||
question = self.question,
|
||||
scratchpad = truncate_scratchpad(self.scratchpad, tokenizer=self.enc))
|
||||
|
||||
def _build_agent_prompt(self) -> str:
|
||||
return self.agent_prompt.format(
|
||||
examples = self.react_examples,
|
||||
reflections = self.reflections_str,
|
||||
question = self.question,
|
||||
scratchpad = self.scratchpad)
|
||||
|
||||
|
||||
### String Stuff ###
|
||||
gpt2_enc = tiktoken.encoding_for_model("text-davinci-003")
|
||||
|
||||
def parse_action(string):
|
||||
pattern = r'^(\w+)\[(.+)\]$'
|
||||
match = re.match(pattern, string)
|
||||
|
||||
if match:
|
||||
action_type = match.group(1)
|
||||
argument = match.group(2)
|
||||
return action_type, argument
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
def format_step(step: str) -> str:
|
||||
return step.strip('\n').strip().replace('\n', '')
|
||||
|
||||
def format_reflections(reflections: List[str],
|
||||
header: str = REFLECTION_HEADER) -> str:
|
||||
if reflections == []:
|
||||
return ''
|
||||
else:
|
||||
return header + 'Reflections:\n- ' + '\n- '.join([r.strip() for r in reflections])
|
||||
|
||||
def format_last_attempt(question: str,
|
||||
scratchpad: str,
|
||||
header: str = LAST_TRIAL_HEADER):
|
||||
return header + f'Question: {question}\n' + truncate_scratchpad(scratchpad, tokenizer=gpt2_enc).strip('\n').strip() + '\n(END PREVIOUS TRIAL)\n'
|
||||
|
||||
def truncate_scratchpad(scratchpad: str, n_tokens: int = 1600, tokenizer = gpt2_enc) -> str:
|
||||
lines = scratchpad.split('\n')
|
||||
observations = filter(lambda x: x.startswith('Observation'), lines)
|
||||
observations_by_tokens = sorted(observations, key=lambda x: len(tokenizer.encode(x)))
|
||||
while len(gpt2_enc.encode('\n'.join(lines))) > n_tokens:
|
||||
largest_observation = observations_by_tokens.pop(-1)
|
||||
ind = lines.index(largest_observation)
|
||||
lines[ind] = largest_observation.split(':')[0] + ': [truncated wikipedia excerpt]'
|
||||
return '\n'.join(lines)
|
||||
|
||||
def normalize_answer(s):
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
def EM(answer, key) -> bool:
|
||||
return normalize_answer(answer) == normalize_answer(key)
|
||||
|
||||
|
||||
|
@ -0,0 +1,12 @@
|
||||
EdgeGPT==0.3.6
|
||||
gym==0.26.2
|
||||
joblib==1.2.0
|
||||
langchain==0.0.162
|
||||
numpy==1.24.1
|
||||
openai==0.27.4
|
||||
python-dotenv==1.0.0
|
||||
tenacity==8.2.2
|
||||
tiktoken==0.4.0
|
||||
transformers==4.28.1
|
||||
pandas==1.5.3
|
||||
scikit-learn
|
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue