NBs and README
parent
e531a5c0d6
commit
c2159d4b93
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,210 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Notebook for running Chain-of-Thought with supporting context experiments "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys, os\n",
|
||||
"sys.path.append('..')\n",
|
||||
"root = '../root/'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import joblib\n",
|
||||
"import numpy as np\n",
|
||||
"from agents import CoTAgent, ReflexionStrategy\n",
|
||||
"from util import summarize_trial, log_trial, save_agents"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Load the HotPotQA Sample"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hotpot = joblib.load('../data/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)\n",
|
||||
"\n",
|
||||
"hotpot['supporting_paragraphs'] = None\n",
|
||||
"for ind, row in hotpot.iterrows():\n",
|
||||
" supporting_articles = row['supporting_facts']['title']\n",
|
||||
" articles = row['context']['title']\n",
|
||||
" sentences = row['context']['sentences'] \n",
|
||||
" supporting_paragraphs = []\n",
|
||||
" for article in supporting_articles:\n",
|
||||
" supporting_paragraph = ''.join(sentences[np.where(articles == article)][0])\n",
|
||||
" supporting_paragraphs.append(supporting_paragraph)\n",
|
||||
" supporting_paragraphs = '\\n\\n'.join(supporting_paragraphs)\n",
|
||||
" hotpot.at[ind, 'supporting_paragraphs'] = supporting_paragraphs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Define the Reflexion Strategy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
" NONE: No reflection\n",
|
||||
" LAST_ATTEMPT: Use last reasoning trace in context \n",
|
||||
" REFLEXION: Apply reflexion to the next reasoning trace \n",
|
||||
" LAST_ATTEMPT_AND_REFLEXION: Use last reasoning trace in context and apply reflexion to the next reasoning trace \n",
|
||||
" \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(ReflexionStrategy.__doc__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"strategy: ReflexionStrategy = ReflexionStrategy.REFLEXION"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Initialize a CoTAgent for each question"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from prompts import cot_agent_prompt, cot_reflect_agent_prompt, cot_reflect_prompt\n",
|
||||
"from fewshots import COT, COT_REFLECT\n",
|
||||
"agents = [CoTAgent(row['question'],\n",
|
||||
" row['supporting_paragraphs'],\n",
|
||||
" row['answer'],\n",
|
||||
" agent_prompt=cot_agent_prompt if strategy == ReflexionStrategy.NONE else cot_reflect_agent_prompt,\n",
|
||||
" cot_examples=COT,\n",
|
||||
" reflect_prompt=cot_reflect_prompt,\n",
|
||||
" reflect_examples=COT_REFLECT,\n",
|
||||
" ) for _, row in hotpot.iterrows()]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Run `n` trials"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"n = 5\n",
|
||||
"trial = 0\n",
|
||||
"log = ''"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i in range(n):\n",
|
||||
" for agent in [a for a in agents if not a.is_correct()]:\n",
|
||||
" agent.run(reflexion_strategy = strategy)\n",
|
||||
" print(f'Answer: {agent.key}')\n",
|
||||
" trial += 1\n",
|
||||
" log += log_trial(agents, trial)\n",
|
||||
" correct, incorrect = summarize_trial(agents)\n",
|
||||
" print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Save the result log"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(os.path.join(root, 'CoT', 'context', strategy.value, f'{len(agents)}_questions_{trial}_trials.txt'), 'w') as f:\n",
|
||||
" f.write(log)\n",
|
||||
"save_agents(agents, os.path.join(root, 'CoT', 'context', strategy.value, 'agents'))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "e23f799cbd2581634725fbf6ce3480ae26192d78438dfafc8efe944acd6490d5"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,190 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Notebook for running Chain-of-Thought with no supporting context experiments"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys, os\n",
|
||||
"sys.path.append('..')\n",
|
||||
"root = '../root/'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from util import summarize_trial, log_trial, save_agents\n",
|
||||
"import joblib\n",
|
||||
"from agents import CoTAgent, ReflexionStrategy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Load the HotPotQA Sample"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hotpot = joblib.load('../data/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Define the Reflexion Strategy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
" NONE: No reflection\n",
|
||||
" LAST_ATTEMPT: Use last reasoning trace in context \n",
|
||||
" REFLEXION: Apply reflexion to the next reasoning trace \n",
|
||||
" LAST_ATTEMPT_AND_REFLEXION: Use last reasoning trace in context and apply reflexion to the next reasoning trace \n",
|
||||
" \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(ReflexionStrategy.__doc__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"strategy: ReflexionStrategy = ReflexionStrategy.REFLEXION"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Initialize a CoTAgent for each question"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from prompts import cot_simple_reflect_agent_prompt, cot_simple_reflect_prompt, cot_simple_agent_prompt\n",
|
||||
"from fewshots import COTQA_SIMPLE6, COT_SIMPLE_REFLECTION\n",
|
||||
"\n",
|
||||
"agents = [CoTAgent(question = row['question'],\n",
|
||||
" context = '',\n",
|
||||
" key = row['answer'],\n",
|
||||
" agent_prompt=cot_simple_agent_prompt if strategy == ReflexionStrategy.NONE else cot_simple_reflect_agent_prompt,\n",
|
||||
" cot_examples = COTQA_SIMPLE6,\n",
|
||||
" reflect_prompt = cot_simple_reflect_prompt,\n",
|
||||
" reflect_examples = COT_SIMPLE_REFLECTION,\n",
|
||||
" ) for _, row in hotpot.iterrows()]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Run `n` trials"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"n = 5\n",
|
||||
"trial = 0\n",
|
||||
"log = ''\n",
|
||||
"for i in range(n):\n",
|
||||
" for agent in [a for a in agents if not a.is_correct()]:\n",
|
||||
" agent.run(reflexion_strategy = strategy)\n",
|
||||
" print(f'Answer: {agent.key}')\n",
|
||||
" trial += 1\n",
|
||||
" log += log_trial(agents, trial)\n",
|
||||
" correct, incorrect = summarize_trial(agents)\n",
|
||||
" print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Save the result log"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(os.path.join(root, 'CoT', 'no_context', strategy.value, f'{len(agents)}_questions_{trial}_trials.txt'), 'w') as f:\n",
|
||||
" f.write(log)\n",
|
||||
"save_agents(agents, os.path.join(root, 'CoT', 'no_context', strategy.value, 'agents'))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "e23f799cbd2581634725fbf6ce3480ae26192d78438dfafc8efe944acd6490d5"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,245 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import joblib\n",
|
||||
"from react_cls import ReactAgent\n",
|
||||
"from mocks import DocStoreExplorerMock, LLMMock"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def summarize_trial(agents):\n",
|
||||
" correct = [a for a in agents if a.is_correct()]\n",
|
||||
" halted = [a for a in agents if a.is_halted()]\n",
|
||||
" incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]\n",
|
||||
" return correct, incorrect, halted\n",
|
||||
"\n",
|
||||
"def log_trial(agents, trial_n):\n",
|
||||
" correct, incorrect, halted = summarize_trial(agents)\n",
|
||||
"\n",
|
||||
" log = f\"\"\"\n",
|
||||
"########################################\n",
|
||||
"BEGIN TRIAL {trial_n}\n",
|
||||
"Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)}\n",
|
||||
"#######################################\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
" log += '------------- BEGIN CORRECT AGENTS -------------\\n\\n'\n",
|
||||
" for agent in correct:\n",
|
||||
" log += f'Question: {agent.question}{agent.scratchpad}\\nCorrect answer: {agent.key}\\n\\n'\n",
|
||||
"\n",
|
||||
" log += '------------- BEGIN INCORRECT AGENTS -----------\\n\\n'\n",
|
||||
" for agent in incorrect:\n",
|
||||
" log += f'Question: {agent.question}{agent.scratchpad}\\nCorrect answer: {agent.key}\\n\\n'\n",
|
||||
"\n",
|
||||
" log += '------------- BEGIN HALTED AGENTS --------------\\n\\n'\n",
|
||||
" for agent in halted:\n",
|
||||
" log += f'Question: {agent.question}{agent.scratchpad}\\nCorrect answer: {agent.key}\\n\\n'\n",
|
||||
"\n",
|
||||
" return log"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hotpot = joblib.load('data/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agents = [ReactAgent(row['question'], row['answer']) for _, row in hotpot.iterrows()]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trial = 0\n",
|
||||
"log = ''"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"q = 0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Trial: 4 (0/66)\n",
|
||||
"Trial: 4 (1/66)\n",
|
||||
"Trial: 4 (2/66)\n",
|
||||
"Trial: 4 (3/66)\n",
|
||||
"Trial: 4 (4/66)\n",
|
||||
"Trial: 4 (5/66)\n",
|
||||
"Trial: 4 (6/66)\n",
|
||||
"Trial: 4 (7/66)\n",
|
||||
"Trial: 4 (8/66)\n",
|
||||
"Trial: 4 (9/66)\n",
|
||||
"Trial: 4 (10/66)\n",
|
||||
"Trial: 4 (11/66)\n",
|
||||
"Trial: 4 (12/66)\n",
|
||||
"Trial: 4 (13/66)\n",
|
||||
"Trial: 4 (14/66)\n",
|
||||
"Trial: 4 (15/66)\n",
|
||||
"Trial: 4 (16/66)\n",
|
||||
"Trial: 4 (17/66)\n",
|
||||
"Trial: 4 (18/66)\n",
|
||||
"Trial: 4 (19/66)\n",
|
||||
"Trial: 4 (20/66)\n",
|
||||
"Trial: 4 (21/66)\n",
|
||||
"Trial: 4 (22/66)\n",
|
||||
"Trial: 4 (23/66)\n",
|
||||
"Trial: 4 (24/66)\n",
|
||||
"Trial: 4 (25/66)\n",
|
||||
"Trial: 4 (26/66)\n",
|
||||
"Trial: 4 (27/66)\n",
|
||||
"Trial: 4 (28/66)\n",
|
||||
"Trial: 4 (29/66)\n",
|
||||
"Trial: 4 (30/66)\n",
|
||||
"Trial: 4 (31/66)\n",
|
||||
"Trial: 4 (32/66)\n",
|
||||
"Trial: 4 (33/66)\n",
|
||||
"Trial: 4 (34/66)\n",
|
||||
"Trial: 4 (35/66)\n",
|
||||
"Trial: 4 (36/66)\n",
|
||||
"Trial: 4 (37/66)\n",
|
||||
"Trial: 4 (38/66)\n",
|
||||
"Trial: 4 (39/66)\n",
|
||||
"Trial: 4 (40/66)\n",
|
||||
"Trial: 4 (41/66)\n",
|
||||
"Trial: 4 (42/66)\n",
|
||||
"Trial: 4 (43/66)\n",
|
||||
"Trial: 4 (44/66)\n",
|
||||
"Trial: 4 (45/66)\n",
|
||||
"Trial: 4 (46/66)\n",
|
||||
"Trial: 4 (47/66)\n",
|
||||
"Trial: 4 (48/66)\n",
|
||||
"Trial: 4 (49/66)\n",
|
||||
"Trial: 4 (50/66)\n",
|
||||
"Trial: 4 (51/66)\n",
|
||||
"Trial: 4 (52/66)\n",
|
||||
"Trial: 4 (53/66)\n",
|
||||
"Trial: 4 (54/66)\n",
|
||||
"Trial: 4 (55/66)\n",
|
||||
"Trial: 4 (56/66)\n",
|
||||
"Trial: 4 (57/66)\n",
|
||||
"Trial: 4 (58/66)\n",
|
||||
"Trial: 4 (59/66)\n",
|
||||
"Trial: 4 (60/66)\n",
|
||||
"Trial: 4 (61/66)\n",
|
||||
"Trial: 4 (62/66)\n",
|
||||
"Trial: 4 (63/66)\n",
|
||||
"Trial: 4 (64/66)\n",
|
||||
"Trial: 4 (65/66)\n",
|
||||
"Finished Trial 5, Correct: 34, Incorrect: 56, Halted: 12\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agents_to_run = [a for a in agents if not a.is_correct()]\n",
|
||||
"\n",
|
||||
"while q < len(agents_to_run):\n",
|
||||
" print(f'Trial: {trial} ({q}/{len(agents_to_run)})')\n",
|
||||
" agents_to_run[q].run()\n",
|
||||
" q += 1\n",
|
||||
"\n",
|
||||
"trial += 1\n",
|
||||
"\n",
|
||||
"log += log_trial(agents, trial)\n",
|
||||
"correct, incorrect, halted = summarize_trial(agents)\n",
|
||||
"print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open('output/base_react/100_questions_5_trials.txt', 'w') as f:\n",
|
||||
" f.write(log)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['output/base_react_dicts.joblib']"
|
||||
]
|
||||
},
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dicts = [dict(a.__dict__) for a in agents]\n",
|
||||
"for d in dicts:\n",
|
||||
" for k, v in d.items():\n",
|
||||
" d[k] = str(v)\n",
|
||||
"\n",
|
||||
"joblib.dump(dicts, 'output/base_react_dicts.joblib')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.9"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "e23f799cbd2581634725fbf6ce3480ae26192d78438dfafc8efe944acd6490d5"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,215 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import joblib\n",
|
||||
"from react_cls import ReactReflectAgent, format_reflections\n",
|
||||
"from mocks import DocStoreExplorerMock, LLMMock"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def summarize_trial(agents):\n",
|
||||
" correct = [a for a in agents if a.is_correct()]\n",
|
||||
" incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]\n",
|
||||
" return correct, incorrect\n",
|
||||
"\n",
|
||||
"def remove_fewshot(prompt: str) -> str:\n",
|
||||
" prefix = prompt.split('Here are some examples:')[0]\n",
|
||||
" suffix = prompt.split('(END OF EXAMPLES)')[1]\n",
|
||||
" return prefix.strip('\\n').strip() +'\\n' + suffix.strip('\\n').strip()\n",
|
||||
"\n",
|
||||
"def log_trial(agents, trial_n):\n",
|
||||
" correct, incorrect = summarize_trial(agents)\n",
|
||||
"\n",
|
||||
" log = f\"\"\"\n",
|
||||
"########################################\n",
|
||||
"BEGIN TRIAL {trial_n}\n",
|
||||
"Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}\n",
|
||||
"#######################################\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
" log += '------------- BEGIN CORRECT AGENTS -------------\\n\\n'\n",
|
||||
" for agent in correct:\n",
|
||||
" log += remove_fewshot(agent._build_agent_prompt()) + f'\\nCorrect answer: {agent.key}\\n\\n'\n",
|
||||
"\n",
|
||||
" log += '------------- BEGIN INCORRECT AGENTS -----------\\n\\n'\n",
|
||||
" for agent in incorrect:\n",
|
||||
" log += remove_fewshot(agent._build_agent_prompt()) + f'\\nCorrect answer: {agent.key}\\n\\n'\n",
|
||||
"\n",
|
||||
" return log\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hotpot = joblib.load('data/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agents = [ReactReflectAgent(row['question'], row['answer']) for _, row in hotpot.iterrows()]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trial = 0\n",
|
||||
"log = ''\n",
|
||||
"last_correct = 0 "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for agent in [a for a in agents if not a.is_correct()]:\n",
|
||||
" agent.run(reflect_strategy='last_attempt')\n",
|
||||
" print(f'Answer: {agent.key}')\n",
|
||||
"trial += 1\n",
|
||||
"log += log_trial(agents, trial)\n",
|
||||
"correct, incorrect = summarize_trial(agents)\n",
|
||||
"print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['output/last_trial_react/react_incorrect_dicts_trial_0.joblib']"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dicts = [dict(a.__dict__) for a in incorrect]\n",
|
||||
"for d in dicts:\n",
|
||||
" for k, v in d.items():\n",
|
||||
" d[k] = str(v)\n",
|
||||
"\n",
|
||||
"joblib.dump(dicts, 'output/last_trial_react/react_incorrect_dicts_trial_0.joblib')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"while last_correct != correct:\n",
|
||||
" last_correct, _ = summarize_trial(agents)\n",
|
||||
" for agent in [a for a in agents if not a.is_correct()]:\n",
|
||||
" agent.run(reflect_strategy='last_attempt')\n",
|
||||
" print(f'Answer: {agent.key}')\n",
|
||||
" trial += 1\n",
|
||||
" log += log_trial(agents, trial)\n",
|
||||
" correct, incorrect = summarize_trial(agents)\n",
|
||||
" print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for agent in [a for a in agents if not a.is_correct()]:\n",
|
||||
" agent.run(reflect_strategy='last_attempt + reflexion')\n",
|
||||
" print(f'Answer: {agent.key}')\n",
|
||||
"trial += 1\n",
|
||||
"log += log_trial(agents, trial)\n",
|
||||
"correct, incorrect = summarize_trial(agents)\n",
|
||||
"print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open('output/last_trial_react/100_questions_5_trials.txt', 'w') as f:\n",
|
||||
" f.write(log)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['output/reflect/react_reflect_50_correct_dicts.joblib']"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dicts = [dict(a.__dict__) for a in correct]\n",
|
||||
"for d in dicts:\n",
|
||||
" for k, v in d.items():\n",
|
||||
" d[k] = str(v)\n",
|
||||
"\n",
|
||||
"joblib.dump(dicts, 'output/reflect/react_reflect_50_correct_dicts.joblib')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.9"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "e23f799cbd2581634725fbf6ce3480ae26192d78438dfafc8efe944acd6490d5"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Reference in New Issue