diff --git a/alfworld/agents/controller/oracle.py b/alfworld/agents/controller/oracle.py index feeaf3e..3865c69 100644 --- a/alfworld/agents/controller/oracle.py +++ b/alfworld/agents/controller/oracle.py @@ -193,8 +193,10 @@ def step(self, action_str): event = self.env.step({'action': "PickupObject", 'objectId': object['object_id'], 'forceAction': True}) - self.inventory.append(object['num_id']) - self.feedback = "You pick up the %s from the %s." % (obj, tar) + + if event.metadata['lastActionSuccess']: + self.inventory.append(object['num_id']) + self.feedback = "You pick up the %s from the %s." % (obj, tar) elif cmd['action'] == self.Action.PUT: obj, rel, tar = cmd['obj'], cmd['rel'], cmd['tar'] @@ -329,6 +331,8 @@ def step(self, action_str): if event and not event.metadata['lastActionSuccess']: self.feedback = "Nothing happens." + if self.debug: + print(event.metadata['errorMessage']) if self.debug: print(self.feedback) diff --git a/alfworld/agents/environment/alfred_tw_env.py b/alfworld/agents/environment/alfred_tw_env.py index e039fd0..b3ed432 100644 --- a/alfworld/agents/environment/alfred_tw_env.py +++ b/alfworld/agents/environment/alfred_tw_env.py @@ -92,8 +92,8 @@ def _gather_infos(self): def load(self, gamefile): super().load(gamefile) self.gamefile = gamefile - self.request_infos.policy_commands = (self.expert_type == AlfredExpertType.PLANNER) - self.request_infos.facts = (self.expert_type == AlfredExpertType.HANDCODED) + self.request_infos.policy_commands = self.request_infos.policy_commands or (self.expert_type == AlfredExpertType.PLANNER) + self.request_infos.facts = self.request_infos.facts or (self.expert_type == AlfredExpertType.HANDCODED) self._handcoded_expert = HandCodedTWAgent(max_steps=200) def step(self, command): diff --git a/alfworld/gen/constants.py b/alfworld/gen/constants.py index 6e9976e..4233c10 100644 --- a/alfworld/gen/constants.py +++ b/alfworld/gen/constants.py @@ -1,4 +1,5 @@ from collections import OrderedDict +import os ######################################################################################################################## # General Settings @@ -85,7 +86,7 @@ # Unity Hyperparameters BUILD_PATH = None -X_DISPLAY = '0' +X_DISPLAY = None AGENT_STEP_SIZE = 0.25 AGENT_HORIZON_ADJ = 15 diff --git a/alfworld/info.py b/alfworld/info.py index fd9ae8a..f5be671 100644 --- a/alfworld/info.py +++ b/alfworld/info.py @@ -1,4 +1,4 @@ -__version__ = '0.3.3' +__version__ = '0.3.5' import os from os.path import join as pjoin diff --git a/scripts/alfworld-play-thor b/scripts/alfworld-play-thor index bf7971d..fc6f512 100755 --- a/scripts/alfworld-play-thor +++ b/scripts/alfworld-play-thor @@ -14,6 +14,22 @@ from alfworld.env.thor_env import ThorEnv from alfworld.agents.detector.mrcnn import load_pretrained_model from alfworld.agents.controller import OracleAgent, OracleAStarAgent, MaskRCNNAgent, MaskRCNNAStarAgent +prompt_toolkit_available = False +try: + # For command line history and autocompletion. + from prompt_toolkit import prompt + from prompt_toolkit.completion import WordCompleter + from prompt_toolkit.history import InMemoryHistory + prompt_toolkit_available = sys.stdout.isatty() +except ImportError: + pass + +try: + # For command line history when prompt_toolkit is not available. + import readline # noqa: F401 +except ImportError: + pass + def setup_scene(env, traj_data, r_idx, args, reward_type='dense'): # scene setup @@ -73,9 +89,20 @@ def main(args): else: raise NotImplementedError() + history = None + if prompt_toolkit_available: + history = InMemoryHistory() + print(agent.feedback) while True: - cmd = input() + if prompt_toolkit_available: + actions_completer = None + admissible_commands = agent.get_admissible_commands() + actions_completer = WordCompleter(admissible_commands, ignore_case=True, sentence=True) + cmd = prompt('> ', completer=actions_completer, history=history, enable_history_search=True) + else: + cmd = input('> ') + if cmd == "ipdb": from ipdb import set_trace; set_trace() continue @@ -104,9 +131,12 @@ if __name__ == "__main__": if args.problem is None: problems = glob.glob(pjoin(ALFWORLD_DATA, "**", "initial_state.pddl"), recursive=True) - args.problem = os.path.dirname(random.choice(problems)) - - main(args) + # Remove problem which contains movable receptacles. + problems = [p for p in problems if "movable_recep" not in p] + args.problem = os.path.dirname(random.choice(problems)) + if "movable_recep" in args.problem: + raise ValueError("This problem contains movable receptacles, which is not supported by ALFWorld.") + main(args) diff --git a/scripts/alfworld-play-tw b/scripts/alfworld-play-tw index 5d4bd4b..8e52bb0 100755 --- a/scripts/alfworld-play-tw +++ b/scripts/alfworld-play-tw @@ -87,6 +87,12 @@ if __name__ == "__main__": if args.problem is None: problems = glob.glob(pjoin(ALFWORLD_DATA, "**", "initial_state.pddl"), recursive=True) + + # Remove problem which contains movable receptacles. + problems = [p for p in problems if "movable_recep" not in p] args.problem = os.path.dirname(random.choice(problems)) + if "movable_recep" in args.problem: + raise ValueError("This problem contains movable receptacles, which is not supported by ALFWorld.") + main(args)