Run Hotpot experiments

hotpot-enhancement
Beck LaBash 9 months ago
parent a8e13b1b0f
commit cefdb0ed53

3
.gitignore vendored

@ -1,3 +1,4 @@
agents/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@ -133,4 +134,4 @@ scratch/
scratch/
.vscode/
.DS_Store/
.DS_Store

@ -7,6 +7,8 @@ from langchain.llms.base import BaseLLM
from langchain.agents.react.base import DocstoreExplorer
from langchain.docstore.base import Docstore
from langchain.prompts import PromptTemplate
from llm import AnyOpenAILLM
import logging
from prompts import reflect_prompt, react_agent_prompt, react_reflect_agent_prompt, REFLECTION_HEADER, LAST_TRIAL_HEADER, REFLECTION_AFTER_LAST_TRIAL_HEADER
from prompts import cot_agent_prompt, cot_reflect_agent_prompt, cot_reflect_prompt, COT_INSTRUCTION, COT_REFLECT_INSTRUCTION
from fewshots import WEBTHINK_SIMPLE6, REFLECTIONS, COT, COT_REFLECT
@ -59,6 +61,7 @@ class CoTAgent:
self.reflections_str = ''
self.answer = ''
self.step_n: int = 0
self.enc = tiktoken.encoding_for_model(self.self_reflect_llm.model_name)
self.reset()
def run(self,
@ -73,14 +76,14 @@ class CoTAgent:
# Think
self.scratchpad += f'\nThought:'
self.scratchpad += ' ' + self.prompt_agent()
print(self.scratchpad.split('\n')[-1])
logging.info(self.scratchpad.split('\n')[-1])
# Act
self.scratchpad += f'\nAction:'
action = self.prompt_agent()
self.scratchpad += ' ' + action
action_type, argument = parse_action(action)
print(self.scratchpad.split('\n')[-1])
logging.info(self.scratchpad.split('\n')[-1])
self.scratchpad += f'\nObservation: '
if action_type == 'Finish':
@ -92,24 +95,24 @@ class CoTAgent:
self.finished = True
return
else:
print('Invalid action type, please try again.')
self.scratchpad += 'Invalid Action. Valid Actions are Finish[<answer>].'
def reflect(self,
strategy: ReflexionStrategy) -> None:
print('Running Reflexion strategy...')
logging.info('Running Reflexion strategy...')
if strategy == ReflexionStrategy.LAST_ATTEMPT:
self.reflections = [self.scratchpad]
self.reflections_str = format_last_attempt(self.question , self.reflections[0])
self.reflections_str = format_last_attempt(self.question, self.reflections[0], self.enc)
elif strategy == ReflexionStrategy.REFLEXION:
self.reflections += [self.prompt_reflection()]
self.reflections_str = format_reflections(self.reflections)
elif strategy == ReflexionStrategy.LAST_ATTEMPT_AND_REFLEXION:
self.reflections_str = format_last_attempt(self.question , self.scratchpad)
self.reflections_str = format_last_attempt(self.question , self.scratchpad, self.enc)
self.reflections = [self.prompt_reflection()]
self.reflections_str += '\n'+ format_reflections(self.reflections, header = REFLECTION_AFTER_LAST_TRIAL_HEADER)
else:
raise NotImplementedError(f'Unknown reflection strategy: {strategy}')
print(self.reflections_str)
logging.info('Reflections: ' + self.reflections_str)
def prompt_reflection(self) -> str:
return format_step(self.self_reflect_llm(self._build_reflection_prompt()))
@ -168,7 +171,7 @@ class ReactAgent:
self.docstore = DocstoreExplorer(docstore) # Search, Lookup
self.llm = react_llm
self.enc = tiktoken.encoding_for_model("text-davinci-003")
self.enc = tiktoken.encoding_for_model(react_llm.model_name)
self.__reset_agent()
@ -183,14 +186,14 @@ class ReactAgent:
# Think
self.scratchpad += f'\nThought {self.step_n}:'
self.scratchpad += ' ' + self.prompt_agent()
print(self.scratchpad.split('\n')[-1])
logging.info(self.scratchpad.split('\n')[-1])
# Act
self.scratchpad += f'\nAction {self.step_n}:'
action = self.prompt_agent()
self.scratchpad += ' ' + action
action_type, argument = parse_action(action)
print(self.scratchpad.split('\n')[-1])
logging.info(self.scratchpad.split('\n')[-1])
# Observe
self.scratchpad += f'\nObservation {self.step_n}: '
@ -209,7 +212,7 @@ class ReactAgent:
try:
self.scratchpad += format_step(self.docstore.search(argument))
except Exception as e:
print(e)
logging.info('Wikipedia error: ' + e)
self.scratchpad += f'Could not find that page, please try again.'
elif action_type == 'Lookup':
@ -221,7 +224,7 @@ class ReactAgent:
else:
self.scratchpad += 'Invalid Action. Valid Actions are Lookup[<topic>] Search[<topic>] and Finish[<answer>].'
print(self.scratchpad.split('\n')[-1])
logging.info(self.scratchpad.split('\n')[-1])
self.step_n += 1
@ -288,20 +291,20 @@ class ReactReflectAgent(ReactAgent):
def reflect(self,
strategy: ReflexionStrategy) -> None:
print('Reflecting...')
logging.info('Reflecting...')
if strategy == ReflexionStrategy.LAST_ATTEMPT:
self.reflections = [self.scratchpad]
self.reflections_str = format_last_attempt(self.question, self.reflections[0])
self.reflections_str = format_last_attempt(self.question, self.reflections[0], self.enc)
elif strategy == ReflexionStrategy.REFLEXION:
self.reflections += [self.prompt_reflection()]
self.reflections_str = format_reflections(self.reflections)
elif strategy == ReflexionStrategy.LAST_ATTEMPT_AND_REFLEXION:
self.reflections_str = format_last_attempt(self.question, self.scratchpad)
self.reflections_str = format_last_attempt(self.question, self.scratchpad, self.enc)
self.reflections = [self.prompt_reflection()]
self.reflections_str += format_reflections(self.reflections, header = REFLECTION_AFTER_LAST_TRIAL_HEADER)
else:
raise NotImplementedError(f'Unknown reflection strategy: {strategy}')
print(self.reflections_str)
logging.info(self.reflections_str)
def prompt_reflection(self) -> str:
return format_step(self.reflect_llm(self._build_reflection_prompt()))
@ -333,7 +336,8 @@ def parse_action(string):
return action_type, argument
else:
return None
logging.error(f'Invalid action: {string}')
return None, string
def format_step(step: str) -> str:
return step.strip('\n').strip().replace('\n', '')
@ -347,10 +351,11 @@ def format_reflections(reflections: List[str],
def format_last_attempt(question: str,
scratchpad: str,
header: str = LAST_TRIAL_HEADER):
return header + f'Question: {question}\n' + truncate_scratchpad(scratchpad, tokenizer=gpt2_enc).strip('\n').strip() + '\n(END PREVIOUS TRIAL)\n'
header: str = LAST_TRIAL_HEADER,
tokenizer = gpt2_enc):
return header + f'Question: {question}\n' + truncate_scratchpad(scratchpad, tokenizer=tokenizer).strip('\n').strip() + '\n(END PREVIOUS TRIAL)\n'
def truncate_scratchpad(scratchpad: str, n_tokens: int = 1600, tokenizer = gpt2_enc) -> str:
def truncate_scratchpad(scratchpad: str, n_tokens: int = 1350, tokenizer = gpt2_enc) -> str:
lines = scratchpad.split('\n')
observations = filter(lambda x: x.startswith('Observation'), lines)
observations_by_tokens = sorted(observations, key=lambda x: len(tokenizer.encode(x)))

@ -0,0 +1,31 @@
import os
from typing import Union, Literal
from langchain.chat_models import ChatOpenAI
from langchain import OpenAI
from langchain.schema import (
HumanMessage
)
class AnyOpenAILLM:
def __init__(self, *args, **kwargs):
# Determine model type from the kwargs
model_name = kwargs.get('model_name', 'gpt-3.5-turbo')
if model_name.split('-')[0] == 'text':
self.model = OpenAI(*args, **kwargs)
self.model_type = 'completion'
else:
self.model = ChatOpenAI(*args, **kwargs)
self.model_type = 'chat'
self.model_name = model_name
def __call__(self, prompt: str):
if self.model_type == 'completion':
return self.model(prompt)
else:
return self.model(
[
HumanMessage(
content=prompt,
)
]
).content

@ -10,7 +10,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -21,14 +21,38 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import dotenv\n",
"dotenv.load_dotenv()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import joblib\n",
"import json\n",
"import logging\n",
"import numpy as np\n",
"from llm import AnyOpenAILLM\n",
"from agents import CoTAgent, ReflexionStrategy\n",
"from util import summarize_trial, log_trial, save_agents"
"from util import summarize_trial, log_trial, save_agents, save_results"
]
},
{
@ -41,7 +65,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@ -70,7 +94,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 14,
"metadata": {},
"outputs": [
{
@ -92,13 +116,30 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"strategy: ReflexionStrategy = ReflexionStrategy.REFLEXION"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Define the LLMs"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"self_reflect_llm = \"gpt-4\"\n",
"action_llm = \"gpt-4\""
]
},
{
"attachments": {},
"cell_type": "markdown",
@ -109,7 +150,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@ -122,6 +163,18 @@
" cot_examples=COT,\n",
" reflect_prompt=cot_reflect_prompt,\n",
" reflect_examples=COT_REFLECT,\n",
" self_reflect_llm= AnyOpenAILLM(\n",
" temperature=0,\n",
" max_tokens=250,\n",
" model_name=self_reflect_llm,\n",
" model_kwargs={\"stop\": \"\\n\"},\n",
" openai_api_key=os.environ['OPENAI_API_KEY']),\n",
" action_llm= AnyOpenAILLM(\n",
" temperature=0,\n",
" max_tokens=250,\n",
" model_name=action_llm,\n",
" model_kwargs={\"stop\": \"\\n\"},\n",
" openai_api_key=os.environ['OPENAI_API_KEY']),\n",
" ) for _, row in hotpot.iterrows()]"
]
},
@ -135,28 +188,212 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"n = 5\n",
"trial = 0\n",
"log = ''"
"n = 9 \n",
"trial = 5 \n",
"log = ''\n",
"results = []\n",
"\n",
"llm = 'gpt4'\n",
"run_dir = os.path.join(root, 'CoT', 'context', strategy.value + f'_{llm}')\n",
"os.makedirs(run_dir, exist_ok=True)\n",
"\n",
"# Basic config\n",
"logging.basicConfig(\n",
" level=logging.INFO,\n",
" format=\"%(asctime)s [%(levelname)s] %(message)s\",\n",
" handlers=[\n",
" logging.FileHandler(os.path.join(run_dir, f'{n}_{trial}.log')),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Optional: Load agents"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"for i in range(n):\n",
"agents = []\n",
"agents_dir = os.path.join(run_dir, 'agents')\n",
"for file in os.listdir(agents_dir):\n",
" agents.append(joblib.load(os.path.join(agents_dir, file)))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer: 22 November\n",
"Answer: Roman\n",
"Answer: in the village of Aldenham\n",
"Answer: author\n",
"Answer: a failed coup attempt\n",
"Answer: novelist\n",
"Answer: California\n",
"Answer: super-regional shopping mall\n",
"Answer: singer, songwriter\n",
"Answer: German\n",
"Answer: the port of Baltimore west to Sandy Hook\n",
"Answer: New York\n",
"Answer: Frederick Alexander\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"ERROR:root:Invalid action: No, only Affiliated Managers Group.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer: no\n",
"Answer: cleaning, catering and security\n",
"Answer: chronological collection of critical quotations\n",
"Answer: The Bad Hemingway Contest\n",
"Answer: \"Now and Then\" (1995)\n",
"Answer: no\n",
"Answer: the Cold War (194791)\n",
"Answer: fortnightly women interest magazine\n",
"Answer: 2 March 1972\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"ERROR:root:Invalid action: Upon further research, I have confirmed that Larnelle Steward Harris was indeed born in the month of July. Therefore, the answer to the question is Yes, David Huntsinger has worked with a gospel singer born in the month of July.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer: Larnelle Harris\n",
"Answer: The Bears\n",
"Answer: Vivendi S.A.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"ERROR:root:Invalid action: No, only Maxillaria is a genus of orchids.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer: no\n",
"Answer: fictional character\n",
"Answer: 10 January 1920\n",
"Answer: Ricard Rubio i Vives\n",
"Finished Trial 6, Correct: 71, Incorrect: 26\n",
"Answer: 22 November\n",
"Answer: Roman\n",
"Answer: in the village of Aldenham\n",
"Answer: author\n",
"Answer: a failed coup attempt\n",
"Answer: novelist\n",
"Answer: California\n",
"Answer: super-regional shopping mall\n",
"Answer: singer, songwriter\n",
"Answer: German\n",
"Answer: the port of Baltimore west to Sandy Hook\n",
"Answer: New York\n",
"Answer: Frederick Alexander\n",
"Answer: no\n",
"Answer: cleaning, catering and security\n",
"Answer: chronological collection of critical quotations\n",
"Answer: The Bad Hemingway Contest\n",
"Answer: \"Now and Then\" (1995)\n",
"Answer: no\n",
"Answer: the Cold War (194791)\n",
"Answer: fortnightly women interest magazine\n",
"Answer: 2 March 1972\n",
"Answer: Larnelle Harris\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"ERROR:root:Invalid action: The context does not provide information about the mascot of Mercer University.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer: The Bears\n",
"Answer: Vivendi S.A.\n",
"Answer: no\n",
"Answer: fictional character\n",
"Answer: 10 January 1920\n",
"Answer: Ricard Rubio i Vives\n",
"Finished Trial 7, Correct: 71, Incorrect: 28\n",
"Answer: 22 November\n",
"Answer: Roman\n",
"Answer: in the village of Aldenham\n",
"Answer: author\n",
"Answer: a failed coup attempt\n",
"Answer: novelist\n",
"Answer: California\n",
"Answer: super-regional shopping mall\n",
"Answer: singer, songwriter\n",
"Answer: German\n",
"Answer: the port of Baltimore west to Sandy Hook\n",
"Answer: New York\n",
"Answer: Frederick Alexander\n",
"Answer: no\n",
"Answer: cleaning, catering and security\n",
"Answer: chronological collection of critical quotations\n",
"Answer: The Bad Hemingway Contest\n",
"Answer: \"Now and Then\" (1995)\n",
"Answer: no\n",
"Answer: the Cold War (194791)\n",
"Answer: fortnightly women interest magazine\n",
"Answer: 2 March 1972\n",
"Answer: Larnelle Harris\n",
"Answer: The Bears\n",
"Answer: Vivendi S.A.\n",
"Answer: no\n",
"Answer: fictional character\n",
"Answer: 10 January 1920\n",
"Answer: Ricard Rubio i Vives\n",
"Finished Trial 8, Correct: 71, Incorrect: 29\n"
]
}
],
"source": [
"for i in range(3):\n",
" for agent in [a for a in agents if not a.is_correct()]:\n",
" agent.run(reflexion_strategy = strategy)\n",
" print(f'Answer: {agent.key}')\n",
" trial += 1\n",
" log += log_trial(agents, trial)\n",
" correct, incorrect = summarize_trial(agents)\n",
" results.append({'trial': trial, 'correct': len(correct), 'incorrect': len(incorrect)})\n",
" save_results(agents, results, run_dir)\n",
" print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')"
]
},
@ -170,13 +407,120 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 13,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving agent 0...\n",
"Saving agent 1...\n",
"Saving agent 2...\n",
"Saving agent 3...\n",
"Saving agent 4...\n",
"Saving agent 5...\n",
"Saving agent 6...\n",
"Saving agent 7...\n",
"Saving agent 8...\n",
"Saving agent 9...\n",
"Saving agent 10...\n",
"Saving agent 11...\n",
"Saving agent 12...\n",
"Saving agent 13...\n",
"Saving agent 14...\n",
"Saving agent 15...\n",
"Saving agent 16...\n",
"Saving agent 17...\n",
"Saving agent 18...\n",
"Saving agent 19...\n",
"Saving agent 20...\n",
"Saving agent 21...\n",
"Saving agent 22...\n",
"Saving agent 23...\n",
"Saving agent 24...\n",
"Saving agent 25...\n",
"Saving agent 26...\n",
"Saving agent 27...\n",
"Saving agent 28...\n",
"Saving agent 29...\n",
"Saving agent 30...\n",
"Saving agent 31...\n",
"Saving agent 32...\n",
"Saving agent 33...\n",
"Saving agent 34...\n",
"Saving agent 35...\n",
"Saving agent 36...\n",
"Saving agent 37...\n",
"Saving agent 38...\n",
"Saving agent 39...\n",
"Saving agent 40...\n",
"Saving agent 41...\n",
"Saving agent 42...\n",
"Saving agent 43...\n",
"Saving agent 44...\n",
"Saving agent 45...\n",
"Saving agent 46...\n",
"Saving agent 47...\n",
"Saving agent 48...\n",
"Saving agent 49...\n",
"Saving agent 50...\n",
"Saving agent 51...\n",
"Saving agent 52...\n",
"Saving agent 53...\n",
"Saving agent 54...\n",
"Saving agent 55...\n",
"Saving agent 56...\n",
"Saving agent 57...\n",
"Saving agent 58...\n",
"Saving agent 59...\n",
"Saving agent 60...\n",
"Saving agent 61...\n",
"Saving agent 62...\n",
"Saving agent 63...\n",
"Saving agent 64...\n",
"Saving agent 65...\n",
"Saving agent 66...\n",
"Saving agent 67...\n",
"Saving agent 68...\n",
"Saving agent 69...\n",
"Saving agent 70...\n",
"Saving agent 71...\n",
"Saving agent 72...\n",
"Saving agent 73...\n",
"Saving agent 74...\n",
"Saving agent 75...\n",
"Saving agent 76...\n",
"Saving agent 77...\n",
"Saving agent 78...\n",
"Saving agent 79...\n",
"Saving agent 80...\n",
"Saving agent 81...\n",
"Saving agent 82...\n",
"Saving agent 83...\n",
"Saving agent 84...\n",
"Saving agent 85...\n",
"Saving agent 86...\n",
"Saving agent 87...\n",
"Saving agent 88...\n",
"Saving agent 89...\n",
"Saving agent 90...\n",
"Saving agent 91...\n",
"Saving agent 92...\n",
"Saving agent 93...\n",
"Saving agent 94...\n",
"Saving agent 95...\n",
"Saving agent 96...\n",
"Saving agent 97...\n",
"Saving agent 98...\n",
"Saving agent 99...\n"
]
}
],
"source": [
"with open(os.path.join(root, 'CoT', 'context', strategy.value, f'{len(agents)}_questions_{trial}_trials.txt'), 'w') as f:\n",
"with open(os.path.join(run_dir, f'{len(agents)}_questions_{trial}_trials.txt'), 'w') as f:\n",
" f.write(log)\n",
"save_agents(agents, os.path.join(root, 'CoT', 'context', strategy.value, 'agents'))"
"save_agents(agents, os.path.join(run_dir, 'agents'))"
]
}
],

@ -19,6 +19,27 @@
"root = '../root/'"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import dotenv\n",
"dotenv.load_dotenv()"
]
},
{
"cell_type": "code",
"execution_count": 15,
@ -176,7 +197,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.11.4"
},
"orig_nbformat": 4,
"vscode": {

File diff suppressed because it is too large Load Diff

@ -10,10 +10,11 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"import sys, os\n",
"sys.path.append('..')\n",
"root = '../root/'"
@ -21,13 +22,36 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import dotenv\n",
"dotenv.load_dotenv()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import joblib\n",
"from util import summarize_react_trial, log_react_trial, save_agents\n",
"from agents import ReactReflectAgent, ReactAgent, ReflexionStrategy"
"from util import summarize_react_trial, log_react_trial, save_agents, save_results\n",
"from agents import ReactReflectAgent, ReactAgent, ReflexionStrategy\n",
"from llm import AnyOpenAILLM"
]
},
{
@ -40,7 +64,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@ -57,7 +81,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 9,
"metadata": {},
"outputs": [
{
@ -79,13 +103,30 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"strategy: ReflexionStrategy = ReflexionStrategy.REFLEXION"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Define the LLMs"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"action_llm = \"gpt-4\"\n",
"self_reflect_llm = \"gpt-4\""
]
},
{
"attachments": {},
"cell_type": "markdown",
@ -96,12 +137,38 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"agent_cls = ReactReflectAgent if strategy != ReflexionStrategy.NONE else ReactAgent\n",
"agents = [agent_cls(row['question'], row['answer']) for _, row in hotpot.iterrows()]"
"if strategy != ReflexionStrategy.NONE:\n",
" agents = [ReactReflectAgent(\n",
" row['question'],\n",
" row['answer'],\n",
" react_llm = AnyOpenAILLM(\n",
" temperature=0,\n",
" max_tokens=100,\n",
" model_name=action_llm,\n",
" model_kwargs={\"stop\": \"\\n\"},\n",
" openai_api_key=os.environ['OPENAI_API_KEY']),\n",
" reflect_llm = AnyOpenAILLM(\n",
" temperature=0,\n",
" max_tokens=250,\n",
" model_name=self_reflect_llm,\n",
" openai_api_key=os.environ['OPENAI_API_KEY']),\n",
" ) for _, row in hotpot.iterrows()]\n",
"else:\n",
" agents = [ReactAgent(\n",
" row['question'],\n",
" row['answer'],\n",
" react_llm = AnyOpenAILLM(\n",
" temperature=0,\n",
" max_tokens=100,\n",
" model_name=action_llm,\n",
" model_kwargs={\"stop\": \"\\n\"},\n",
" openai_api_key=os.environ['OPENAI_API_KEY']),\n",
" ) for _, row in hotpot.iterrows()]"
]
},
{
@ -120,14 +187,49 @@
"source": [
"n = 5\n",
"trial = 0\n",
"log = ''"
"log = ''\n",
"results = []\n",
"\n",
"run_dir = os.path.join(root, 'ReAct', strategy.value + '_'+ self_reflect_llm)\n",
"os.makedirs(run_dir, exist_ok=True)\n",
"\n",
"# Basic logging config that writes to a file\n",
"logging.basicConfig(\n",
" filename=os.path.join(run_dir, f'{n}_questions.log'),\n",
" filemode='w',\n",
" format='%(asctime)s - %(levelname)s - %(message)s',\n",
" level=logging.INFO\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/becklabash/Documents/Research/reflexion/reflexion/env/lib/python3.11/site-packages/wikipedia/wikipedia.py:389: GuessedAtParserWarning: No parser was explicitly specified, so I'm using the best available HTML parser for this system (\"html.parser\"). This usually isn't a problem, but if you run this code on another system, or in a different virtual environment, it may use a different parser and behave differently.\n",
"\n",
"The code that caused this warning is on line 389 of the file /Users/becklabash/Documents/Research/reflexion/reflexion/env/lib/python3.11/site-packages/wikipedia/wikipedia.py. To get rid of this warning, pass the additional argument 'features=\"html.parser\"' to the BeautifulSoup constructor.\n",
"\n",
" lis = BeautifulSoup(html).find_all('li')\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Finished Trial 1, Correct: 39, Incorrect: 45, Halted: 16\n",
"Finished Trial 2, Correct: 47, Incorrect: 30, Halted: 23\n",
"Finished Trial 3, Correct: 51, Incorrect: 29, Halted: 20\n",
"Finished Trial 4, Correct: 52, Incorrect: 25, Halted: 23\n",
"Finished Trial 5, Correct: 53, Incorrect: 26, Halted: 21\n"
]
}
],
"source": [
"for i in range(n):\n",
" for agent in [a for a in agents if not a.is_correct()]:\n",
@ -135,10 +237,13 @@
" agent.run(reflect_strategy = strategy)\n",
" else:\n",
" agent.run()\n",
" print(f'Answer: {agent.key}')\n",
" logging.info(f'Answer: {agent.key}')\n",
" trial += 1\n",
" log += log_react_trial(agents, trial)\n",
" correct, incorrect, halted = summarize_react_trial(agents)\n",
" results.append({\"trial\": trial, \"correct\": len(correct), \"incorrect\": len(incorrect), \"halted\": len(halted)})\n",
" save_results(agents, results, run_dir)\n",
" #save_agents(agents, os.path.join(run_dir, 'ReAct', strategy.value, 'agents', f'trial_{trial}'))\n",
" print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)}')"
]
},
@ -156,7 +261,7 @@
"metadata": {},
"outputs": [],
"source": [
"with open(os.path.join(root, 'ReAct', strategy.value, f'{len(agents)}_questions_{trial}_trials.txt'), 'w') as f:\n",
"with open(os.path.join(run_dir, f'{len(agents)}_questions_{trial}_trials.txt'), 'w') as f:\n",
" f.write(log)\n",
"save_agents(agents, os.path.join('ReAct', strategy.value, 'agents'))"
]
@ -178,7 +283,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.11.4"
},
"orig_nbformat": 4,
"vscode": {

@ -0,0 +1 @@
[{"trial": 1, "correct": 56, "incorrect": 44}, {"trial": 2, "correct": 56, "incorrect": 44}, {"trial": 3, "correct": 57, "incorrect": 43}, {"trial": 4, "correct": 57, "incorrect": 43}, {"trial": 5, "correct": 57, "incorrect": 43}]

@ -0,0 +1 @@
[{"trial": 1, "correct": 68, "incorrect": 32}, {"trial": 2, "correct": 68, "incorrect": 32}, {"trial": 3, "correct": 68, "incorrect": 32}, {"trial": 4, "correct": 68, "incorrect": 32}, {"trial": 5, "correct": 68, "incorrect": 32}]

@ -0,0 +1 @@
[{"trial": 1, "correct": 57, "incorrect": 43}, {"trial": 2, "correct": 67, "incorrect": 32}, {"trial": 3, "correct": 68, "incorrect": 30}, {"trial": 4, "correct": 71, "incorrect": 28}, {"trial": 5, "correct": 71, "incorrect": 27}]

@ -0,0 +1 @@
[{"trial": 1, "correct": 57, "incorrect": 43}, {"trial": 2, "correct": 67, "incorrect": 32}, {"trial": 3, "correct": 68, "incorrect": 30}, {"trial": 4, "correct": 71, "incorrect": 28}, {"trial": 5, "correct": 71, "incorrect": 27}, {"trial": 6, "correct": 71, "incorrect": 26}, {"trial": 6, "correct": 71, "incorrect": 26}, {"trial": 7, "correct": 71, "incorrect": 28}, {"trial": 6, "correct": 71, "incorrect": 26}, {"trial": 7, "correct": 71, "incorrect": 28}, {"trial": 8, "correct": 71, "incorrect": 29}]

@ -0,0 +1 @@
[{"trial": 1, "correct": 68, "incorrect": 32}, {"trial": 2, "correct": 76, "incorrect": 24}, {"trial": 3, "correct": 77, "incorrect": 23}, {"trial": 4, "correct": 80, "incorrect": 20}, {"trial": 5, "correct": 80, "incorrect": 20}]

@ -0,0 +1 @@
[{"trial": 1, "correct": 26, "incorrect": 44, "halted": 30}, {"trial": 2, "correct": 33, "incorrect": 17, "halted": 50}, {"trial": 3, "correct": 34, "incorrect": 27, "halted": 39}, {"trial": 4, "correct": 38, "incorrect": 30, "halted": 32}]

@ -0,0 +1 @@
[{"trial": 1, "correct": 24, "incorrect": 43, "halted": 33}]

@ -0,0 +1 @@
[{"trial": 1, "correct": 40, "incorrect": 41, "halted": 19}]

@ -0,0 +1 @@
[{"trial": 1, "correct": 40, "incorrect": 41, "halted": 19}, {"trial": 1, "correct": 39, "incorrect": 45, "halted": 16}, {"trial": 1, "correct": 39, "incorrect": 45, "halted": 16}, {"trial": 2, "correct": 47, "incorrect": 30, "halted": 23}, {"trial": 1, "correct": 39, "incorrect": 45, "halted": 16}, {"trial": 2, "correct": 47, "incorrect": 30, "halted": 23}, {"trial": 3, "correct": 51, "incorrect": 29, "halted": 20}, {"trial": 1, "correct": 39, "incorrect": 45, "halted": 16}, {"trial": 2, "correct": 47, "incorrect": 30, "halted": 23}, {"trial": 3, "correct": 51, "incorrect": 29, "halted": 20}, {"trial": 4, "correct": 52, "incorrect": 25, "halted": 23}, {"trial": 1, "correct": 39, "incorrect": 45, "halted": 16}, {"trial": 2, "correct": 47, "incorrect": 30, "halted": 23}, {"trial": 3, "correct": 51, "incorrect": 29, "halted": 20}, {"trial": 4, "correct": 52, "incorrect": 25, "halted": 23}, {"trial": 5, "correct": 53, "incorrect": 26, "halted": 21}]

@ -0,0 +1 @@
[{"trial": 1, "correct": 42, "incorrect": 39, "halted": 19}]

@ -1,5 +1,7 @@
import os
import json
import joblib
from typing import Dict
def summarize_trial(agents):
correct = [a for a in agents if a.is_correct()]
@ -64,4 +66,18 @@ Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {le
def save_agents(agents, dir: str):
os.makedirs(dir, exist_ok=True)
for i, agent in enumerate(agents):
joblib.dump(agent, os.path.join(dir, f'{i}.joblib'))
print(f'Saving agent {i}...')
joblib.dump(agent, os.path.join(dir, f'{i}.joblib'))
def save_results(agents, results: Dict, run_dir: str):
result_file = os.path.join(run_dir, f'{len(agents)}_questions_results.json')
# Check if result file exists
existing_results = []
if os.path.exists(result_file):
# Load existing results
with open(result_file, 'r') as f:
existing_results = json.load(f)
# Add new results
existing_results.extend(results)
with open(os.path.join(run_dir, f'{len(agents)}_questions_results.json'), 'w') as f:
json.dump(existing_results, f)

@ -13,8 +13,10 @@ PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = "You are a Python writing assistant.
# The first line of your response should have 4 spaces of indentation so that it fits syntactically with the user provided signature.
PY_SIMPLE_CHAT_INSTRUCTION = "You are an AI that only responds with python code, NOT ENGLISH. You will be given a function signature and its docstring by the user. Write your full implementation (restate the function signature)."
# The first line of your response should have 4 spaces of indentation so that it fits syntactically with the user provided signature.
PY_SIMPLE_CHAT_INSTRUCTION_V2 = "You are an AI that only responds with only python code. You will be given a function signature and its docstring by the user. Write your full implementation (restate the function signature)."
PY_REFLEXION_CHAT_INSTRUCTION = "You are an AI Python assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Write your full implementation (restate the function signature)."
PY_REFLEXION_CHAT_INSTRUCTION_V2 = "You are an AI Python assistant. You will be given your previous implementation of a function, a series of unit tests results, and your self-reflection on your previous implementation. Write your full implementation (restate the function signature)."
PY_REFLEXION_FEW_SHOT_ADD = '''Example 1:
[previous impl]:

@ -50,10 +50,10 @@ END EXAMPLES
RS_TEST_GENERATION_FEW_SHOT = """For example:
func signature:
/// For a given number n, find the largest number that divides n evenly, smaller than n
/// >>> largest_divisor(15)
/// 5
fn largest_divisor(n: isize) -> isize {
/// For a given number n, find the largest number that divides n evenly, smaller than n
/// >>> largest_divisor(15)
/// 5
for i in (1..n).rev() {
if n % i == 0 {
return i;

Loading…
Cancel
Save