You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

102 lines
2.9 KiB

import re
import string
from typing import Tuple
import gym
from langchain import Wikipedia
from langchain.agents.react.base import DocstoreExplorer
class QAEnv(gym.Env):
def __init__(self,
question: str,
key: str,
max_steps: int = 6,
explorer: DocstoreExplorer = DocstoreExplorer(Wikipedia())):
self.question = question
self.key = key
self.max_steps = max_steps
self.explorer = explorer
def reset(self):
self.curr_step = 0
self.terminated = False
self.answer = ''
def step(self, action: str) -> Tuple[str, bool, bool, bool, bool]:
action_type, argument = parse_action(action)
if action_type == 'Finish':
self.answer = argument
if self.is_correct():
observation = 'Answer is CORRECT'
observation = 'Answer is INCORRECT'
self.terminated = True
elif action_type == 'Search':
observation ='\n').strip()
except Exception as e:
observation = f'Could not find that page, please try again.'
elif action_type == 'Lookup':
observation = self.explorer.lookup(argument).strip('\n').strip()
except ValueError:
observation = f'The last page Searched was not found, so you cannot Lookup a keyword in it. Please try one of the similar pages given.'
observation = 'Invalid Action. Valid Actions are Lookup[<topic>] Search[<topic>] and Finish[<answer>].'
reward = self.is_correct()
terminated = self.is_terminated()
truncated = self.is_truncated()
self.curr_step += 1
return observation, reward, terminated, truncated, self.curr_step
def is_correct(self) -> bool:
return EM(self.answer, self.key)
def is_terminated(self) -> bool:
return self.terminated
def is_truncated(self) -> bool:
return self.curr_step >= self.max_steps
def parse_action(string):
pattern = r'^(\w+)\[(.+)\]$'
match = re.match(pattern, string)
if match:
action_type =
argument =
return action_type, argument
return None, None
def normalize_answer(s):
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def EM(answer, key) -> bool:
return normalize_answer(answer) == normalize_answer(key)