You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
reflexion-human-eval/hotpotqa_runs/util.py

67 lines
2.4 KiB
Python

import os
import joblib
def summarize_trial(agents):
correct = [a for a in agents if a.is_correct()]
incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]
return correct, incorrect
def remove_fewshot(prompt: str) -> str:
prefix = prompt.split('Here are some examples:')[0]
suffix = prompt.split('(END OF EXAMPLES)')[1]
return prefix.strip('\n').strip() + '\n' + suffix.strip('\n').strip()
def log_trial(agents, trial_n):
correct, incorrect = summarize_trial(agents)
log = f"""
########################################
BEGIN TRIAL {trial_n}
Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}
#######################################
"""
log += '------------- BEGIN CORRECT AGENTS -------------\n\n'
for agent in correct:
log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n'
log += '------------- BEGIN INCORRECT AGENTS -----------\n\n'
for agent in incorrect:
log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n'
return log
def summarize_react_trial(agents):
correct = [a for a in agents if a.is_correct()]
halted = [a for a in agents if a.is_halted()]
incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]
return correct, incorrect, halted
def log_react_trial(agents, trial_n):
correct, incorrect, halted = summarize_react_trial(agents)
log = f"""
########################################
BEGIN TRIAL {trial_n}
Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)}
#######################################
"""
log += '------------- BEGIN CORRECT AGENTS -------------\n\n'
for agent in correct:
log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n'
log += '------------- BEGIN INCORRECT AGENTS -----------\n\n'
for agent in incorrect:
log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n'
log += '------------- BEGIN HALTED AGENTS -----------\n\n'
for agent in halted:
log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n'
return log
def save_agents(agents, dir: str):
os.makedirs(dir, exist_ok=True)
for i, agent in enumerate(agents):
joblib.dump(agent, os.path.join(dir, f'{i}.joblib'))