You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
reflexion-human-eval/hotpotqa_runs/environment.py

102 lines
2.9 KiB
Python

import re
import string
from typing import Tuple
import gym
from langchain import Wikipedia
from langchain.agents.react.base import DocstoreExplorer
class QAEnv(gym.Env):
def __init__(self,
question: str,
key: str,
max_steps: int = 6,
explorer: DocstoreExplorer = DocstoreExplorer(Wikipedia())):
self.question = question
self.key = key
self.max_steps = max_steps
self.explorer = explorer
self.reset()
def reset(self):
self.curr_step = 0
self.terminated = False
self.answer = ''
def step(self, action: str) -> Tuple[str, bool, bool, bool, bool]:
action_type, argument = parse_action(action)
if action_type == 'Finish':
self.answer = argument
if self.is_correct():
observation = 'Answer is CORRECT'
else:
observation = 'Answer is INCORRECT'
self.terminated = True
elif action_type == 'Search':
try:
observation = self.explorer.search(argument).strip('\n').strip()
except Exception as e:
print(e)
observation = f'Could not find that page, please try again.'
elif action_type == 'Lookup':
try:
observation = self.explorer.lookup(argument).strip('\n').strip()
except ValueError:
observation = f'The last page Searched was not found, so you cannot Lookup a keyword in it. Please try one of the similar pages given.'
else:
observation = 'Invalid Action. Valid Actions are Lookup[<topic>] Search[<topic>] and Finish[<answer>].'
reward = self.is_correct()
terminated = self.is_terminated()
truncated = self.is_truncated()
self.curr_step += 1
return observation, reward, terminated, truncated, self.curr_step
def is_correct(self) -> bool:
return EM(self.answer, self.key)
def is_terminated(self) -> bool:
return self.terminated
def is_truncated(self) -> bool:
return self.curr_step >= self.max_steps
def parse_action(string):
pattern = r'^(\w+)\[(.+)\]$'
match = re.match(pattern, string)
if match:
action_type = match.group(1)
argument = match.group(2)
return action_type, argument
else:
return None, None
def normalize_answer(s):
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def EM(answer, key) -> bool:
return normalize_answer(answer) == normalize_answer(key)