diff --git a/MANIFEST.in b/MANIFEST.in index 24a63fa..9808edb 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include requirements.txt +include requirements-vis.txt include requirements-full.txt include README.md include LICENSE diff --git a/README.md b/README.md index 7e0e954..9867ba2 100644 --- a/README.md +++ b/README.md @@ -15,10 +15,23 @@ For the latest updates, see: [**alfworld.github.io**](https://alfworld.github.io ## Quickstart +Create a virtual environment (recommended) + + conda create -n alfworld python=3.9 + conda activate alfworld + +> [!WARNING] +> If you are using MacOS with an arm-based system, it is recommended to use +> + CONDA_SUBDIR=osx-64 conda create -n alfworld python=3.9 + conda activate alfworld + Install with pip (python3.9+): pip install alfworld[full] +> **Note:** Without the `full` extra, it will only install the text version of ALFWorld. To enable visual modalities, use `pip install alfworld[vis]`. + Download PDDL & Game files and pre-trained MaskRCNN detector: ```bash export ALFWORLD_DATA= @@ -104,12 +117,10 @@ Play around with [TextWorld and THOR demos](scripts/). ## Prerequisites - Python 3.9+ -- PyTorch 1.2.0 (later versions might be ok) -- Torchvision 0.4.0 (later versions might be ok) -- AI2THOR 2.1.0 -See [requirements.txt](requirements.txt) for the prerequisites to run ALFWorld. -See [requirements-full.txt](requirements-full.txt) for the prerequisites to run experiments. +See [requirements.txt](requirements.txt) for the prerequisites to run ALFWorld text-only version. +See [requirements-vis.txt](requirements.txt) for the prerequisites to run ALFWorld with both text and visual modalities. +See [requirements-full.txt](requirements-full.txt) for the full prerequisites to run experiments. ## Hardware @@ -122,6 +133,9 @@ Tested on: ## Docker Setup +> [!WARNING] +> This docker setup has been tested for an older version of ALFWorld. + Pull [vzhong](https://github.com/vzhong)'s image: https://hub.docker.com/r/vzhong/alfworld **OR** @@ -212,10 +226,6 @@ You might have to modify `X_DISPLAY` in [gen/constants.py](alfworld/gen/constant Also, checkout this guide: [Setting up THOR on Google Cloud](https://medium.com/@etendue2013/how-to-run-ai2-thor-simulation-fast-with-google-cloud-platform-gcp-c9fcde213a4a) -## Change Log - -18/12/2020: -- PIP package version available. The repo was refactored. ## Citations diff --git a/alfworld/agents/agent/base_agent.py b/alfworld/agents/agent/base_agent.py index f8ecd11..2992b30 100644 --- a/alfworld/agents/agent/base_agent.py +++ b/alfworld/agents/agent/base_agent.py @@ -1,7 +1,12 @@ import logging import numpy as np -import torch -from transformers import DistilBertModel, DistilBertTokenizer + +try: + import torch + from transformers import DistilBertModel, DistilBertTokenizer +except ImportError: + raise ImportError("torch or transformers not found. Please install them via `pip install alfworld[full]`.") + logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) import alfworld.agents.modules.memory as memory diff --git a/alfworld/agents/agent/text_dagger_agent.py b/alfworld/agents/agent/text_dagger_agent.py index c921d5e..9240a45 100644 --- a/alfworld/agents/agent/text_dagger_agent.py +++ b/alfworld/agents/agent/text_dagger_agent.py @@ -4,7 +4,10 @@ from queue import PriorityQueue import numpy as np -import torch +try: + import torch +except ImportError: + raise ImportError("torch not found. Please install them via `pip install alfworld[full]`.") logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) from alfworld.agents.agent import BaseAgent @@ -249,7 +252,6 @@ def command_generation_greedy_generation(self, observation_strings, task_desc_st break res = [self.tokenizer.decode(item) for item in input_target_list] res = [item.replace("[CLS]", "").replace("[SEP]", "").strip() for item in res] - res = [item.replace(" in / on ", " in/on " ) for item in res] return res, current_dynamics def command_generation_beam_search_generation(self, observation_strings, task_desc_strings, previous_dynamics): @@ -348,7 +350,6 @@ def command_generation_beam_search_generation(self, observation_strings, task_de utte_string = self.tokenizer.decode(utte) utterances.append(utte_string) utterances = [item.replace("[CLS]", "").replace("[SEP]", "").strip() for item in utterances] - utterances = [item.replace(" in / on ", " in/on " ) for item in utterances] res.append(utterances) return res, current_dynamics diff --git a/alfworld/agents/agent/text_dqn_agent.py b/alfworld/agents/agent/text_dqn_agent.py index cfa99de..d6941f5 100644 --- a/alfworld/agents/agent/text_dqn_agent.py +++ b/alfworld/agents/agent/text_dqn_agent.py @@ -4,8 +4,12 @@ from queue import PriorityQueue import numpy as np -import torch -import torch.nn.functional as F + +try: + import torch + import torch.nn.functional as F +except ImportError: + raise ImportError("torch not found. Please install them via `pip install alfworld[full]`.") logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) from alfworld.agents.agent import BaseAgent @@ -261,7 +265,6 @@ def command_generation_by_beam_search(self, observation_strings, task_desc_strin utterances.append(utte_string) utterances = [item.replace("[CLS]", "").replace("[SEP]", "").strip() for item in utterances] - utterances = [item.replace(" in / on ", " in/on " ) for item in utterances] chosen_actions.append(utterances) return chosen_actions, current_dynamics, obs_mask, aggregated_obs_representation @@ -302,7 +305,6 @@ def command_generation_act_greedy(self, observation_strings, task_desc_strings, break chosen_actions = [self.tokenizer.decode(item) for item in input_target_list] chosen_actions = [item.replace("[CLS]", "").replace("[SEP]", "").strip() for item in chosen_actions] - chosen_actions = [item.replace(" in / on ", " in/on " ) for item in chosen_actions] chosen_indices = [item[1:] for item in input_target_list] for i in range(len(chosen_indices)): if chosen_indices[i][-1] == self.word2id["[SEP]"]: @@ -400,7 +402,6 @@ def command_generation_act_random(self, observation_strings, task_desc_strings, indicies.append(utte) utterances = [item.replace("[CLS]", "").replace("[SEP]", "").strip() for item in utterances] - utterances = [item.replace(" in / on ", " in/on " ) for item in utterances] indicies = [item[1:] for item in indicies] for i in range(len(indicies)): if indicies[i][-1] == self.word2id["[SEP]"]: diff --git a/alfworld/agents/agent/vision_dagger_agent.py b/alfworld/agents/agent/vision_dagger_agent.py index 6502452..0052bd0 100644 --- a/alfworld/agents/agent/vision_dagger_agent.py +++ b/alfworld/agents/agent/vision_dagger_agent.py @@ -1,10 +1,11 @@ -import os -import sys import copy import numpy as np -import torch -import torch.nn.functional as F +try: + import torch + import torch.nn.functional as F +except ImportError: + raise ImportError("torch not found. Please install them via `pip install alfworld[full]`.") import alfworld.agents import alfworld.agents.modules.memory as memory diff --git a/alfworld/agents/controller/mrcnn.py b/alfworld/agents/controller/mrcnn.py index 3ce75c0..34066fa 100644 --- a/alfworld/agents/controller/mrcnn.py +++ b/alfworld/agents/controller/mrcnn.py @@ -20,7 +20,10 @@ from alfworld.agents.controller.base import BaseAgent from alfworld.agents.utils.misc import extract_admissible_commands_with_heuristics -import torchvision.transforms as T +try: + import torchvision.transforms as T +except ImportError: + raise ImportError("torchvision not found. Please install them via `pip install alfworld[full]`.") class MaskRCNNAgent(BaseAgent): diff --git a/alfworld/agents/environment/__init__.py b/alfworld/agents/environment/__init__.py index 09fe688..e69de29 100644 --- a/alfworld/agents/environment/__init__.py +++ b/alfworld/agents/environment/__init__.py @@ -1,3 +0,0 @@ -from alfworld.agents.environment.alfred_tw_env import AlfredTWEnv -from alfworld.agents.environment.alfred_thor_env import AlfredThorEnv -from alfworld.agents.environment.alfred_hybrid import AlfredHybrid \ No newline at end of file diff --git a/alfworld/agents/environment/alfred_hybrid.py b/alfworld/agents/environment/alfred_hybrid.py index 0598368..c745557 100644 --- a/alfworld/agents/environment/alfred_hybrid.py +++ b/alfworld/agents/environment/alfred_hybrid.py @@ -1,7 +1,7 @@ import random -import importlib -import alfworld.agents.environment +from alfworld.agents.environment.alfred_thor_env import AlfredThorEnv +from alfworld.agents.environment.alfred_tw_env import AlfredTWEnv class AlfredHybrid(object): @@ -28,12 +28,12 @@ def choose_env(self): return self.tw def init_env(self, batch_size): - AlfredTWEnv = getattr(alfworld.agents.environment, "AlfredTWEnv")(self.config, train_eval=self.train_eval) - AlfredThorEnv = getattr(alfworld.agents.environment, "AlfredThorEnv")(self.config, train_eval=self.train_eval) + alfred_tw_env = AlfredTWEnv(self.config, train_eval=self.train_eval) + alfred_thor_env = AlfredThorEnv(self.config, train_eval=self.train_eval) self.batch_size = batch_size - self.tw = AlfredTWEnv.init_env(batch_size) - self.thor = AlfredThorEnv.init_env(batch_size) + self.tw = alfred_tw_env.init_env(batch_size) + self.thor = alfred_thor_env.init_env(batch_size) return self def seed(self, num): @@ -56,4 +56,4 @@ def reset(self): env = self.choose_env() obs, infos = env.reset() self.num_resets += self.batch_size - return obs, infos \ No newline at end of file + return obs, infos diff --git a/alfworld/agents/environment/alfred_thor_env.py b/alfworld/agents/environment/alfred_thor_env.py index 65ee9df..01177c0 100644 --- a/alfworld/agents/environment/alfred_thor_env.py +++ b/alfworld/agents/environment/alfred_thor_env.py @@ -1,17 +1,15 @@ import os import json -import glob import numpy as np import traceback import threading from queue import Queue from threading import Thread -import sys import random import alfworld.agents -from alfworld.agents.utils.misc import Demangler, get_templated_task_desc, add_task_to_grammar +from alfworld.agents.utils.misc import get_templated_task_desc from alfworld.env.thor_env import ThorEnv from alfworld.agents.expert import HandCodedThorAgent, HandCodedAgentTimeout from alfworld.agents.detector.mrcnn import load_pretrained_model diff --git a/alfworld/agents/expert/handcoded_expert.py b/alfworld/agents/expert/handcoded_expert.py index 6ac6219..9fa8c3d 100644 --- a/alfworld/agents/expert/handcoded_expert.py +++ b/alfworld/agents/expert/handcoded_expert.py @@ -236,7 +236,7 @@ def act(self, game_state, last_action): # if holding something irrelavant, then discard it from where it was pickedup if len(self.inventory) > 0 and not self.is_agent_holding_right_object: if self.curr_recep == self.got_inventory_from_recep: - return "put {} in/on {}".format(self.inventory[0], self.got_inventory_from_recep) + return "move {} to {}".format(self.inventory[0], self.got_inventory_from_recep) else: return "go to {}".format(self.got_inventory_from_recep) @@ -276,7 +276,7 @@ def act(self, game_state, last_action): return "open {}".format(self.curr_recep) else: obj = self.inventory[0] - return "put {} in/on {}".format(obj, self.curr_recep) + return "move {} to {}".format(obj, self.curr_recep) # OPEN if sub_action == 'open': @@ -321,7 +321,7 @@ def act(self, game_state, last_action): # if holding something irrelavant, then discard it from where it was pickedup if len(self.inventory) > 0 and not self.is_agent_holding_right_object: if self.curr_recep == self.got_inventory_from_recep: - return "put {} in/on {}".format(self.inventory[0], self.got_inventory_from_recep) + return "move {} to {}".format(self.inventory[0], self.got_inventory_from_recep) else: return "go to {}".format(self.got_inventory_from_recep) diff --git a/alfworld/agents/expert/handcoded_expert_tw.py b/alfworld/agents/expert/handcoded_expert_tw.py index d58ac21..da8bdc1 100644 --- a/alfworld/agents/expert/handcoded_expert_tw.py +++ b/alfworld/agents/expert/handcoded_expert_tw.py @@ -14,7 +14,7 @@ def get_predicates(self, game_state, obj, parent): obs_at_curr_recep = self.obs_at_recep[self.curr_recep] if self.curr_recep in self.obs_at_recep else "" is_obj_in_obs = "you see" in obs_at_curr_recep and " {} ".format(obj) in obs_at_curr_recep at_right_recep = parent in self.curr_recep - can_put_object = "put {} in/on {}".format(obj, parent) in admissible_commands_wo_num_ids + can_put_object = "move {} to {}".format(obj, parent) in admissible_commands_wo_num_ids can_take_object = any("take {}".format(obj) in ac for ac in admissible_commands_wo_num_ids) return at_right_recep, can_put_object, can_take_object, is_obj_in_obs @@ -35,7 +35,7 @@ def get_predicates(self, game_state, obj, parent): obs_at_curr_recep = self.obs_at_recep[self.curr_recep] if self.curr_recep in self.obs_at_recep else "" is_obj_in_obs = "you see" in obs_at_curr_recep and " {} ".format(obj) in obs_at_curr_recep at_right_recep = parent in self.curr_recep - can_put_object = "put {} in/on {}".format(obj, parent) in admissible_commands_wo_num_ids + can_put_object = "move {} to {}".format(obj, parent) in admissible_commands_wo_num_ids can_take_object = any("take {}".format(obj) in ac for ac in admissible_commands_wo_num_ids) return at_right_recep, can_put_object, can_take_object, is_obj_in_obs, is_one_object_already_inside_receptacle, trying_to_take_the_same_object @@ -68,7 +68,7 @@ def get_predicates(self, game_state, obj, parent): obs_at_curr_recep = self.obs_at_recep[self.curr_recep] if self.curr_recep in self.obs_at_recep else "" is_obj_in_obs = "you see" in obs_at_curr_recep and " {} ".format(obj) in obs_at_curr_recep at_right_recep = parent in self.curr_recep - can_put_object = "put {} in/on {}".format(obj, parent) in admissible_commands_wo_num_ids + can_put_object = "move {} to {}".format(obj, parent) in admissible_commands_wo_num_ids can_take_object = any("take {}".format(obj) in ac for ac in admissible_commands_wo_num_ids) can_heat_object = "heat {} with {}".format(obj, "microwave") in admissible_commands_wo_num_ids return at_right_recep, can_heat_object, can_put_object, can_take_object, is_obj_in_obs, is_the_object_agent_holding_hot @@ -87,7 +87,7 @@ def get_predicates(self, game_state, obj, parent): obs_at_curr_recep = self.obs_at_recep[self.curr_recep] if self.curr_recep in self.obs_at_recep else "" is_obj_in_obs = "you see" in obs_at_curr_recep and " {} ".format(obj) in obs_at_curr_recep at_right_recep = parent in self.curr_recep - can_put_object = "put {} in/on {}".format(obj, parent) in admissible_commands_wo_num_ids + can_put_object = "move {} to {}".format(obj, parent) in admissible_commands_wo_num_ids can_cool_object = "cool {} with {}".format(obj, "fridge") in admissible_commands_wo_num_ids can_take_object = any("take {}".format(obj) in ac for ac in admissible_commands_wo_num_ids) return at_right_recep, can_cool_object, can_put_object, can_take_object, is_obj_in_obs, is_the_object_agent_holding_cool @@ -106,7 +106,7 @@ def get_predicates(self, game_state, obj, parent): obs_at_curr_recep = self.obs_at_recep[self.curr_recep] if self.curr_recep in self.obs_at_recep else "" is_obj_in_obs = "you see" in obs_at_curr_recep and " {} ".format(obj) in obs_at_curr_recep at_right_recep = parent in self.curr_recep - can_put_object = "put {} in/on {}".format(obj, parent) in admissible_commands_wo_num_ids + can_put_object = "move {} to {}".format(obj, parent) in admissible_commands_wo_num_ids can_clean_object = "clean {} with {}".format(obj, "sinkbasin") in admissible_commands_wo_num_ids can_take_object = any("take {}".format(obj) in ac for ac in admissible_commands_wo_num_ids) return at_right_recep, can_clean_object, can_put_object, can_take_object, is_obj_in_obs, is_the_object_agent_holding_isclean @@ -125,4 +125,4 @@ def get_task_policy(self, task_param): if task_class_str in globals(): return globals()[task_class_str] else: - raise Exception("Invalid Task Type: %s" % task_type) \ No newline at end of file + raise Exception("Invalid Task Type: %s" % task_type) diff --git a/alfworld/agents/utils/misc.py b/alfworld/agents/utils/misc.py index 632c21f..3d0297e 100644 --- a/alfworld/agents/utils/misc.py +++ b/alfworld/agents/utils/misc.py @@ -154,7 +154,7 @@ def extract_admissible_commands_with_heuristics(intro, frame_desc, feedback, "open {recep}", "close {recep}", "take {obj} from {recep}", - "put {obj} in/on {recep}", + "move {obj} to {recep}", "use {lamp}", "heat {obj} with {microwave}", "cool {obj} with {fridge}", @@ -162,6 +162,7 @@ def extract_admissible_commands_with_heuristics(intro, frame_desc, feedback, "slice {obj} with {knife}", "inventory", "look", + "help", "examine {obj}", "examine {recep}" ] @@ -182,7 +183,7 @@ def extract_admissible_commands_with_heuristics(intro, frame_desc, feedback, for obj in objects: if 'desklamp' not in obj and 'floorlamp' not in obj: admissible_commands.append(t.format(recep=at_recep, obj=obj)) - elif 'put {obj} in/on {recep}' in t: + elif 'move {obj} to {recep}' in t: if in_inv and at_recep: admissible_commands.append(t.format(recep=at_recep, obj=in_inv)) elif '{obj}' in t and '{microwave}' in t: @@ -246,7 +247,7 @@ def extract_admissible_commands(intro, frame_desc): "open {recep}", "close {recep}", "take {obj} from {recep}", - "put {obj} in/on {recep}", + "move {obj} to {recep}", "use {lamp}", "heat {obj} with {microwave}", "cool {obj} with {fridge}", @@ -254,6 +255,7 @@ def extract_admissible_commands(intro, frame_desc): "slice {obj} with {knife}", "inventory", "look", + "help", "examine {obj}", "examine {recep}" ] diff --git a/alfworld/data/alfred.pddl b/alfworld/data/alfred.pddl index b087c0e..9c0d812 100644 --- a/alfworld/data/alfred.pddl +++ b/alfworld/data/alfred.pddl @@ -495,4 +495,14 @@ ) +(:action help + :parameters (?a - agent) + :precondition + () + :effect + (and + (checked ?a) + ) +) + ) diff --git a/alfworld/data/alfred.twl2 b/alfworld/data/alfred.twl2 index d391142..47a7c9f 100644 --- a/alfworld/data/alfred.twl2 +++ b/alfworld/data/alfred.twl2 @@ -20,7 +20,7 @@ grammar :: """ "GotoLocation.feedback": [ { - "rhs": "You arrive at {lend.name}. #examineReceptacle.feedback#" + "rhs": "You arrive at {r.name}. #examineReceptacle.feedback#" } ], @@ -62,7 +62,7 @@ grammar :: """ "PutObject.feedback": [ { - "rhs": "You put the {o.name} in/on the {r.name}." + "rhs": "You move the {o.name} to the {r.name}." } ], @@ -208,7 +208,7 @@ action PickupFullReceptacleObject { } action PutObject { - template :: "put {o} in/on {r}"; + template :: "move {o} to {r}"; feedback :: "#PutObject.feedback#"; } @@ -218,7 +218,7 @@ action PutObjectInReceptacleObject { } action PutEmptyReceptacleObjectinReceptacle { - template :: "put {o} in/on {r}"; + template :: "move {o} to {r}"; feedback :: "#PutEmptyReceptacleObjectinReceptacle.feedback#"; } @@ -420,3 +420,18 @@ action look { } """; } + +action help { + template :: "help"; + feedback :: "#help.feedback#"; + + grammar :: """ + { + "help.feedback": [ + { + "rhs": "\nAvailable commands:\n look: look around your current location\n inventory: check your current inventory\n go to (receptacle): move to a receptacle\n open (receptacle): open a receptacle\n close (receptacle): close a receptacle\n take (object) from (receptacle): take an object from a receptacle\n move (object) to (receptacle): place an object in or on a receptacle\n examine (something): examine a receptacle or an object\n use (object): use an object\n heat (object) with (receptacle): heat an object using a receptacle\n clean (object) with (receptacle): clean an object using a receptacle\n cool (object) with (receptacle): cool an object using a receptacle\n slice (object) with (object): slice an object using a sharp object\n" + } + ] + } + """; +} diff --git a/alfworld/env/thor_env.py b/alfworld/env/thor_env.py index 3ee602f..4e3ad4e 100644 --- a/alfworld/env/thor_env.py +++ b/alfworld/env/thor_env.py @@ -1,10 +1,14 @@ -import cv2 import copy import os import glob import numpy as np from collections import Counter, OrderedDict -from ai2thor.controller import Controller + +try: + from ai2thor.controller import Controller + import cv2 +except ImportError: + raise ImportError("ai2thor or cv2 not found. Please install them via `pip install alfworld[vis]`.") import alfworld.gen.constants as constants from alfworld.env.tasks import get_task diff --git a/alfworld/info.py b/alfworld/info.py index f5be671..83a5066 100644 --- a/alfworld/info.py +++ b/alfworld/info.py @@ -1,4 +1,4 @@ -__version__ = '0.3.5' +__version__ = '0.4.0' import os from os.path import join as pjoin diff --git a/requirements-vis.txt b/requirements-vis.txt new file mode 100644 index 0000000..fb91442 --- /dev/null +++ b/requirements-vis.txt @@ -0,0 +1,7 @@ +ai2thor==2.1.0 +opencv-python +pandas +torch +torchvision +tqdm +werkzeug==2.0.3 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8b4a6fb..843e46c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1 @@ -ai2thor==2.1.0 -opencv-python -pandas -textworld[pddl]>=1.6.1 -torch -torchvision -tqdm -werkzeug==2.0.3 \ No newline at end of file +textworld[pddl]>=1.6.1 \ No newline at end of file diff --git a/scripts/README.md b/scripts/README.md index aa072c8..c389f90 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -45,7 +45,7 @@ You heat the potato 1 using the microwave 1. You arrive at loc 8. On the sinkbasin 1, you see a knife 3, a egg 2, and a dishsponge 3. -> put potato 1 in/on sinkbasin 1 +> move potato 1 to sinkbasin 1 You won! ``` diff --git a/scripts/alfworld-download b/scripts/alfworld-download index ae9ec3e..f1721ea 100755 --- a/scripts/alfworld-download +++ b/scripts/alfworld-download @@ -17,7 +17,7 @@ from alfworld.info import ALFRED_PDDL_PATH, ALFRED_TWL2_PATH JSON_FILES_URL = "https://github.com/alfworld/alfworld/releases/download/0.2.2/json_2.1.1_json.zip" PDDL_FILES_URL = "https://github.com/alfworld/alfworld/releases/download/0.2.2/json_2.1.1_pddl.zip" -TW_PDDL_FILES_URL = "https://github.com/alfworld/alfworld/releases/download/0.2.2/json_2.1.1_tw-pddl.zip" +TW_PDDL_FILES_URL = "https://github.com/alfworld/alfworld/releases/download/0.4.0/json_2.1.2_tw-pddl.zip" MRCNN_URL = "https://github.com/alfworld/alfworld/releases/download/0.2.2/mrcnn_alfred_objects_sep13_004.pth" CHECKPOINTS_URL = "https://github.com/alfworld/alfworld/releases/download/0.2.2/pretrained_checkpoints.zip" SEQ2SEQ_DATA_URL = "https://github.com/alfworld/alfworld/releases/download/0.2.2/seq2seq_data.zip" diff --git a/scripts/alfworld-play-thor b/scripts/alfworld-play-thor index fc6f512..22caecf 100755 --- a/scripts/alfworld-play-thor +++ b/scripts/alfworld-play-thor @@ -134,6 +134,10 @@ if __name__ == "__main__": # Remove problem which contains movable receptacles. problems = [p for p in problems if "movable_recep" not in p] + + if len(problems) == 0: + raise ValueError(f"Can't find problem files in {ALFWORLD_DATA}. Did you run alfworld-data?") + args.problem = os.path.dirname(random.choice(problems)) if "movable_recep" in args.problem: diff --git a/scripts/alfworld-play-tw b/scripts/alfworld-play-tw index 8e52bb0..9a014f6 100755 --- a/scripts/alfworld-play-tw +++ b/scripts/alfworld-play-tw @@ -34,7 +34,7 @@ def main(args): # dump game file gamedata = dict(**GAME_LOGIC, pddl_problem=open(pddl_file).read()) gamefile = os.path.join(os.path.dirname(pddl_file), 'game.tw-pddl') - json.dump(gamedata, open(gamefile, "w")) + #json.dump(gamedata, open(gamefile, "w")) expert = AlfredExpert(expert_type=AlfredExpertType.PLANNER) @@ -90,6 +90,10 @@ if __name__ == "__main__": # Remove problem which contains movable receptacles. problems = [p for p in problems if "movable_recep" not in p] + + if len(problems) == 0: + raise ValueError(f"Can't find problem files in {ALFWORLD_DATA}. Did you run alfworld-data?") + args.problem = os.path.dirname(random.choice(problems)) if "movable_recep" in args.problem: diff --git a/scripts/patch_tw-pddl.py b/scripts/patch_tw-pddl.py new file mode 100644 index 0000000..ac0fe5a --- /dev/null +++ b/scripts/patch_tw-pddl.py @@ -0,0 +1,71 @@ +# This script is to patch the 2.1.1 game.tw-pddl files to bring up to 2.1.2. +# This script will add a new "help" action and separate the "PutObject" action +# into two separate actions: PutObjectInContainer and PutObjectOnSupporter. +# This script will also add the corresponding grammar for the new actions. +# This script also patch the grammar to fix a typo in the go-to feedback. +# The script will create a backup of the original file before patching. + +import os +import json +from glob import glob +from os.path import join as pjoin +from tqdm import tqdm + +from alfworld.info import ALFWORLD_DATA +import os +import json + +import tqdm + + +HELP_ACTION_PDDL = """\ +(:action help + :parameters (?a - agent) + :precondition + () + :effect + (and + (checked ?a) + ) +) +""" + +HELP_GRAMMAR = """\ + +action help { + template :: "help"; + feedback :: "\nAvailable commands:\n look: look around your current location\n inventory: check your current inventory\n go to (receptacle): move to a receptacle\n open (receptacle): open a receptacle\n close (receptacle): close a receptacle\n take (object) from (receptacle): take an object from a receptacle\n move (object) to (receptacle): place an object in or on a receptacle\n examine (something): examine a receptacle or an object\n use (object): use an object\n heat (object) with (receptacle): heat an object using a receptacle\n clean (object) with (receptacle): clean an object using a receptacle\n cool (object) with (receptacle): cool an object using a receptacle\n slice (object) with (object): slice an object using a sharp object\n"; +} +""" + +def patch_twpddl(filename): + with open(filename, "r") as f: + data = json.load(f) + + # Make backup if doesn't exist + if not os.path.exists(filename + ".bak"): + with open(filename + ".bak", "w") as f: + json.dump(data, f) + + # Always start from backup. + with open(filename + ".bak") as f: + data = json.load(f) + + # Patch domain pddl. + before, after = data["pddl_domain"].rsplit(")", 1) + data["pddl_domain"] = f"{before}{HELP_ACTION_PDDL}){after}" + + # Patch grammar. + data["grammar"] = data["grammar"].replace("You arrive at {lend.name}.", "You arrive at {r.name}.") + data["grammar"] = data["grammar"].replace("put {o} in/on {r}", "move {o} to {r}") + data["grammar"] = data["grammar"].replace("You put the {o.name} in/on the {r.name}.", "You move the {o.name} to the {r.name}.") + data["grammar"] += HELP_GRAMMAR + + with open(filename, "w") as f: + json.dump(data, f) + + +filenames = glob(pjoin(ALFWORLD_DATA, "json_2.1.1/**/**/**/*.tw-pddl")) + +for filename in tqdm(filenames): + patch_twpddl(filename) diff --git a/setup.py b/setup.py index 23223bc..7ece51e 100644 --- a/setup.py +++ b/setup.py @@ -26,5 +26,6 @@ install_requires=open('requirements.txt').readlines(), extras_require={ 'full': open('requirements-full.txt').readlines(), + 'vis': open('requirements-vis.txt').readlines(), } )