From b89ed6765ccd7ae8af510953f02a7ab24376a0fc Mon Sep 17 00:00:00 2001 From: HakiRose Date: Fri, 7 Feb 2020 20:00:50 -0500 Subject: [PATCH 1/5] New project to implement the EventAction class --- textworld/generator/game.py | 179 +++++++++++++++---- textworld/generator/inform7/world2inform7.py | 69 +++++-- textworld/generator/maker.py | 27 +++ textworld/logic/__init__.py | 48 ++++- 4 files changed, 277 insertions(+), 46 deletions(-) diff --git a/textworld/generator/game.py b/textworld/generator/game.py index e7feb9b8..53f3998d 100644 --- a/textworld/generator/game.py +++ b/textworld/generator/game.py @@ -39,6 +39,12 @@ def __init__(self): super().__init__(msg) +class UnderspecifiedEventActionError(NameError): + def __init__(self): + msg = "No action is defined, action is required to create an event." + super().__init__(msg) + + class UnderspecifiedQuestError(NameError): def __init__(self): msg = "At least one winning or failing event is needed to create a quest." @@ -68,13 +74,13 @@ class Event: """ Event happening in TextWorld. - An event gets triggered when its set of conditions become all statisfied. + An event gets triggered when its set of conditions become all satisfied. Attributes: actions: Actions to be performed to trigger this event commands: Human readable version of the actions. condition: :py:class:`textworld.logic.Action` that can only be applied - when all conditions are statisfied. + when all conditions are satisfied. """ def __init__(self, actions: Iterable[Action] = (), @@ -111,6 +117,7 @@ def commands(self, commands: Iterable[str]) -> None: def is_triggering(self, state: State) -> bool: """ Check if this event would be triggered in a given state. """ + return state.is_applicable(self.condition) def set_conditions(self, conditions: Iterable[Proposition]) -> Action: @@ -167,6 +174,7 @@ def serialize(self) -> Mapping: `Event`'s data serialized to be JSON compatible. """ data = {} + # data["class"] = "Event" data["commands"] = self.commands data["actions"] = [action.serialize() for action in self.actions] data["condition"] = self.condition.serialize() @@ -177,6 +185,85 @@ def copy(self) -> "Event": return self.deserialize(self.serialize()) +class EventAction: + def __init__(self, action: Iterable[Action] = (), precond_verb_tense: dict = (), postcond_verb_tense: dict = ()) -> None: + self.verb_tense_precond = precond_verb_tense + self.verb_tense_postcond = postcond_verb_tense + self.event, self.actions = self.set_actions(action) + + def set_parameters(self, output: dict, acts: Iterable[Proposition], verbs: Iterable[dict]): + def tense(val): + if val == 1: + return 'will' + elif val == 0: + return 'is' + elif val == -1: + return 'was' + elif val == -2: + return 'has been' + elif val == -3: + return 'had been' + + if not verbs: + return output + + for prop in acts: + if prop.name in verbs.keys(): + output['name'].append(prop.name) + [output['argument'].append(v) for v in prop.arguments] + output['verb_val'].append(verbs[prop.name]) + output['verb_def'].append(tense(verbs[prop.name])) + + return output + + def set_actions(self, action: Iterable[Action]): + # tp_action = action + tp_action = [a for a in action] + params = {'name': [], 'argument': [], 'verb_val': [], 'verb_def': []} + params = self.set_parameters(params, tp_action[0].removed, self.verb_tense_precond) + params = self.set_parameters(params, tp_action[0].added, self.verb_tense_postcond) + event = Proposition("event", arguments=params['argument'], definition=params['name'], + verb_var=params['verb_val'], verb_def=params['verb_def']) + return event, action + + def is_triggering(self, action: Action) -> bool: + """ Check if this event would be triggered for a given action. """ + return action == [a for a in self.actions][0] + + @classmethod + def deserialize(cls, data: Mapping) -> "EventAction": + """ Creates an `EventAction` from serialized data. + + Args: + data: Serialized data with the needed information to build a + `EventAction` object. + """ + actions = [Action.deserialize(d) for d in data["actions"]] + event = cls(actions, data["precond_verb_tense"], data["postcond_verb_tense"]) + return event + + def serialize(self) -> Mapping: + """ Serialize this event. + + Results: + `EventAction`'s data serialized to be JSON compatible. + """ + return {"actions": [action.serialize() for action in self.actions], + "precond_verb_tense": self.verb_tense_precond, + "postcond_verb_tense": self.verb_tense_postcond, + } + + def __hash__(self) -> int: + return hash((self.actions, self.event, self.verb_tense_precond, self.verb_tense_postcond)) + + def __eq__(self, other: Any) -> bool: + return (isinstance(other, EventAction) and + self.actions == other.actions and + self.event == other.event and + self.verb_tense_precond == other.verb_tense_precond and + self.verb_tense_postcond == other.verb_tense_postcond) + + class Quest: """ Quest representation in TextWorld. @@ -196,8 +283,10 @@ class Quest: """ def __init__(self, - win_events: Iterable[Event] = (), - fail_events: Iterable[Event] = (), + # win_events: Iterable[Event] = (), + # fail_events: Iterable[Event] = (), + win_events: Iterable = (), + fail_events: Iterable = (), reward: Optional[int] = None, desc: Optional[str] = None, commands: Iterable[str] = ()) -> None: @@ -279,8 +368,22 @@ def deserialize(cls, data: Mapping) -> "Quest": data: Serialized data with the needed information to build a `Quest` object. """ - win_events = [Event.deserialize(d) for d in data["win_events"]] - fail_events = [Event.deserialize(d) for d in data["fail_events"]] + win_events = [] + for d in data["win_events"]: + if "precond_verb_tense" in d.keys(): + win_events.append(EventAction.deserialize(d)) + + if "condition" in d.keys(): + win_events.append(Event.deserialize(d)) + + fail_events = [] + for d in data["fail_events"]: + if "precond_verb_tense" in d.keys(): + fail_events.append(EventAction.deserialize(d)) + + if "condition" in d.keys(): + fail_events.append(Event.deserialize(d)) + commands = data.get("commands", []) reward = data["reward"] desc = data["desc"] @@ -426,6 +529,8 @@ def change_grammar(self, grammar: Grammar) -> None: generate_text_from_grammar(self, self.grammar) for quest in self.quests: + # TODO: should have a generic way of generating text commands from actions + # instead of relying on inform7 convention. for event in quest.win_events: event.commands = _gen_commands(event.actions) @@ -578,21 +683,6 @@ def objective(self) -> str: def objective(self, value: str): self._objective = value - @property - def walkthrough(self) -> Optional[List[str]]: - walkthrough = self.metadata.get("walkthrough") - if walkthrough: - return walkthrough - - # Check if we can derive a walkthrough from the quests. - policy = GameProgression(self).winning_policy - if policy: - mapping = {k: info.name for k, info in self._infos.items()} - walkthrough = [a.format_command(mapping) for a in policy] - self.metadata["walkthrough"] = walkthrough - - return walkthrough - class ActionDependencyTreeElement(DependencyTreeElement): """ Representation of an `Action` in the dependency tree. @@ -613,7 +703,11 @@ def depends_on(self, other: "ActionDependencyTreeElement") -> bool: of the action1 is not empty, i.e. action1 needs the propositions added by action2. """ - return len(other.action.added & self.action._pre_set) > 0 + if isinstance(self.action, frozenset): + act = d = [a for a in self.action][0] + else: + act = self.action + return len(other.action.added & act._pre_set) > 0 @property def action(self) -> Action: @@ -722,7 +816,7 @@ class EventProgression: relevant actions to be performed. """ - def __init__(self, event: Event, kb: KnowledgeBase) -> None: + def __init__(self, event, kb: KnowledgeBase) -> None: """ Args: quest: The quest to keep track of its completion. @@ -737,13 +831,18 @@ def __init__(self, event: Event, kb: KnowledgeBase) -> None: self._tree = ActionDependencyTree(kb=self._kb, element_type=ActionDependencyTreeElement) - if len(event.actions) > 0: - self._tree.push(event.condition) + if isinstance(event, Event): + if len(event.actions) > 0: + self._tree.push(event.condition) + + for action in event.actions[::-1]: + self._tree.push(action) - for action in event.actions[::-1]: - self._tree.push(action) + self._policy = event.actions + (event.condition,) - self._policy = event.actions + (event.condition,) + # if isinstance(event, EventAction): + # self._tree.push([a for a in event.actions][0]) + # self._policy = event.actions def copy(self) -> "EventProgression": """ Return a soft copy. """ @@ -790,7 +889,12 @@ def update(self, action: Optional[Action] = None, state: Optional[State] = None) if state is not None: # Check if event is triggered. - self._triggered = self.event.is_triggering(state) + + if isinstance(self.event, Event): + self._triggered = self.event.is_triggering(state) + + if isinstance(self.event, EventAction): + self._triggered = self.event.is_triggering(action) # Try compressing the winning policy given the new game state. if self.compress_policy(state): @@ -941,7 +1045,6 @@ def __init__(self, game: Game, track_quests: bool = True) -> None: self.state = game.world.state.copy() self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), self.game.kb.types.constants_mapping)) - self.quest_progressions = [] if track_quests: self.quest_progressions = [QuestProgression(quest, game.kb) for quest in game.quests] @@ -1033,11 +1136,23 @@ def update(self, action: Action) -> None: # Get valid actions. self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), self.game.kb.types.constants_mapping)) - # Update all quest progressions given the last action and new state. for quest_progression in self.quest_progressions: quest_progression.update(action, self.state) + for quest_progression in self.quest_progressions: + for win_event in quest_progression.win_events: + if quest_progression.quest.reward >= 0: + if isinstance(win_event.event, Event): + self.state.apply(win_event.event.condition) + if isinstance(win_event.event, EventAction): + propos = [prop for prop in win_event.event.actions[0].added] + self.state.apply(Action("trigger", preconditions=propos, + postconditions=propos + [win_event.event.event])) + + self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), + self.game.kb.types.constants_mapping)) + class GameOptions: """ @@ -1162,7 +1277,7 @@ def _key_missing(seeds): @property def rngs(self) -> Dict[str, RandomState]: rngs = {} - for key, seed in self.seeds.items(): + for key, seed in self._seeds.items(): rngs[key] = RandomState(seed) return rngs diff --git a/textworld/generator/inform7/world2inform7.py b/textworld/generator/inform7/world2inform7.py index 45669e22..4877f43b 100644 --- a/textworld/generator/inform7/world2inform7.py +++ b/textworld/generator/inform7/world2inform7.py @@ -15,7 +15,7 @@ from textworld.utils import make_temp_directory, str2bool, chunk -from textworld.generator.game import Game +from textworld.generator.game import Event, EventAction, Game from textworld.generator.world import WorldRoom, WorldEntity from textworld.logic import Signature, Proposition, Action, Variable @@ -121,6 +121,26 @@ def gen_source_for_conditions(self, conds: Iterable[Proposition]) -> str: return " and ".join(i7_conds) + def gen_source_for_rule(self, rule: Action) -> Optional[str]: + pt = self.kb.inform7_events[rule.name] + if pt is None: + msg = "Undefined Inform7's command: {}".format(rule.name) + warnings.warn(msg, TextworldInform7Warning) + return None + + return pt.format(**self._get_entities_mapping(rule)) + + def gen_source_for_actions(self, acts: Iterable[Action]) -> str: + """Generate Inform 7 source for winning/losing actions.""" + + i7_acts = [] + for act in acts: + i7_act = self.gen_source_for_rule(act) + if i7_act: + i7_acts.append(i7_act) + + return " and ".join(i7_acts) + def gen_source_for_objects(self, objects: Iterable[WorldEntity]) -> str: source = "" for obj in objects: @@ -195,6 +215,10 @@ def _get_name_mapping(self, action): mapping = self.kb.rules[action.name].match(action) return {ph.name: self.entity_infos[var.name].name for ph, var in mapping.items()} + def _get_entities_mapping(self, action): + mapping = self.kb.rules[action.name].match(action) + return {ph.name: self.entity_infos[var.name].id for ph, var in mapping.items()} + def gen_commands_from_actions(self, actions: Iterable[Action]) -> List[str]: commands = [] for action in actions: @@ -239,6 +263,8 @@ def detect_action(self, i7_event: str, actions: Iterable[Action]) -> Optional[Ac """ # Prioritze actions with many precondition terms. actions = sorted(actions, key=lambda a: len(a.preconditions), reverse=True) + from pprint import pprint + pprint(actions) for action in actions: event = self.kb.inform7_events[action.name] if event.format(**self._get_name_mapping(action)).lower() == i7_event.lower(): @@ -312,6 +338,7 @@ def gen_source(self, seed: int = 1234) -> str: objective = self.game.objective.replace("\n", "[line break]") maximum_score = 0 + wining = 0 for quest_id, quest in enumerate(self.game.quests): maximum_score += quest.reward @@ -337,20 +364,39 @@ def gen_source(self, seed: int = 1234) -> str: else if {conditions}: end the story; [Lost]""") - win_template = textwrap.dedent(""" + win_template_state = textwrap.dedent(""" else if {conditions}: increase the score by {reward}; [Quest completed] Now the quest{quest_id} completed is true;""") + win_template_action = textwrap.dedent(""" + else: + After {conditions}: + increase the score by {reward}; [Quest completed] + Now the quest{quest_id} completed is true;""") + for fail_event in quest.fail_events: - conditions = self.gen_source_for_conditions(fail_event.condition.preconditions) + if isinstance(fail_event, Event): + param = fail_event.condition + if isinstance(fail_event, EventAction): + param = [act for act in fail_event.actions][0] + + conditions = self.gen_source_for_conditions(param.preconditions) quest_ending_conditions += fail_template.format(conditions=conditions) for win_event in quest.win_events: - conditions = self.gen_source_for_conditions(win_event.condition.preconditions) - quest_ending_conditions += win_template.format(conditions=conditions, - reward=quest.reward, - quest_id=quest_id) + if isinstance(win_event, Event): + conditions = self.gen_source_for_conditions(win_event.condition.preconditions) + quest_ending_conditions += win_template_state.format(conditions=conditions, + reward=quest.reward, + quest_id=quest_id) + + if isinstance(win_event, EventAction): + conditions = self.gen_source_for_actions([act for act in win_event.actions]) + quest_ending_conditions += win_template_action.format(conditions=conditions, + reward=quest.reward, + quest_id=quest_id) + wining += 1 quest_ending = """\ Every turn:\n{conditions} @@ -359,13 +405,15 @@ def gen_source(self, seed: int = 1234) -> str: source += textwrap.dedent(quest_ending) # Enable scoring is at least one quest has nonzero reward. - if maximum_score != 0: + if maximum_score >= 0: source += "Use scoring. The maximum score is {}.\n".format(maximum_score) + # Build test condition for winning the game. game_winning_test = "1 is 0 [always false]" - if len(self.game.quests) > 0: - game_winning_test = "score is maximum score" + if wining > 0: + if maximum_score != 0: + game_winning_test = "score is at least maximum score" # Remove square bracket when printing score increases. Square brackets are conflicting with # Inform7's events parser in tw_inform7.py. @@ -383,6 +431,7 @@ def gen_source(self, seed: int = 1234) -> str: if {game_winning_test}: end the story finally; [Win] + The simpler notify score changes rule substitutes for the notify score changes rule. """.format(game_winning_test=game_winning_test)) diff --git a/textworld/generator/maker.py b/textworld/generator/maker.py index 1b89f60c..9dcfd0d7 100644 --- a/textworld/generator/maker.py +++ b/textworld/generator/maker.py @@ -682,6 +682,33 @@ def new_fact(self, name: str, *entities: List["WorldEntity"]) -> None: args = [entity.var for entity in entities] return Proposition(name, args) + def new_rule_fact(self, name: str, *entities: List["WorldEntity"]) -> None: + """ Create new fact about a rule. + + Args: + name: The name of the rule which can be used for the new rule fact as well. + *entities: A list of entities as arguments to the new rule fact. + """ + + def new_conditions(conditions, args): + new_ph = [] + for pred in conditions: + new_var = [var for ph in pred.parameters for var in args if ph.type == var.type] + new_ph.append(Proposition(pred.name, new_var)) + + return new_ph + + args = [entity.var for entity in entities] + + for rule in self._kb.rules.values(): + if rule.name == name.name: + precond = new_conditions(rule.preconditions, args) + postcond = new_conditions(rule.postconditions, args) + + return Action(rule.name, precond, postcond) + + return None + def new_event_using_commands(self, commands: List[str]) -> Event: """ Creates a new event using predefined text commands. diff --git a/textworld/logic/__init__.py b/textworld/logic/__init__.py index 007a684f..c5bf4966 100644 --- a/textworld/logic/__init__.py +++ b/textworld/logic/__init__.py @@ -615,9 +615,10 @@ class Proposition(with_metaclass(PropositionTracker, object)): An instantiated Predicate, with concrete variables for each placeholder. """ - __slots__ = ("name", "arguments", "signature", "_hash") + __slots__ = ("name", "arguments", "signature", "_hash", "definition", "verb_var", "verb_def") - def __init__(self, name: str, arguments: Iterable[Variable] = []): + def __init__(self, name: str, arguments: Iterable[Variable] = [], definition: str = None, + verb_var: int = None, verb_def: str = None): """ Create a Proposition. @@ -633,6 +634,9 @@ def __init__(self, name: str, arguments: Iterable[Variable] = []): self.arguments = tuple(arguments) self.signature = Signature(name, [var.type for var in self.arguments]) self._hash = hash((self.name, self.arguments)) + self.definition = definition + self.verb_var = verb_var + self.verb_def = verb_def @property def names(self) -> Collection[str]: @@ -648,11 +652,47 @@ def types(self) -> Collection[str]: """ return self.signature.types + def make_str(self, max_arg=False): + args = [v for v in self.arguments] + txt = [] + for i in range(len(args)): + if max_arg: + txt.append("({})".format(", ".join(map(str, [args[i], self.verb_def[i], self.definition[i]])))) + else: + txt.append("({})".format(", ".join(map(str, [args[i], self.definition[i]])))) + + return "{}".format(", ".join(txt)) + def __str__(self): - return "{}({})".format(self.name, ", ".join(map(str, self.arguments))) + def make_str(max_arg=False): + args = [v for v in self.arguments] + txt = [] + for i in range(len(args)): + if max_arg: + txt.append("({})".format(", ".join(map(str, [args[i], self.verb_def[i], self.definition[i]])))) + else: + txt.append("({})".format(", ".join(map(str, [args[i], self.definition[i]])))) + + return "{}".format(", ".join(txt)) + + if self.definition and self.verb_def: + return "{}({txt})".format(self.name, txt=make_str(max_arg=True)) + + elif self.definition: + return "{}({txt})".format(self.name, txt=make_str()) + + else: + return "{}({})".format(self.name, ", ".join(map(str, self.arguments))) def __repr__(self): - return "Proposition({!r}, {!r})".format(self.name, self.arguments) + if self.definition and self.verb: + return "Proposition({!r}, {!r}, {!r}, {!r})".format(self.name, self.arguments, self.definition, + self.verb_def) + elif self.definition: + return "Proposition({!r}, {!r}, {!r})".format(self.name, self.arguments, self.definition) + + else: + return "Proposition({!r}, {!r})".format(self.name, self.arguments) def __eq__(self, other): if isinstance(other, Proposition): From 0570faa31ac81ad1798c53315956083fe2b42770 Mon Sep 17 00:00:00 2001 From: HakiRose Date: Fri, 28 Feb 2020 13:16:41 -0500 Subject: [PATCH 2/5] The new updates over the NEW STYLE of the FRAMEWORK after the rebase, the fully operating structure of the FW --- .gitignore | 3 +- .../tw_coin_collector/coin_collector.py | 21 +- .../textworld_data/logic/player.twl | 4 - .../tw_simple/textworld_data/logic/key.twl | 44 +-- .../tw_simple/textworld_data/logic/room.twl | 17 +- textworld/generator/game.py | 302 +++++++++++------- textworld/logic/__init__.py | 231 ++++++++++---- 7 files changed, 398 insertions(+), 224 deletions(-) diff --git a/.gitignore b/.gitignore index c2aaddf5..6632a7cd 100644 --- a/.gitignore +++ b/.gitignore @@ -24,5 +24,4 @@ tmp/* *.ipynb_checkpoints /dist /wheelhouse -docs/build -docs/src +*.orig diff --git a/textworld/challenges/tw_coin_collector/coin_collector.py b/textworld/challenges/tw_coin_collector/coin_collector.py index d916d27b..243f66ce 100644 --- a/textworld/challenges/tw_coin_collector/coin_collector.py +++ b/textworld/challenges/tw_coin_collector/coin_collector.py @@ -15,23 +15,17 @@ other than the coin to collect. """ -import os import argparse -from os.path import join as pjoin from typing import Mapping, Optional, Any import textworld from textworld.generator.graph_networks import reverse_direction from textworld.utils import encode_seeds -from textworld.generator.data import KnowledgeBase -from textworld.generator.game import GameOptions, Quest, Event +from textworld.generator.game import GameOptions, Quest, EventCondition from textworld.challenges import register -KB_PATH = pjoin(os.path.dirname(__file__), "textworld_data") - - def build_argparser(parser=None): parser = parser or argparse.ArgumentParser() @@ -39,6 +33,9 @@ def build_argparser(parser=None): group.add_argument("--level", required=True, type=int, help="The difficulty level. Must be between 1 and 300 (included).") + group.add_argument("--force-entity-numbering", required=True, action="store_true", + help="This will set `--entity-numbering` to be True which is required for this challenge.") + return parser @@ -52,7 +49,7 @@ def make(settings: Mapping[str, Any], options: Optional[GameOptions] = None) -> :py:class:`textworld.GameOptions ` for the list of available options). - .. warning:: This challenge enforces `options.grammar.allowed_variables_numbering` to be `True`. + .. warning:: This challenge requires `options.grammar.allowed_variables_numbering` to be `True`. Returns: Generated game. @@ -71,11 +68,9 @@ def make(settings: Mapping[str, Any], options: Optional[GameOptions] = None) -> """ options = options or GameOptions() - # Load knowledge base specific to this challenge. - options.kb = KnowledgeBase.load(KB_PATH) - # Needed for games with a lot of rooms. - options.grammar.allowed_variables_numbering = True + options.grammar.allowed_variables_numbering = settings["force_entity_numbering"] + assert options.grammar.allowed_variables_numbering level = settings["level"] if level < 1 or level > 300: @@ -172,7 +167,7 @@ def make_game(mode: str, options: GameOptions) -> textworld.Game: # Generate the quest thats by collecting the coin. quest = Quest(win_events=[ - Event(conditions={M.new_fact("in", coin, M.inventory)}) + EventCondition(conditions={M.new_fact("in", coin, M.inventory)}) ]) M.quests = [quest] diff --git a/textworld/challenges/tw_coin_collector/textworld_data/logic/player.twl b/textworld/challenges/tw_coin_collector/textworld_data/logic/player.twl index 450e47bd..6783223b 100644 --- a/textworld/challenges/tw_coin_collector/textworld_data/logic/player.twl +++ b/textworld/challenges/tw_coin_collector/textworld_data/logic/player.twl @@ -4,10 +4,6 @@ type P { look :: at(P, r) -> at(P, r); # Nothing changes. } - reverse_rules { - look :: look; - } - inform7 { commands { look :: "look" :: "looking"; diff --git a/textworld/challenges/tw_simple/textworld_data/logic/key.twl b/textworld/challenges/tw_simple/textworld_data/logic/key.twl index ff6d0499..c7da05e5 100644 --- a/textworld/challenges/tw_simple/textworld_data/logic/key.twl +++ b/textworld/challenges/tw_simple/textworld_data/logic/key.twl @@ -1,25 +1,25 @@ -# key -type k : o { - predicates { - match(k, c); - match(k, d); - } +# # key +# type k : o { +# predicates { +# match(k, c); +# match(k, d); +# } - constraints { - k1 :: match(k, c) & match(k', c) -> fail(); - k2 :: match(k, c) & match(k, c') -> fail(); - k3 :: match(k, d) & match(k', d) -> fail(); - k4 :: match(k, d) & match(k, d') -> fail(); - } +# constraints { +# k1 :: match(k, c) & match(k', c) -> fail(); +# k2 :: match(k, c) & match(k, c') -> fail(); +# k3 :: match(k, d) & match(k', d) -> fail(); +# k4 :: match(k, d) & match(k, d') -> fail(); +# } - inform7 { - type { - kind :: "key"; - } +# inform7 { +# type { +# kind :: "key"; +# } - predicates { - match(k, c) :: "The matching key of the {c} is the {k}"; - match(k, d) :: "The matching key of the {d} is the {k}"; - } - } -} +# predicates { +# match(k, c) :: "The matching key of the {c} is the {k}"; +# match(k, d) :: "The matching key of the {d} is the {k}"; +# } +# } +# } diff --git a/textworld/challenges/tw_simple/textworld_data/logic/room.twl b/textworld/challenges/tw_simple/textworld_data/logic/room.twl index 58715688..62bde7f0 100644 --- a/textworld/challenges/tw_simple/textworld_data/logic/room.twl +++ b/textworld/challenges/tw_simple/textworld_data/logic/room.twl @@ -7,16 +7,15 @@ type r { north_of(r, r); west_of(r, r); - north_of/d(r, d, r); - west_of/d(r, d, r); - free(r, r); south_of(r, r') = north_of(r', r); east_of(r, r') = west_of(r', r); - south_of/d(r, d, r') = north_of/d(r', d, r); - east_of/d(r, d, r') = west_of/d(r', d, r); + # north_of/d(r, d, r); + # west_of/d(r, d, r); + # south_of/d(r, d, r') = north_of/d(r', d, r); + # east_of/d(r, d, r') = west_of/d(r', d, r); } rules { @@ -66,10 +65,10 @@ type r { east_of(r, r') :: "The {r} is mapped east of {r'}"; west_of(r, r') :: "The {r} is mapped west of {r'}"; - north_of/d(r, d, r') :: "South of {r} and north of {r'} is a door called {d}"; - south_of/d(r, d, r') :: "North of {r} and south of {r'} is a door called {d}"; - east_of/d(r, d, r') :: "West of {r} and east of {r'} is a door called {d}"; - west_of/d(r, d, r') :: "East of {r} and west of {r'} is a door called {d}"; + # north_of/d(r, d, r') :: "South of {r} and north of {r'} is a door called {d}"; + # south_of/d(r, d, r') :: "North of {r} and south of {r'} is a door called {d}"; + # east_of/d(r, d, r') :: "West of {r} and east of {r'} is a door called {d}"; + # west_of/d(r, d, r') :: "East of {r} and west of {r'} is a door called {d}"; } commands { diff --git a/textworld/generator/game.py b/textworld/generator/game.py index 53f3998d..0945c194 100644 --- a/textworld/generator/game.py +++ b/textworld/generator/game.py @@ -70,9 +70,72 @@ def _get_name_mapping(action): return commands -class Event: +class PropositionControl: """ - Event happening in TextWorld. + Controlling the proposition's appearance within the game. + + When a proposition is activated in the state set, it may be important to track this event. This basically is + determined in the quest design directly or indirectly. This class manages the creation of the event propositions, + Add or Remove the event proposition from the state set, etc. + + Attributes: + + """ + + def __init__(self, props: Iterable[Proposition], verbs: dict): + + self.propositions = props + self.verbs = verbs + self.traceable_propositions, self.addon = self.set_events() + + def set_events(self): + variables = sorted(set([v for c in self.propositions for v in c.arguments])) + event = Proposition("event", arguments=variables) + + if self.verbs: + state_event = [Proposition(name=self.verbs[prop.definition].replace(' ', '_') + '__' + prop.definition, + arguments=prop.arguments, definition=prop.definition, + verb=self.verbs[prop.definition], activate=0) + for prop in self.propositions if prop.definition in self.verbs.keys()] + for p in state_event: + p.activate = 0 + else: + state_event = [] + + return state_event, event + + @classmethod + def add_propositions(cls, props: Iterable[Proposition]) -> Iterable[Proposition]: + for prop in props: + if not prop.name.startswith("is__") and (prop.verb == "has been"): + prop.activate = 1 + + return props + + @classmethod + def set_activated(cls, prop: Proposition): + if not prop.activate: + prop.activate = 1 + + @classmethod + def remove(cls, prop: Proposition, state: State): + if not prop.name.startswith('was__'): + return + + if prop.activate and (prop in state.facts): + if Proposition(prop.definition, prop.arguments) not in state.facts: + state.remove_fact(prop) + + def has_traceable(self): + for prop in self.get_facts(): + if not prop.name.startswith('is__'): + return True + return False + + +class EventCondition: + """ + EventCondition happening in TextWorld. An event gets triggered when its set of conditions become all satisfied. @@ -85,7 +148,8 @@ class Event: def __init__(self, actions: Iterable[Action] = (), conditions: Iterable[Proposition] = (), - commands: Iterable[str] = ()) -> None: + commands: Iterable[str] = (), + output_verb_tense: dict = ()) -> None: """ Args: actions: The actions to be performed to trigger this event. @@ -97,7 +161,8 @@ def __init__(self, actions: Iterable[Action] = (), """ self.actions = actions self.commands = commands - self.condition = self.set_conditions(conditions) + self.verb_tense = output_verb_tense + self.condition, self.traceable = self.set_conditions(conditions) @property def actions(self) -> Iterable[Action]: @@ -139,96 +204,96 @@ def set_conditions(self, conditions: Iterable[Proposition]) -> Action: # last action in the quest. conditions = self.actions[-1].postconditions - variables = sorted(set([v for c in conditions for v in c.arguments])) - event = Proposition("event", arguments=variables) - self.condition = Action("trigger", preconditions=conditions, - postconditions=list(conditions) + [event]) - return self.condition + event = PropositionControl(conditions, self.verb_tense) + traceable = event.traceable_propositions + condition = Action("trigger", preconditions=conditions, postconditions=list(conditions) + [event.addon]) + + # The corresponding traceable(s) should be active in state set to be considered for the event. + if condition.has_traceable(): + condition.activate_traceable() + + return condition, traceable def __hash__(self) -> int: - return hash((self.actions, self.commands, self.condition)) + return hash((self.actions, self.commands, self.condition, self.traceable, + self.verb_tense)) def __eq__(self, other: Any) -> bool: - return (isinstance(other, Event) - and self.actions == other.actions - and self.commands == other.commands - and self.condition == other.condition) + return (isinstance(other, EventCondition) and + self.actions == other.actions and + self.commands == other.commands and + self.condition == other.condition and + self.verb_tense == other.verb_tense and + self.traceable == other.traceable) @classmethod - def deserialize(cls, data: Mapping) -> "Event": - """ Creates an `Event` from serialized data. + def deserialize(cls, data: Mapping) -> "EventCondition": + """ Creates an `EventCondition` from serialized data. Args: - data: Serialized data with the needed information to build a - `Event` object. + data: Serialized data with the needed information to build a `EventCondotion` object. """ actions = [Action.deserialize(d) for d in data["actions"]] condition = Action.deserialize(data["condition"]) - event = cls(actions, condition.preconditions, data["commands"]) + event = cls(actions, condition.preconditions, data["commands"], data["output_verb_tense"]) return event def serialize(self) -> Mapping: """ Serialize this event. Results: - `Event`'s data serialized to be JSON compatible. + `EventCondition`'s data serialized to be JSON compatible. """ data = {} - # data["class"] = "Event" data["commands"] = self.commands data["actions"] = [action.serialize() for action in self.actions] data["condition"] = self.condition.serialize() + data["output_verb_tense"] = self.verb_tense return data - def copy(self) -> "Event": + def copy(self) -> "EventCondition": """ Copy this event. """ return self.deserialize(self.serialize()) class EventAction: - def __init__(self, action: Iterable[Action] = (), precond_verb_tense: dict = (), postcond_verb_tense: dict = ()) -> None: - self.verb_tense_precond = precond_verb_tense - self.verb_tense_postcond = postcond_verb_tense - self.event, self.actions = self.set_actions(action) - - def set_parameters(self, output: dict, acts: Iterable[Proposition], verbs: Iterable[dict]): - def tense(val): - if val == 1: - return 'will' - elif val == 0: - return 'is' - elif val == -1: - return 'was' - elif val == -2: - return 'has been' - elif val == -3: - return 'had been' - - if not verbs: - return output - - for prop in acts: - if prop.name in verbs.keys(): - output['name'].append(prop.name) - [output['argument'].append(v) for v in prop.arguments] - output['verb_val'].append(verbs[prop.name]) - output['verb_def'].append(tense(verbs[prop.name])) - - return output - - def set_actions(self, action: Iterable[Action]): - # tp_action = action - tp_action = [a for a in action] - params = {'name': [], 'argument': [], 'verb_val': [], 'verb_def': []} - params = self.set_parameters(params, tp_action[0].removed, self.verb_tense_precond) - params = self.set_parameters(params, tp_action[0].added, self.verb_tense_postcond) - event = Proposition("event", arguments=params['argument'], definition=params['name'], - verb_var=params['verb_val'], verb_def=params['verb_def']) - return event, action + def __init__(self, action: Iterable[Action] = (), + output_verb_tense_precond: dict = (), + output_verb_tense_postcond: dict = ()) -> None: + self.verb_tense_precond = output_verb_tense_precond + self.verb_tense_postcond = output_verb_tense_postcond + self.verb_tense = self.set_verbs() + self.actions = list(action) + self.traceable = self.set_actions() + + def set_verbs(self): + def mergeDict(dict1, dict2): + """ + Merge dictionaries and keep values of common keys in list + """ + + dict3 = {**dict1, **dict2} + for key, value in dict3.items(): + if key in dict1 and key in dict2: + dict3[key] = [value, dict1[key]] + + return dict3 + + return mergeDict(dict(self.verb_tense_precond), dict(self.verb_tense_postcond)) + + def set_actions(self): + props = [] + for p in self.actions[0].all_propositions: + if p not in props: + props.append(p) + + event = PropositionControl(props, self.verb_tense) + traceable = event.traceable_propositions + return traceable def is_triggering(self, action: Action) -> bool: """ Check if this event would be triggered for a given action. """ - return action == [a for a in self.actions][0] + return action == self.actions[0] @classmethod def deserialize(cls, data: Mapping) -> "EventAction": @@ -238,8 +303,8 @@ def deserialize(cls, data: Mapping) -> "EventAction": data: Serialized data with the needed information to build a `EventAction` object. """ - actions = [Action.deserialize(d) for d in data["actions"]] - event = cls(actions, data["precond_verb_tense"], data["postcond_verb_tense"]) + action = [Action.deserialize(d) for d in data["action"]] + event = cls(action, data["output_verb_tense_precond"], data["output_verb_tense_postcond"]) return event def serialize(self) -> Mapping: @@ -248,20 +313,25 @@ def serialize(self) -> Mapping: Results: `EventAction`'s data serialized to be JSON compatible. """ - return {"actions": [action.serialize() for action in self.actions], - "precond_verb_tense": self.verb_tense_precond, - "postcond_verb_tense": self.verb_tense_postcond, + return {"action": [action.serialize() for action in self.actions], + "output_verb_tense_precond": self.verb_tense_precond, + "output_verb_tense_postcond": self.verb_tense_postcond, } def __hash__(self) -> int: - return hash((self.actions, self.event, self.verb_tense_precond, self.verb_tense_postcond)) + return hash((self.actions, self.verb_tense, self.verb_tense_precond, self.verb_tense_postcond, self.traceable)) def __eq__(self, other: Any) -> bool: return (isinstance(other, EventAction) and self.actions == other.actions and - self.event == other.event and + self.verb_tense == other.verb_tense and self.verb_tense_precond == other.verb_tense_precond and - self.verb_tense_postcond == other.verb_tense_postcond) + self.verb_tense_postcond == other.verb_tense_postcond and + self.traceable == other.traceable) + + def copy(self) -> "EventAction": + """ Copy this event. """ + return self.deserialize(self.serialize()) class Quest: @@ -283,10 +353,8 @@ class Quest: """ def __init__(self, - # win_events: Iterable[Event] = (), - # fail_events: Iterable[Event] = (), - win_events: Iterable = (), - fail_events: Iterable = (), + win_events: Iterable[Union[EventCondition, EventAction]] = (), + fail_events: Iterable[Union[EventCondition, EventAction]] = (), reward: Optional[int] = None, desc: Optional[str] = None, commands: Iterable[str] = ()) -> None: @@ -317,19 +385,19 @@ def __init__(self, raise UnderspecifiedQuestError() @property - def win_events(self) -> Iterable[Event]: + def win_events(self) -> Iterable[EventCondition]: return self._win_events @win_events.setter - def win_events(self, events: Iterable[Event]) -> None: + def win_events(self, events: Iterable[EventCondition]) -> None: self._win_events = tuple(events) @property - def fail_events(self) -> Iterable[Event]: + def fail_events(self) -> Iterable[EventCondition]: return self._fail_events @fail_events.setter - def fail_events(self, events: Iterable[Event]) -> None: + def fail_events(self, events: Iterable[EventCondition]) -> None: self._fail_events = tuple(events) @property @@ -370,19 +438,19 @@ def deserialize(cls, data: Mapping) -> "Quest": """ win_events = [] for d in data["win_events"]: - if "precond_verb_tense" in d.keys(): + if "output_verb_tense_precond" in d.keys(): win_events.append(EventAction.deserialize(d)) if "condition" in d.keys(): - win_events.append(Event.deserialize(d)) + win_events.append(EventCondition.deserialize(d)) fail_events = [] for d in data["fail_events"]: - if "precond_verb_tense" in d.keys(): + if "output_verb_tense_precond" in d.keys(): fail_events.append(EventAction.deserialize(d)) if "condition" in d.keys(): - fail_events.append(Event.deserialize(d)) + fail_events.append(EventCondition.deserialize(d)) commands = data.get("commands", []) reward = data["reward"] @@ -545,7 +613,7 @@ def change_grammar(self, grammar: Grammar) -> None: mapping = {k: info.name for k, info in self._infos.items()} commands = [a.format_command(mapping) for a in policy] self.metadata["walkthrough"] = commands - self.objective = describe_event(Event(policy), self, self.grammar) + self.objective = describe_event(EventCondition(policy), self, self.grammar) def save(self, filename: str) -> None: """ Saves the serialized data of this game to a file. """ @@ -816,7 +884,7 @@ class EventProgression: relevant actions to be performed. """ - def __init__(self, event, kb: KnowledgeBase) -> None: + def __init__(self, event: Union[EventCondition, EventAction], kb: KnowledgeBase) -> None: """ Args: quest: The quest to keep track of its completion. @@ -831,7 +899,7 @@ def __init__(self, event, kb: KnowledgeBase) -> None: self._tree = ActionDependencyTree(kb=self._kb, element_type=ActionDependencyTreeElement) - if isinstance(event, Event): + if isinstance(event, EventCondition): if len(event.actions) > 0: self._tree.push(event.condition) @@ -840,10 +908,6 @@ def __init__(self, event, kb: KnowledgeBase) -> None: self._policy = event.actions + (event.condition,) - # if isinstance(event, EventAction): - # self._tree.push([a for a in event.actions][0]) - # self._policy = event.actions - def copy(self) -> "EventProgression": """ Return a soft copy. """ ep = EventProgression(self.event, self._kb) @@ -890,7 +954,7 @@ def update(self, action: Optional[Action] = None, state: Optional[State] = None) if state is not None: # Check if event is triggered. - if isinstance(self.event, Event): + if isinstance(self.event, EventCondition): self._triggered = self.event.is_triggering(state) if isinstance(self.event, EventAction): @@ -943,6 +1007,15 @@ def _find_shorter_policy(policy): return compressed + def will_trigger(self, state: State, action: Action): + if isinstance(self.event, EventCondition): + triggered = self.event.is_triggering(state) + + if isinstance(self.event, EventAction): + triggered = self.event.is_triggering(action) + + return triggered + class QuestProgression: """ QuestProgression keeps track of the completion of a quest. @@ -1043,8 +1116,7 @@ def __init__(self, game: Game, track_quests: bool = True) -> None: """ self.game = game self.state = game.world.state.copy() - self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) + self._valid_actions = self.valid_actions_gen() self.quest_progressions = [] if track_quests: self.quest_progressions = [QuestProgression(quest, game.kb) for quest in game.quests] @@ -1061,6 +1133,11 @@ def copy(self) -> "GameProgression": return gp + def valid_actions_gen(self): + potential_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), + self.game.kb.types.constants_mapping)) + return [act for act in potential_actions if act.is_valid()] + @property def done(self) -> bool: """ Whether all quests are completed or at least one has failed or is unfinishable. """ @@ -1124,34 +1201,43 @@ def winning_policy(self) -> Optional[List[Action]]: # Discard all "trigger" actions. return tuple(a for a in master_quest_tree.flatten() if a.name != "trigger") + def add_traceables(self, action): + s = self.state.facts + for quest_progression in self.quest_progressions: + if not quest_progression.completed and (quest_progression.quest.reward >= 0): + for win_event in quest_progression.win_events: + if win_event.event.traceable and not (win_event.event.traceable in s): + if win_event.will_trigger(self.state, action): + self.state.add_facts(PropositionControl.add_propositions(win_event.event.traceable)) + + def traceable_manager(self): + if not self.state.has_traceable(): + return + + for prop in self.state.get_facts(): + if not prop.name.startswith('is__'): + PropositionControl.set_activated(prop) + PropositionControl.remove(prop, self.state) + def update(self, action: Action) -> None: """ Update the state of the game given the provided action. Args: action: Action affecting the state of the game. """ - # Update world facts. + # Update world facts self.state.apply(action) + self.add_traceables(action) - # Get valid actions. - self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) # Update all quest progressions given the last action and new state. for quest_progression in self.quest_progressions: quest_progression.update(action, self.state) - for quest_progression in self.quest_progressions: - for win_event in quest_progression.win_events: - if quest_progression.quest.reward >= 0: - if isinstance(win_event.event, Event): - self.state.apply(win_event.event.condition) - if isinstance(win_event.event, EventAction): - propos = [prop for prop in win_event.event.actions[0].added] - self.state.apply(Action("trigger", preconditions=propos, - postconditions=propos + [win_event.event.event])) - - self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) + # Update world facts. + self.traceable_manager() + + # Get valid actions. + self._valid_actions = self.valid_actions_gen() class GameOptions: diff --git a/textworld/logic/__init__.py b/textworld/logic/__init__.py index c5bf4966..6cf7b0a5 100644 --- a/textworld/logic/__init__.py +++ b/textworld/logic/__init__.py @@ -63,6 +63,24 @@ def _check_type_conflict(name, old_type, new_type): raise ValueError("Conflicting types for `{}`: have `{}` and `{}`.".format(name, old_type, new_type)) +class UnderspecifiedSignatureError(NameError): + def __init__(self): + msg = "The verb and definition of the signature either should both be None or both take values." + super().__init__(msg) + + +class UnderspecifiedPredicateError(NameError): + def __init__(self): + msg = "The verb and definition of the predicate either should both be None or both take values." + super().__init__(msg) + + +class UnderspecifiedPropositionError(NameError): + def __init__(self): + msg = "The verb and definition of the proposition either should both be None or both take values." + super().__init__(msg) + + class _ModelConverter(NodeWalker): """ Converts TatSu model objects to our types. @@ -118,10 +136,10 @@ def walk_VariableNode(self, node): return self._walk_variable_ish(node, Variable) def walk_SignatureNode(self, node): - return Signature(node.name, node.types) + return Signature(node.name, node.types, node.verb, node.definition) def walk_PropositionNode(self, node): - return Proposition(node.name, self.walk(node.arguments)) + return Proposition(node.name, self.walk(node.arguments), node.verb, node.definition) def walk_ActionNode(self, node): return self._walk_action_ish(node, Action) @@ -130,7 +148,7 @@ def walk_PlaceholderNode(self, node): return self._walk_variable_ish(node, Placeholder) def walk_PredicateNode(self, node): - return Predicate(node.name, self.walk(node.parameters)) + return Predicate(node.name, self.walk(node.parameters), node.verb, node.definition) def walk_RuleNode(self, node): return self._walk_action_ish(node, Rule) @@ -536,7 +554,9 @@ def deserialize(cls, data: Mapping) -> "Variable": lambda cls, args, kwargs: ( cls, kwargs.get("name", args[0] if len(args) >= 1 else None), - tuple(kwargs.get("types", args[1] if len(args) == 2 else [])) + tuple(kwargs.get("types", args[1] if len(args) >= 2 else [])), + kwargs.get("verb", args[2] if len(args) >= 3 else None), + kwargs.get("definition", args[3] if len(args) == 4 else None), ) ) @@ -547,9 +567,9 @@ class Signature(with_metaclass(SignatureTracker, object)): The type signature of a Predicate or Proposition. """ - __slots__ = ("name", "types", "_hash") + __slots__ = ("name", "types", "_hash", "verb", "definition") - def __init__(self, name: str, types: Iterable[str]): + def __init__(self, name: str, types: Iterable[str], verb=None, definition=None): """ Create a Signature. @@ -560,10 +580,22 @@ def __init__(self, name: str, types: Iterable[str]): types : The types of the parameters to the proposition/predicate. """ + if (not verb and definition) or (verb and not definition): + raise UnderspecifiedSignatureError + + if name.count('__') == 0: + verb = "is" + definition = name + name = "is__"+name + else: + verb = name[:name.find('__')] + definition = name[name.find('__') + 2:] self.name = name self.types = tuple(types) - self._hash = hash((self.name, self.types)) + self.verb = verb + self.definition = definition + self._hash = hash((self.name, self.types, self.verb, self.definition)) def __str__(self): return "{}({})".format(self.name, ", ".join(map(str, self.types))) @@ -573,7 +605,7 @@ def __repr__(self): def __eq__(self, other): if isinstance(other, Signature): - return self.name == other.name and self.types == other.types + return self.name == other.name and self.types == other.types and self.verb == other.verb and self.definition == other.definition else: return NotImplemented @@ -604,7 +636,11 @@ def parse(cls, expr: str) -> "Signature": lambda cls, args, kwargs: ( cls, kwargs.get("name", args[0] if len(args) >= 1 else None), - tuple(v.name for v in kwargs.get("arguments", args[1] if len(args) == 2 else [])) + tuple(v.name for v in kwargs.get("arguments", args[1] if len(args) >= 2 else [])), + kwargs.get("verb", args[2] if len(args) >= 3 else None), + kwargs.get("definition", args[3] if len(args) >= 4 else None), + # kwargs.get("activate", 0) + kwargs.get("activate", args[4] if len(args) == 5 else 0) ) ) @@ -615,10 +651,10 @@ class Proposition(with_metaclass(PropositionTracker, object)): An instantiated Predicate, with concrete variables for each placeholder. """ - __slots__ = ("name", "arguments", "signature", "_hash", "definition", "verb_var", "verb_def") + __slots__ = ("name", "arguments", "signature", "_hash", "verb", "definition", "activate") - def __init__(self, name: str, arguments: Iterable[Variable] = [], definition: str = None, - verb_var: int = None, verb_def: str = None): + def __init__(self, name: str, arguments: Iterable[Variable] = [], verb: str = None, definition: str = None, + activate: int = 0): """ Create a Proposition. @@ -630,13 +666,28 @@ def __init__(self, name: str, arguments: Iterable[Variable] = [], definition: st The variables this proposition is applied to. """ + if (not verb and definition) or (verb and not definition): + raise UnderspecifiedPropositionError + + if name.count('__') == 0: + verb = "is" + definition = name + name = "is__"+name + else: + verb = name[:name.find('__')].replace('_', ' ') + definition = name[name.find('__') + 2:] + self.name = name self.arguments = tuple(arguments) - self.signature = Signature(name, [var.type for var in self.arguments]) - self._hash = hash((self.name, self.arguments)) + self.verb = verb self.definition = definition - self.verb_var = verb_var - self.verb_def = verb_def + self.signature = Signature(name, [var.type for var in self.arguments], self.verb, self.definition) + self._hash = hash((self.name, self.arguments, self.verb, self.definition)) + + if self.verb == 'is': + activate = 1 + + self.activate = activate @property def names(self) -> Collection[str]: @@ -652,51 +703,16 @@ def types(self) -> Collection[str]: """ return self.signature.types - def make_str(self, max_arg=False): - args = [v for v in self.arguments] - txt = [] - for i in range(len(args)): - if max_arg: - txt.append("({})".format(", ".join(map(str, [args[i], self.verb_def[i], self.definition[i]])))) - else: - txt.append("({})".format(", ".join(map(str, [args[i], self.definition[i]])))) - - return "{}".format(", ".join(txt)) - def __str__(self): - def make_str(max_arg=False): - args = [v for v in self.arguments] - txt = [] - for i in range(len(args)): - if max_arg: - txt.append("({})".format(", ".join(map(str, [args[i], self.verb_def[i], self.definition[i]])))) - else: - txt.append("({})".format(", ".join(map(str, [args[i], self.definition[i]])))) - - return "{}".format(", ".join(txt)) - - if self.definition and self.verb_def: - return "{}({txt})".format(self.name, txt=make_str(max_arg=True)) - - elif self.definition: - return "{}({txt})".format(self.name, txt=make_str()) - - else: - return "{}({})".format(self.name, ", ".join(map(str, self.arguments))) + return "{}({})".format(self.name, ", ".join(map(str, self.arguments))) def __repr__(self): - if self.definition and self.verb: - return "Proposition({!r}, {!r}, {!r}, {!r})".format(self.name, self.arguments, self.definition, - self.verb_def) - elif self.definition: - return "Proposition({!r}, {!r}, {!r})".format(self.name, self.arguments, self.definition) - - else: - return "Proposition({!r}, {!r})".format(self.name, self.arguments) + return "Proposition({!r}, {!r})".format(self.name, self.arguments) def __eq__(self, other): if isinstance(other, Proposition): - return self.name == other.name and self.arguments == other.arguments + return (self.name, self.arguments, self.verb, self.definition, self.activate) == \ + (other.name, other.arguments, other.verb, other.definition, other.activate) else: return NotImplemented @@ -725,13 +741,19 @@ def serialize(self) -> Mapping: return { "name": self.name, "arguments": [var.serialize() for var in self.arguments], + "verb": self.verb, + "definition": self.definition, + "activate": self.activate } @classmethod def deserialize(cls, data: Mapping) -> "Proposition": name = data["name"] args = [Variable.deserialize(arg) for arg in data["arguments"]] - return cls(name, args) + verb = data["verb"] + definition = data["definition"] + activate = data["activate"] + return cls(name, args, verb, definition, activate) @total_ordering @@ -815,7 +837,7 @@ class Predicate: A boolean-valued function over variables. """ - def __init__(self, name: str, parameters: Iterable[Placeholder]): + def __init__(self, name: str, parameters: Iterable[Placeholder], verb=None, definition=None): """ Create a Predicate. @@ -826,10 +848,22 @@ def __init__(self, name: str, parameters: Iterable[Placeholder]): parameters : The symbolic arguments to this predicate. """ + if (not verb and definition) or (verb and not definition): + raise UnderspecifiedPredicateError + + if name.count('__') == 0: + verb = "is" + definition = name + name = "is__" + name + else: + verb = name[:name.find('__')] + definition = name[name.find('__') + 2:] self.name = name self.parameters = tuple(parameters) - self.signature = Signature(name, [ph.type for ph in self.parameters]) + self.verb = verb + self.definition = definition + self.signature = Signature(name, [ph.type for ph in self.parameters], self.verb, self.definition) @property def names(self) -> Collection[str]: @@ -853,12 +887,12 @@ def __repr__(self): def __eq__(self, other): if isinstance(other, Predicate): - return (self.name, self.parameters) == (other.name, other.parameters) + return (self.name, self.parameters, self.verb, self.definition) == (other.name, other.parameters, other.verb, other.definition) else: return NotImplemented def __hash__(self): - return hash((self.name, self.parameters)) + return hash((self.name, self.types, self.verb, self.definition)) def __lt__(self, other): if isinstance(other, Predicate): @@ -882,13 +916,17 @@ def serialize(self) -> Mapping: return { "name": self.name, "parameters": [ph.serialize() for ph in self.parameters], + "verb": self.verb, + "definition": self.definition } @classmethod def deserialize(cls, data: Mapping) -> "Predicate": name = data["name"] params = [Placeholder.deserialize(ph) for ph in data["parameters"]] - return cls(name, params) + verb = data["verb"] + definition = data["definition"] + return cls(name, params, verb, definition) def substitute(self, mapping: Mapping[Placeholder, Placeholder]) -> "Predicate": """ @@ -901,7 +939,7 @@ def substitute(self, mapping: Mapping[Placeholder, Placeholder]) -> "Predicate": """ params = [mapping.get(param, param) for param in self.parameters] - return Predicate(self.name, params) + return Predicate(self.name, params, self.verb, self.definition) def instantiate(self, mapping: Mapping[Placeholder, Variable]) -> Proposition: """ @@ -918,7 +956,13 @@ def instantiate(self, mapping: Mapping[Placeholder, Variable]) -> Proposition: """ args = [mapping[param] for param in self.parameters] - return Proposition(self.name, args) + return Proposition(self.name, arguments=args, verb=self.verb, definition=self.definition) + + # args = [mapping[param] for param in self.parameters] + # if Proposition.name == 'event': + # return Proposition(self.name, args, verb=special.verb, definition=special.definition) + # else: + # return Proposition(self.name, args) def match(self, proposition: Proposition) -> Optional[Mapping[Placeholder, Variable]]: """ @@ -1105,6 +1149,20 @@ def format_command(self, mapping: Dict[str, str] = {}): mapping = mapping or {v.name: v.name for v in self.variables} return self.command_template.format(**mapping) + def has_traceable(self): + for prop in self.all_propositions: + if not prop.name.startswith('is__'): + return True + return False + + def activate_traceable(self): + for prop in self.all_propositions: + if not prop.name.startswith('is__'): + prop.activate = 1 + + def is_valid(self): + return all([prop.activate == 1 for prop in self.all_propositions]) + class Rule: """ @@ -1231,7 +1289,6 @@ def instantiate(self, mapping: Mapping[Placeholder, Variable]) -> Action: ------- The instantiated Action with each Placeholder mapped to the corresponding Variable. """ - key = tuple(mapping[ph] for ph in self.placeholders) if key in self._cache: return self._cache[key] @@ -1239,7 +1296,8 @@ def instantiate(self, mapping: Mapping[Placeholder, Variable]) -> Action: pre_inst = [pred.instantiate(mapping) for pred in self.preconditions] post_inst = [pred.instantiate(mapping) for pred in self.postconditions] action = Action(self.name, pre_inst, post_inst) - + if action.has_traceable(): + action.activate_traceable() action.command_template = self._make_command_template(mapping) if self.reverse_rule: action.reverse_name = self.reverse_rule.name @@ -1505,6 +1563,26 @@ def _normalize_predicates(self, predicates): result.append(pred) return result + def _predicate_diversity(self): + new_preds = [] + for pred in self.predicates: + for v in ['was', 'has been', 'had been']: + new_preds.append(Signature(name=v.replace(' ', '_') + pred.name[pred.name.find('__'):], types=pred.types, + verb=v, definition=pred.definition)) + self.predicates.update(set(new_preds)) + + def _inform7_predicates_diversity(self): + new_preds = {} + for k, v in self.inform7.predicates.items(): + for vt in ['was', 'has been', 'had been']: + new_preds[Signature(name=vt.replace(' ', '_') + k.name[k.name.find('__'):], types=k.types, + verb=vt, definition=k.definition)] = \ + Inform7Predicate(predicate=Predicate(name=vt.replace(' ', '_') + v.predicate.name[v.predicate.name.find('__'):], + parameters=v.predicate.parameters, verb=vt, + definition=v.predicate.definition), + source=v.source.replace('is', vt)) + self.inform7.predicates.update(new_preds) + @classmethod @lru_cache(maxsize=128, typed=False) def parse(cls, document: str) -> "GameLogic": @@ -1519,6 +1597,8 @@ def load(cls, paths: Iterable[str]): for path in paths: with open(path, "r") as f: result._parse(f.read(), path=path) + result._predicate_diversity() + result._inform7_predicates_diversity() result._initialize() return result @@ -1624,6 +1704,9 @@ def are_facts(self, props: Iterable[Proposition]) -> bool: if not self.is_fact(prop): return False + if not prop.activate: + return False + return True @property @@ -1825,7 +1908,6 @@ def all_assignments(self, seen_phs.add(ph) new_phs_by_depth.append(new_phs) - # Placeholders uniquely found in postcondition are considered as free variables. free_vars = [ph for ph in rule.placeholders if ph not in seen_phs] new_phs_by_depth.append(free_vars) @@ -1969,3 +2051,20 @@ def __str__(self): lines.append("})") return "\n".join(lines) + + def get_facts(self): + all_facts = [] + for sig in sorted(self._facts.keys()): + facts = self._facts[sig] + if len(facts) == 0: + continue + for fact in sorted(facts): + all_facts.append(fact) + return all_facts + + def has_traceable(self): + for prop in self.get_facts(): + if not prop.name.startswith('is__'): + return True + return False + From c2940162dd7027e3b67ebffb8efee9e8bdf7cb1b Mon Sep 17 00:00:00 2001 From: HakiRose Date: Sat, 18 Apr 2020 01:15:33 -0400 Subject: [PATCH 3/5] New TextWorld framework structure includes: new quest description, new quest/Event design, new world2inform7 structure, and more. --- textworld/generator/game.py | 796 +++++++++++++++---- textworld/generator/inform7/world2inform7.py | 266 ++++++- textworld/generator/maker.py | 95 ++- textworld/generator/text_generation.py | 316 +++++++- textworld/generator/world.py | 4 +- textworld/logic/__init__.py | 167 ++-- textworld/logic/parser.py | 31 +- 7 files changed, 1342 insertions(+), 333 deletions(-) diff --git a/textworld/generator/game.py b/textworld/generator/game.py index 0945c194..8a5c50dd 100644 --- a/textworld/generator/game.py +++ b/textworld/generator/game.py @@ -5,6 +5,7 @@ import copy import json import textwrap +import re from typing import List, Dict, Optional, Mapping, Any, Iterable, Union, Tuple from collections import OrderedDict @@ -12,6 +13,7 @@ from numpy.random import RandomState +import textworld from textworld import g_rng from textworld.utils import encode_seeds from textworld.generator.data import KnowledgeBase @@ -94,35 +96,19 @@ def set_events(self): if self.verbs: state_event = [Proposition(name=self.verbs[prop.definition].replace(' ', '_') + '__' + prop.definition, - arguments=prop.arguments, definition=prop.definition, - verb=self.verbs[prop.definition], activate=0) + arguments=prop.arguments) for prop in self.propositions if prop.definition in self.verbs.keys()] - for p in state_event: - p.activate = 0 else: state_event = [] return state_event, event - @classmethod - def add_propositions(cls, props: Iterable[Proposition]) -> Iterable[Proposition]: - for prop in props: - if not prop.name.startswith("is__") and (prop.verb == "has been"): - prop.activate = 1 - - return props - - @classmethod - def set_activated(cls, prop: Proposition): - if not prop.activate: - prop.activate = 1 - @classmethod def remove(cls, prop: Proposition, state: State): if not prop.name.startswith('was__'): return - if prop.activate and (prop in state.facts): + if prop in state.facts: if Proposition(prop.definition, prop.arguments) not in state.facts: state.remove_fact(prop) @@ -133,39 +119,29 @@ def has_traceable(self): return False -class EventCondition: - """ - EventCondition happening in TextWorld. - - An event gets triggered when its set of conditions become all satisfied. - - Attributes: - actions: Actions to be performed to trigger this event - commands: Human readable version of the actions. - condition: :py:class:`textworld.logic.Action` that can only be applied - when all conditions are satisfied. - """ +class Event: - def __init__(self, actions: Iterable[Action] = (), - conditions: Iterable[Proposition] = (), - commands: Iterable[str] = (), - output_verb_tense: dict = ()) -> None: + def __init__(self, actions: Iterable[Action] = (), commands: Iterable[str] = ()) -> None: """ Args: actions: The actions to be performed to trigger this event. - If an empty list, then `conditions` must be provided. - conditions: Set of propositions which need to - be all true in order for this event - to get triggered. commands: Human readable version of the actions. """ - self.actions = actions + + self.actions = list(actions) + self.commands = commands - self.verb_tense = output_verb_tense - self.condition, self.traceable = self.set_conditions(conditions) @property - def actions(self) -> Iterable[Action]: + def verb_tense(self) -> dict: + return self._verb_tense + + @verb_tense.setter + def verb_tense(self, verb: dict) -> None: + self._verb_tense = verb + + @property + def actions(self) -> Tuple[Action]: return self._actions @actions.setter @@ -180,10 +156,62 @@ def commands(self) -> Iterable[str]: def commands(self, commands: Iterable[str]) -> None: self._commands = tuple(commands) - def is_triggering(self, state: State) -> bool: - """ Check if this event would be triggered in a given state. """ + def __hash__(self) -> int: + return hash((self.actions, self.commands)) + + def __eq__(self, other: Any) -> bool: + return (isinstance(other, Event) and + self.actions == other.actions and + self.commands == other.commands) + + @classmethod + def deserialize(cls, data: Mapping) -> "Event": + """ Creates an `Event` from serialized data. + + Args: + data: Serialized data with the needed information to build a `Event` object. + """ + actions = [Action.deserialize(d) for d in data["actions_Event"]] + return cls(actions, data["commands_Event"]) + + def serialize(self) -> Mapping: + """ Serialize this event. + + Results: + `Event`'s data serialized to be JSON compatible. + """ + return {"commands_Event": self.commands, + "actions_Event": [action.serialize() for action in self.actions]} + # data = {} + # data["commands"] = self.commands + # data["actions"] = [action.serialize() for action in self.actions] + # return data + + def copy(self) -> "Event": + """ Copy this event. """ + return self.deserialize(self.serialize()) - return state.is_applicable(self.condition) + +class EventCondition(Event): + def __init__(self, conditions: Iterable[Proposition] = (), + verb_tense: dict = (), + actions: Iterable[Action] = (), + commands: Iterable[str] = (), + ) -> None: + """ + Args: + actions: The actions to be performed to trigger this event. + If an empty list, then `conditions` must be provided. + conditions: Set of propositions which need to be all true in order for this event + to get triggered. + commands: Human readable version of the actions. + verb_tense: The desired verb tense for any state propositions which are been tracking. + """ + super(EventCondition, self).__init__(actions, commands) + + self.verb_tense = verb_tense + + self.condition = self.set_conditions(conditions) def set_conditions(self, conditions: Iterable[Proposition]) -> Action: """ @@ -205,18 +233,29 @@ def set_conditions(self, conditions: Iterable[Proposition]) -> Action: conditions = self.actions[-1].postconditions event = PropositionControl(conditions, self.verb_tense) - traceable = event.traceable_propositions + self.traceable = event.traceable_propositions condition = Action("trigger", preconditions=conditions, postconditions=list(conditions) + [event.addon]) - # The corresponding traceable(s) should be active in state set to be considered for the event. - if condition.has_traceable(): - condition.activate_traceable() + return condition + + def is_valid(self): + return isinstance(self.condition, Action) + + def is_triggering(self, state: State, actions: Iterable[Action] = ()) -> bool: + """ Check if this event would be triggered in a given state. """ + + return state.is_applicable(self.condition) + + @property + def traceable(self) -> Iterable[Proposition]: + return self._traceable - return condition, traceable + @traceable.setter + def traceable(self, traceable: Iterable[Proposition]) -> None: + self._traceable = tuple(traceable) def __hash__(self) -> int: - return hash((self.actions, self.commands, self.condition, self.traceable, - self.verb_tense)) + return hash((self.actions, self.commands, self.condition, self.verb_tense, self.traceable)) def __eq__(self, other: Any) -> bool: return (isinstance(other, EventCondition) and @@ -233,10 +272,9 @@ def deserialize(cls, data: Mapping) -> "EventCondition": Args: data: Serialized data with the needed information to build a `EventCondotion` object. """ - actions = [Action.deserialize(d) for d in data["actions"]] - condition = Action.deserialize(data["condition"]) - event = cls(actions, condition.preconditions, data["commands"], data["output_verb_tense"]) - return event + actions = [Action.deserialize(d) for d in data["actions_EventCondition"]] + condition = Action.deserialize(data["condition_EventCondition"]) + return cls(condition.preconditions, data["verb_tense_EventCondition"], actions, data["commands_EventCondition"]) def serialize(self) -> Mapping: """ Serialize this event. @@ -244,56 +282,79 @@ def serialize(self) -> Mapping: Results: `EventCondition`'s data serialized to be JSON compatible. """ - data = {} - data["commands"] = self.commands - data["actions"] = [action.serialize() for action in self.actions] - data["condition"] = self.condition.serialize() - data["output_verb_tense"] = self.verb_tense - return data + return {"commands_EventCondition": self.commands, + "actions_EventCondition": [action.serialize() for action in self.actions], + "condition_EventCondition": self.condition.serialize(), + "verb_tense_EventCondition": self.verb_tense} + # data = {} + # data["commands"] = self.commands + # data["actions"] = [action.serialize() for action in self.actions] + # data["condition"] = self.condition.serialize() + # data["verb_tense"] = self.verb_tense + # return data def copy(self) -> "EventCondition": """ Copy this event. """ return self.deserialize(self.serialize()) -class EventAction: - def __init__(self, action: Iterable[Action] = (), - output_verb_tense_precond: dict = (), - output_verb_tense_postcond: dict = ()) -> None: - self.verb_tense_precond = output_verb_tense_precond - self.verb_tense_postcond = output_verb_tense_postcond - self.verb_tense = self.set_verbs() - self.actions = list(action) - self.traceable = self.set_actions() - - def set_verbs(self): - def mergeDict(dict1, dict2): - """ - Merge dictionaries and keep values of common keys in list - """ +class EventAction(Event): - dict3 = {**dict1, **dict2} - for key, value in dict3.items(): - if key in dict1 and key in dict2: - dict3[key] = [value, dict1[key]] + def __init__(self, actions: Iterable[Action] = (), + verb_tense: dict = (), + commands: Iterable[str] = ()) -> None: + """ + Args: + actions: The actions to be performed to trigger this event. + commands: Human readable version of the actions. + verb_tense: The desired verb tense for any state propositions which are been tracking. + """ + super(EventAction, self).__init__(actions, commands) - return dict3 + self.verb_tense = verb_tense - return mergeDict(dict(self.verb_tense_precond), dict(self.verb_tense_postcond)) + self.traceable = self.set_actions() def set_actions(self): - props = [] - for p in self.actions[0].all_propositions: - if p not in props: - props.append(p) + traceable = [] + for act in self.actions: + props = [] + for p in act.all_propositions: + if p not in props: + props.append(p) - event = PropositionControl(props, self.verb_tense) - traceable = event.traceable_propositions - return traceable + event = PropositionControl(props, self.verb_tense) + traceable.append(event.traceable_propositions) - def is_triggering(self, action: Action) -> bool: + return [prop for ar in traceable for prop in ar] + + def is_valid(self): + return len(self.actions) != 0 + + def is_triggering(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: """ Check if this event would be triggered for a given action. """ - return action == self.actions[0] + if not actions: + return False + + return all((actions[i] == self.actions[i] for i in range(len(actions)))) + + @property + def traceable(self) -> Iterable[Proposition]: + return self._traceable + + @traceable.setter + def traceable(self, traceable: Iterable[Proposition]) -> None: + self._traceable = tuple(traceable) + + def __hash__(self) -> int: + return hash((self.actions, self.commands, self.verb_tense, self.traceable)) + + def __eq__(self, other: Any) -> bool: + return (isinstance(other, EventAction) and + self.actions == other.actions and + self.commands == other.commands and + self.verb_tense == other.verb_tense and + self.traceable == other.traceable) @classmethod def deserialize(cls, data: Mapping) -> "EventAction": @@ -303,9 +364,8 @@ def deserialize(cls, data: Mapping) -> "EventAction": data: Serialized data with the needed information to build a `EventAction` object. """ - action = [Action.deserialize(d) for d in data["action"]] - event = cls(action, data["output_verb_tense_precond"], data["output_verb_tense_postcond"]) - return event + action = [Action.deserialize(d) for d in data["actions_EventAction"]] + return cls(action, data["verb_tense_EventAction"], data["commands_EventAction"]) def serialize(self) -> Mapping: """ Serialize this event. @@ -313,24 +373,168 @@ def serialize(self) -> Mapping: Results: `EventAction`'s data serialized to be JSON compatible. """ - return {"action": [action.serialize() for action in self.actions], - "output_verb_tense_precond": self.verb_tense_precond, - "output_verb_tense_postcond": self.verb_tense_postcond, + return {"actions_EventAction": [action.serialize() for action in self.actions], + "commands_EventAction": self.commands, + "verb_tense_EventAction": self.verb_tense, } + # return {"actions": [action.serialize() for action in self.actions], + # "commands": self.commands, + # "verb_tense": self.verb_tense, + # } + + def copy(self) -> "EventAction": + """ Copy this event. """ + return self.deserialize(self.serialize()) + + +class EventOr: + def __init__(self, events=()): + self.events = events + self._any_triggered = False + self._any_untriggered = False + + @property + def events(self) -> Tuple[Union[EventAction, EventCondition]]: + return self._events + + @events.setter + def events(self, events) -> None: + self._events = tuple(events) + + def are_triggering(self, state, action): + # status_i, status_t = [], [] + status = [] + for ev in self.events: + if isinstance(ev, EventCondition) or isinstance(ev, EventAction): + status.append(ev.is_triggering(state, [action])) + # status_i.append(ev.is_triggering(state, action)) + continue + status.append(ev.are_triggering(state, action)) + # status_t.append(ev.are_triggering(state, action)) + + return any(status) + # status = [] + # for ev in self.events: + # if isinstance(ev, EventCondition) or isinstance(ev, EventAction): + # status.append(ev.is_triggering(state, action)) + # status + # return any(status) + + def are_events_triggered(self, state, action): + return any((ev.is_triggering(state, action) for ev in self.events)) def __hash__(self) -> int: - return hash((self.actions, self.verb_tense, self.verb_tense_precond, self.verb_tense_postcond, self.traceable)) + return hash(self.events) def __eq__(self, other: Any) -> bool: - return (isinstance(other, EventAction) and - self.actions == other.actions and - self.verb_tense == other.verb_tense and - self.verb_tense_precond == other.verb_tense_precond and - self.verb_tense_postcond == other.verb_tense_postcond and - self.traceable == other.traceable) + return (isinstance(other, EventOr) + and self.events == other.events) - def copy(self) -> "EventAction": - """ Copy this event. """ + def serialize(self) -> Mapping: + """ Serialize this EventOr. + + Results: + EventOr's data serialized to be JSON compatible + """ + return {"events_EventOr": [ev.serialize() for ev in self.events]} + + @classmethod + def deserialize(cls, data: Mapping) -> "EventOr": + """ Creates a `EventOr` from serialized data. + + Args: + data: Serialized data with the needed information to build a `EventOr` object. + """ + events = [] + for d in data["events_EventOr"]: + if "condition_EventCondition" in d.keys(): + events.append(EventCondition.deserialize(d)) + elif "actions_EventAction" in d.keys(): + events.append(EventAction.deserialize(d)) + elif "actions_Event" in d.keys(): + events.append(Event.deserialize(d)) + elif "events_EventAnd" in d.keys(): + events.append(EventAnd.deserialize(d)) + elif "events_EventOr" in d.keys(): + events.append(EventOr.deserialize(d)) + + return cls(events) + + def copy(self) -> "EventOr": + """ Copy this EventOr. """ + return self.deserialize(self.serialize()) + + +class EventAnd: + def __init__(self, events=()): + self.events = events + self._all_triggered = False + self._all_untriggered = False + + @property + def events(self) -> Tuple[Union[EventAction, EventCondition]]: + return self._events + + @events.setter + def events(self, events) -> None: + self._events = tuple(events) + + def are_triggering(self, state, action): + # status_i, status_t = [], [] + status = [] + for ev in self.events: + if isinstance(ev, EventCondition) or isinstance(ev, EventAction): + status.append(ev.is_triggering(state, [action])) + # status_i.append(ev.is_triggering(state, action)) + continue + status.append(ev.are_triggering(state, action)) + # status_t.append(ev.are_triggering(state, action)) + # status_i + # status_t + return all(status) + + def are_events_triggered(self, state, action): + return all((ev.is_triggering(state, action) for ev in self.events)) + + def __hash__(self) -> int: + return hash(self.events) + + def __eq__(self, other: Any) -> bool: + return (isinstance(other, EventAnd) + and self.events == other.events) + + def serialize(self) -> Mapping: + """ Serialize this EventAnd. + + Results: + EventAnd's data serialized to be JSON compatible + """ + return {"events_EventAnd": [ev.serialize() for ev in self.events]} + + @classmethod + def deserialize(cls, data: Mapping) -> "EventAnd": + """ Creates a `EventAnd` from serialized data. + + Args: + data: Serialized data with the needed information to build a `EventAnd` object. + """ + events = [] + for d in data["events_EventAnd"]: + if "condition_EventCondition" in d.keys(): + events.append(EventCondition.deserialize(d)) + elif "actions_EventAction" in d.keys(): + events.append(EventAction.deserialize(d)) + elif "actions_Event" in d.keys(): + events.append(Event.deserialize(d)) + elif "events_EventAnd" in d.keys(): + events.append(EventAnd.deserialize(d)) + elif "events_EventOr" in d.keys(): + events.append(EventOr.deserialize(d)) + + return cls(events) + + def copy(self) -> "EventAnd": + """ Copy this EventAnd. """ return self.deserialize(self.serialize()) @@ -353,8 +557,8 @@ class Quest: """ def __init__(self, - win_events: Iterable[Union[EventCondition, EventAction]] = (), - fail_events: Iterable[Union[EventCondition, EventAction]] = (), + win_events: Iterable[Union[EventAnd, EventOr]] = (), + fail_events: Iterable[Union[EventAnd, EventOr]] = (), reward: Optional[int] = None, desc: Optional[str] = None, commands: Iterable[str] = ()) -> None: @@ -372,11 +576,14 @@ def __init__(self, desc: A text description of the quest. commands: List of text commands leading to this quest completion. """ - self.win_events = tuple(win_events) - self.fail_events = tuple(fail_events) + self.win_events = win_events + self.fail_events = fail_events self.desc = desc self.commands = tuple(commands) + self.win_events_list = self.events_organizer(self.win_events) + self.fail_events_list = self.events_organizer(self.fail_events) + # Unless explicitly provided, reward is set to 1 if there is at least # one winning events otherwise it is set to 0. self.reward = int(len(win_events) > 0) if reward is None else reward @@ -385,21 +592,37 @@ def __init__(self, raise UnderspecifiedQuestError() @property - def win_events(self) -> Iterable[EventCondition]: + def win_events(self) -> Iterable[Union[EventOr, EventAnd]]: return self._win_events @win_events.setter - def win_events(self, events: Iterable[EventCondition]) -> None: + def win_events(self, events: Iterable[Union[EventOr, EventAnd]]) -> None: self._win_events = tuple(events) @property - def fail_events(self) -> Iterable[EventCondition]: + def win_events_list(self) -> Iterable[Union[EventOr, EventAnd]]: + return self._win_events_list + + @win_events_list.setter + def win_events_list(self, events: Iterable[Union[EventOr, EventAnd]]) -> None: + self._win_events_list = tuple(events) + + @property + def fail_events(self) -> Iterable[Union[EventOr, EventAnd]]: return self._fail_events @fail_events.setter - def fail_events(self, events: Iterable[EventCondition]) -> None: + def fail_events(self, events: Iterable[Union[EventOr, EventAnd]]) -> None: self._fail_events = tuple(events) + @property + def fail_events_list(self) -> Iterable[Union[EventOr, EventAnd]]: + return self._fail_events_list + + @fail_events_list.setter + def fail_events_list(self, events: Iterable[Union[EventOr, EventAnd]]) -> None: + self._fail_events_list = tuple(events) + @property def commands(self) -> Iterable[str]: return self._commands @@ -408,17 +631,39 @@ def commands(self) -> Iterable[str]: def commands(self, commands: Iterable[str]) -> None: self._commands = tuple(commands) - def is_winning(self, state: State) -> bool: + def event_organizer(self, combined_event=(), _events=[]): + if isinstance(combined_event, EventCondition) or isinstance(combined_event, EventAction): + _events.append(combined_event) + return + + act = [] + for event in combined_event.events: + out = self.event_organizer(event, act) + if out: + for a in out: + _events.append(a) + + return (len(act) > 0 and len(act) > len(_events)) * act or (len(_events) > 0 and len(_events) > len(act)) * _events + + def events_organizer(self, combined_events=()): + _events_ = [] + for comb_ev in combined_events: + for ev in self.event_organizer(comb_ev, _events=[]): + _events_.append(ev) + + return _events_ + + def is_winning(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: """ Check if this quest is winning in that particular state. """ - return any(event.is_triggering(state) for event in self.win_events) - def is_failing(self, state: State) -> bool: + return any(event.is_triggering(state, actions) for event in self.win_events) + + def is_failing(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: """ Check if this quest is failing in that particular state. """ - return any(event.is_triggering(state) for event in self.fail_events) + return any(event.is_triggering(state, actions) for event in self.fail_events) def __hash__(self) -> int: - return hash((self.win_events, self.fail_events, self.reward, - self.desc, self.commands)) + return hash((self.win_events, self.fail_events, self.reward, self.desc, self.commands)) def __eq__(self, other: Any) -> bool: return (isinstance(other, Quest) @@ -438,19 +683,17 @@ def deserialize(cls, data: Mapping) -> "Quest": """ win_events = [] for d in data["win_events"]: - if "output_verb_tense_precond" in d.keys(): - win_events.append(EventAction.deserialize(d)) - - if "condition" in d.keys(): - win_events.append(EventCondition.deserialize(d)) + if "events_EventOr" in d.keys(): + win_events.append(EventOr.deserialize(d)) + elif "events_EventAnd" in d.keys(): + win_events.append(EventAnd.deserialize(d)) fail_events = [] for d in data["fail_events"]: - if "output_verb_tense_precond" in d.keys(): - fail_events.append(EventAction.deserialize(d)) - - if "condition" in d.keys(): - fail_events.append(EventCondition.deserialize(d)) + if "events_EventOr" in d.keys(): + fail_events.append(EventOr.deserialize(d)) + elif "events_EventAnd" in d.keys(): + fail_events.append(EventAnd.deserialize(d)) commands = data.get("commands", []) reward = data["reward"] @@ -463,18 +706,166 @@ def serialize(self) -> Mapping: Results: Quest's data serialized to be JSON compatible """ - data = {} - data["desc"] = self.desc - data["reward"] = self.reward - data["commands"] = self.commands - data["win_events"] = [event.serialize() for event in self.win_events] - data["fail_events"] = [event.serialize() for event in self.fail_events] - return data + return { + "desc": self.desc, + "reward": self.reward, + "commands": self.commands, + "win_events": [event.serialize() for event in self.win_events], + "fail_events": [event.serialize() for event in self.fail_events] + } def copy(self) -> "Quest": """ Copy this quest. """ return self.deserialize(self.serialize()) +# class Quest: +# """ Quest representation in TextWorld. +# +# A quest is defined by a mutually exclusive set of winning events and +# a mutually exclusive set of failing events. +# +# Attributes: +# win_events: Mutually exclusive set of winning events. That is, +# only one such event needs to be triggered in order +# to complete this quest. +# fail_events: Mutually exclusive set of failing events. That is, +# only one such event needs to be triggered in order +# to fail this quest. +# reward: Reward given for completing this quest. +# desc: A text description of the quest. +# commands: List of text commands leading to this quest completion. +# """ +# +# def __init__(self, +# win_events: Iterable[Union[Event, EventCondition, EventAction]] = (), +# fail_events: Iterable[Union[Event, EventCondition, EventAction]] = (), +# reward: Optional[int] = None, +# desc: Optional[str] = None, +# commands: Iterable[str] = ()) -> None: +# r""" +# Args: +# win_events: Mutually exclusive set of winning events. That is, +# only one such event needs to be triggered in order +# to complete this quest. +# fail_events: Mutually exclusive set of failing events. That is, +# only one such event needs to be triggered in order +# to fail this quest. +# reward: Reward given for completing this quest. By default, +# reward is set to 1 if there is at least one winning events +# otherwise it is set to 0. +# desc: A text description of the quest. +# commands: List of text commands leading to this quest completion. +# """ +# self.win_events = tuple(win_events) +# self.fail_events = tuple(fail_events) +# self.desc = desc +# self.commands = tuple(commands) +# +# # Unless explicitly provided, reward is set to 1 if there is at least +# # one winning events otherwise it is set to 0. +# self.reward = int(len(win_events) > 0) if reward is None else reward +# +# if len(self.win_events) == 0 and len(self.fail_events) == 0: +# raise UnderspecifiedQuestError() +# +# @property +# def win_events(self) -> Iterable[Union[Event, EventCondition, EventAction]]: +# return self._win_events +# +# @win_events.setter +# def win_events(self, events: Iterable[Union[Event, EventCondition, EventAction]]) -> None: +# self._win_events = tuple(events) +# +# @property +# def fail_events(self) -> Iterable[Union[Event, EventCondition, EventAction]]: +# return self._fail_events +# +# @fail_events.setter +# def fail_events(self, events: Iterable[Union[Event, EventCondition, EventAction]]) -> None: +# self._fail_events = tuple(events) +# +# @property +# def commands(self) -> Iterable[str]: +# return self._commands +# +# @commands.setter +# def commands(self, commands: Iterable[str]) -> None: +# self._commands = tuple(commands) +# +# def is_winning(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: +# """ Check if this quest is winning in that particular state. """ +# +# return any(event.is_triggering(state, actions) for event in self.win_events) +# +# def is_failing(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: +# """ Check if this quest is failing in that particular state. """ +# return any(event.is_triggering(state, actions) for event in self.fail_events) +# +# def __hash__(self) -> int: +# return hash((self.win_events, self.fail_events, self.reward, self.desc, self.commands)) +# +# def __eq__(self, other: Any) -> bool: +# return (isinstance(other, Quest) +# and self.win_events == other.win_events +# and self.fail_events == other.fail_events +# and self.reward == other.reward +# and self.desc == other.desc +# and self.commands == other.commands) +# +# @classmethod +# def deserialize(cls, data: Mapping) -> "Quest": +# """ Creates a `Quest` from serialized data. +# +# Args: +# data: Serialized data with the needed information to build a +# `Quest` object. +# """ +# +# win_events = [] +# for d in data["win_events"]: +# if "action_verb_tense" in d.keys(): +# win_events.append(Event.deserialize(d)) +# +# elif "output_verb_tense" in d.keys() and "commands" in d.keys(): +# win_events.append(EventCondition.deserialize(d)) +# +# else: +# win_events.append(EventAction.deserialize(d)) +# +# fail_events = [] +# for d in data["fail_events"]: +# if "action_verb_tense" in d.keys(): +# fail_events.append(Event.deserialize(d)) +# +# elif "output_verb_tense" in d.keys() and "commands" in d.keys(): +# fail_events.append(EventCondition.deserialize(d)) +# +# else: +# fail_events.append(EventAction.deserialize(d)) +# +# commands = data.get("commands", []) +# reward = data["reward"] +# desc = data["desc"] +# return cls(win_events, fail_events, reward, desc, commands) +# +# def serialize(self) -> Mapping: +# """ Serialize this quest. +# +# Results: +# Quest's data serialized to be JSON compatible +# """ +# data = {} +# data["desc"] = self.desc +# data["reward"] = self.reward +# data["commands"] = self.commands +# data["win_events"] = [event.serialize() for event in self.win_events] +# data["fail_events"] = [event.serialize() for event in self.fail_events] +# return data +# +# def copy(self) -> "Quest": +# """ Copy this quest. """ +# return self.deserialize(self.serialize()) + class EntityInfo: """ Additional information about entities in the game. """ @@ -558,11 +949,14 @@ def __init__(self, world: World, grammar: Optional[Grammar] = None, self.quests = tuple(quests) self.metadata = {} self._objective = None + # self.objective self._infos = self._build_infos() self.kb = world.kb self.change_grammar(grammar) + + @property def infos(self) -> Dict[str, EntityInfo]: """ Information about the entities in the game. """ @@ -595,15 +989,17 @@ def change_grammar(self, grammar: Grammar) -> None: inform7 = Inform7Game(self) _gen_commands = inform7.gen_commands_from_actions generate_text_from_grammar(self, self.grammar) + from textworld.generator.text_generation import describe_quests + self.objective = describe_quests(self, self.grammar) for quest in self.quests: # TODO: should have a generic way of generating text commands from actions # instead of relying on inform7 convention. - for event in quest.win_events: + for event in quest.win_events_list: event.commands = _gen_commands(event.actions) - if quest.win_events: - quest.commands = quest.win_events[0].commands + if quest.win_events_list: + quest.commands = quest.win_events_list[0].commands # Check if we can derive a global winning policy from the quests. if self.grammar: @@ -615,6 +1011,26 @@ def change_grammar(self, grammar: Grammar) -> None: self.metadata["walkthrough"] = commands self.objective = describe_event(EventCondition(policy), self, self.grammar) + def command_generator(self, events, _gen_commands): + for event in events: + events.commands = _gen_commands(events.actions) + + # def command_generator(self, events, _gen_commands, quest): + # if isinstance(events, EventCondition) or isinstance(events, EventAction): + # events.commands = _gen_commands(events.actions) + # # quest.append(events.actions) + # quest.append(events) + # return + # + # act = [] + # for event in events.events: + # out = self.command_generator(event, _gen_commands, act) + # if out: + # for a in out: + # quest.append(a) + # + # return (len(act) > 0 and len(act) > len(quest)) * act or (len(quest) > 0 and len(quest) > len(act)) * quest + def save(self, filename: str) -> None: """ Saves the serialized data of this game to a file. """ with open(filename, 'w') as f: @@ -743,7 +1159,7 @@ def objective(self) -> str: return self._objective # TODO: Find a better way of describing the objective of the game with several quests. - self._objective = "\nAND\n".join(quest.desc for quest in self.quests if quest.desc) + self._objective = "\n The next quest is \n".join(quest.desc for quest in self.quests if quest.desc) return self._objective @@ -884,7 +1300,7 @@ class EventProgression: relevant actions to be performed. """ - def __init__(self, event: Union[EventCondition, EventAction], kb: KnowledgeBase) -> None: + def __init__(self, event: Union[Event, EventCondition, EventAction], kb: KnowledgeBase) -> None: """ Args: quest: The quest to keep track of its completion. @@ -899,14 +1315,30 @@ def __init__(self, event: Union[EventCondition, EventAction], kb: KnowledgeBase) self._tree = ActionDependencyTree(kb=self._kb, element_type=ActionDependencyTreeElement) - if isinstance(event, EventCondition): - if len(event.actions) > 0: - self._tree.push(event.condition) - - for action in event.actions[::-1]: - self._tree.push(action) + self.tree_policy(event) + # if not isinstance(event, EventAction) and not isinstance(event, Event): + # if len(event.actions) > 0: + # self._tree.push(event.condition) + # + # for action in event.actions[::-1]: + # self._tree.push(action) + # + # self._policy = event.actions + (event.condition,) + + def tree_policy(self, event): + if isinstance(event, EventCondition) or isinstance(event, EventAction): + if isinstance(event, EventCondition): + if len(event.actions) > 0: + self._tree.push(event.condition) + + for action in event.actions[::-1]: + self._tree.push(action) + + self._policy = event.actions + (event.condition,) + return - self._policy = event.actions + (event.condition,) + for ev in event.events: + self.tree_policy(ev) def copy(self) -> "EventProgression": """ Return a soft copy. """ @@ -941,7 +1373,7 @@ def untriggerable(self) -> bool: """ Check whether the event is in an untriggerable state. """ return self._untriggerable - def update(self, action: Optional[Action] = None, state: Optional[State] = None) -> None: + def update(self, action: Tuple[Action] = (), state: Optional[State] = None) -> None: """ Update event progression given available information. Args: @@ -953,18 +1385,13 @@ def update(self, action: Optional[Action] = None, state: Optional[State] = None) if state is not None: # Check if event is triggered. - - if isinstance(self.event, EventCondition): - self._triggered = self.event.is_triggering(state) - - if isinstance(self.event, EventAction): - self._triggered = self.event.is_triggering(action) + self._triggered = self.event.are_triggering(state, action) # Try compressing the winning policy given the new game state. if self.compress_policy(state): return # A shorter winning policy has been found. - if action is not None and not self._tree.empty: + if action and not self._tree.empty: # Determine if we moved away from the goal or closer to it. changed, reverse_action = self._tree.remove(action) if changed and reverse_action is None: # Irreversible action. @@ -1007,12 +1434,8 @@ def _find_shorter_policy(policy): return compressed - def will_trigger(self, state: State, action: Action): - if isinstance(self.event, EventCondition): - triggered = self.event.is_triggering(state) - - if isinstance(self.event, EventAction): - triggered = self.event.is_triggering(action) + def will_trigger(self, state: State, action: Tuple[Action]): + triggered = self.event.are_triggering(state, action) return triggered @@ -1075,7 +1498,8 @@ def done(self) -> bool: @property def completed(self) -> bool: """ Check whether the quest is completed. """ - return any(event.triggered for event in self.win_events) + return all(event.triggered for event in self.win_events) + # return any(event.triggered for event in self.win_events) @property def failed(self) -> bool: @@ -1116,12 +1540,13 @@ def __init__(self, game: Game, track_quests: bool = True) -> None: """ self.game = game self.state = game.world.state.copy() - self._valid_actions = self.valid_actions_gen() + self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), + self.game.kb.types.constants_mapping)) self.quest_progressions = [] if track_quests: self.quest_progressions = [QuestProgression(quest, game.kb) for quest in game.quests] for quest_progression in self.quest_progressions: - quest_progression.update(action=None, state=self.state) + quest_progression.update(action=(), state=self.state) def copy(self) -> "GameProgression": """ Return a soft copy. """ @@ -1201,14 +1626,27 @@ def winning_policy(self) -> Optional[List[Action]]: # Discard all "trigger" actions. return tuple(a for a in master_quest_tree.flatten() if a.name != "trigger") + def any_traceable_exist(self, events): + if isinstance(events, EventCondition) or isinstance(events, EventAction): + return len(events.traceable) > 0 and not (events.traceable in self.state.facts) + + trc_exist = [] + for event in events.events: + trc_exist.append(self.any_traceable_exist(event)) + + return any(trc_exist) + def add_traceables(self, action): s = self.state.facts + trace = [] for quest_progression in self.quest_progressions: - if not quest_progression.completed and (quest_progression.quest.reward >= 0): + if quest_progression.quest.reward >= 0: for win_event in quest_progression.win_events: - if win_event.event.traceable and not (win_event.event.traceable in s): - if win_event.will_trigger(self.state, action): - self.state.add_facts(PropositionControl.add_propositions(win_event.event.traceable)) + if self.any_traceable_exist(win_event.event): + if win_event.will_trigger(self.state, tuple([action])): + trace.append(tr for eve in win_event.event.events for tr in eve.traceable) + + return [p for ar in trace for p in ar] def traceable_manager(self): if not self.state.has_traceable(): @@ -1216,7 +1654,6 @@ def traceable_manager(self): for prop in self.state.get_facts(): if not prop.name.startswith('is__'): - PropositionControl.set_activated(prop) PropositionControl.remove(prop, self.state) def update(self, action: Action) -> None: @@ -1227,17 +1664,28 @@ def update(self, action: Action) -> None: """ # Update world facts self.state.apply(action) - self.add_traceables(action) + trace = self.add_traceables(action) + if trace: + for prop in trace: + if prop.name.startswith('has_been') and prop not in self.state.facts: + self.state.add_facts([prop]) # Update all quest progressions given the last action and new state. for quest_progression in self.quest_progressions: quest_progression.update(action, self.state) # Update world facts. + if trace: + for prop in trace: + if not prop.name.startswith('has_been') and prop not in self.state.facts: + self.state.add_facts([prop]) + self.traceable_manager() # Get valid actions. - self._valid_actions = self.valid_actions_gen() + self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), + self.game.kb.types.constants_mapping)) + x = 0 class GameOptions: diff --git a/textworld/generator/inform7/world2inform7.py b/textworld/generator/inform7/world2inform7.py index 4877f43b..387afad8 100644 --- a/textworld/generator/inform7/world2inform7.py +++ b/textworld/generator/inform7/world2inform7.py @@ -15,7 +15,7 @@ from textworld.utils import make_temp_directory, str2bool, chunk -from textworld.generator.game import Event, EventAction, Game +from textworld.generator.game import EventCondition, EventAction, Event, EventAnd, EventOr, Game from textworld.generator.world import WorldRoom, WorldEntity from textworld.logic import Signature, Proposition, Action, Variable @@ -99,7 +99,12 @@ def gen_source_for_attribute(self, attr: Proposition) -> Optional[str]: def gen_source_for_attributes(self, attributes: Iterable[Proposition]) -> str: source = "" for attr in attributes: - source_attr = self.gen_source_for_attribute(attr) + if attr.name.count('__') == 0: + attr_ = Proposition(name='is__' + attr.name, arguments=attr.arguments, verb='is', definition=attr.name) + else: + attr_ = attr + + source_attr = self.gen_source_for_attribute(attr_) if source_attr: source += source_attr + ".\n" @@ -213,6 +218,9 @@ def gen_source_for_rooms(self) -> str: def _get_name_mapping(self, action): mapping = self.kb.rules[action.name].match(action) + for ph, var in mapping.items(): + a = ph.name + b = self.entity_infos[var.name].name return {ph.name: self.entity_infos[var.name].name for ph, var in mapping.items()} def _get_entities_mapping(self, action): @@ -339,6 +347,8 @@ def gen_source(self, seed: int = 1234) -> str: objective = self.game.objective.replace("\n", "[line break]") maximum_score = 0 wining = 0 + quests_text, viewed_actions = [], {} + action_id = [] for quest_id, quest in enumerate(self.game.quests): maximum_score += quest.reward @@ -346,14 +356,14 @@ def gen_source(self, seed: int = 1234) -> str: The quest{quest_id} completed is a truth state that varies. The quest{quest_id} completed is usually false. """) - source += quest_completed.format(quest_id=quest_id) + quest_ending = quest_completed.format(quest_id=quest_id) - for event_id, event in enumerate(quest.win_events): + for event_id, event in enumerate(quest.win_events_list): commands = self.gen_commands_from_actions(event.actions) event.commands = commands walkthrough = '\nTest quest{}_{} with "{}"\n\n'.format(quest_id, event_id, " / ".join(commands)) - source += walkthrough + quest_ending += walkthrough # Add winning and losing conditions for quest. quest_ending_conditions = textwrap.dedent("""\ @@ -361,54 +371,71 @@ def gen_source(self, seed: int = 1234) -> str: do nothing;""".format(quest_id=quest_id)) fail_template = textwrap.dedent(""" - else if {conditions}: - end the story; [Lost]""") + otherwise if {conditions}: + end the story; [Lost];""") - win_template_state = textwrap.dedent(""" - else if {conditions}: + win_template = textwrap.dedent(""" + otherwise if {conditions}: increase the score by {reward}; [Quest completed] - Now the quest{quest_id} completed is true;""") + Now the quest{quest_id} completed is true; + {removed_conditions}""") - win_template_action = textwrap.dedent(""" - else: - After {conditions}: - increase the score by {reward}; [Quest completed] - Now the quest{quest_id} completed is true;""") + otherwise_template = textwrap.dedent("""\ + otherwise: + {removed_conditions}""") + conditions, removals = '', '' + cond_id = [] for fail_event in quest.fail_events: - if isinstance(fail_event, Event): - param = fail_event.condition - if isinstance(fail_event, EventAction): - param = [act for act in fail_event.actions][0] + condition, removed_conditions, final_condition, _, _ = self.get_events(fail_event, + textwrap.dedent(""""""), + textwrap.dedent(""""""), + action_id=action_id, + cond_id=cond_id, + quest_id=quest_id, + rwd_conds=viewed_actions) + removals += (len(removals) > 0) * ' ' + '' + removed_conditions + quest_ending_conditions += fail_template.format(conditions=final_condition) + conditions += condition - conditions = self.gen_source_for_conditions(param.preconditions) - quest_ending_conditions += fail_template.format(conditions=conditions) + wining += 1 for win_event in quest.win_events: - if isinstance(win_event, Event): - conditions = self.gen_source_for_conditions(win_event.condition.preconditions) - quest_ending_conditions += win_template_state.format(conditions=conditions, - reward=quest.reward, - quest_id=quest_id) - - if isinstance(win_event, EventAction): - conditions = self.gen_source_for_actions([act for act in win_event.actions]) - quest_ending_conditions += win_template_action.format(conditions=conditions, - reward=quest.reward, - quest_id=quest_id) + condition, removed_conditions, final_condition, _, _ = self.get_events(win_event, + textwrap.dedent(""""""), + textwrap.dedent(""""""), + action_id=action_id, + cond_id=cond_id, + quest_id=quest_id, + rwd_conds=viewed_actions) + removals += (len(removals) > 0) * ' ' + '' + removed_conditions + quest_ending_conditions += win_template.format(reward=quest.reward, quest_id=quest_id, + conditions=final_condition, + removed_conditions=textwrap.indent(removals, "")) + conditions += condition + wining += 1 - quest_ending = """\ + if removals: + quest_ending_conditions += otherwise_template.format(removed_conditions=textwrap.indent(removals, "")) + + quest_condition_template = """\ Every turn:\n{conditions} """.format(conditions=textwrap.indent(quest_ending_conditions, " ")) - source += textwrap.dedent(quest_ending) + + quest_ending += textwrap.dedent(quest_condition_template) + + source += textwrap.dedent(conditions) + source += textwrap.dedent('\n') + quests_text += [quest_ending] + + source += textwrap.dedent('\n'.join(txt for txt in quests_text if txt)) # Enable scoring is at least one quest has nonzero reward. if maximum_score >= 0: source += "Use scoring. The maximum score is {}.\n".format(maximum_score) - # Build test condition for winning the game. game_winning_test = "1 is 0 [always false]" if wining > 0: @@ -1017,6 +1044,175 @@ def gen_source(self, seed: int = 1234) -> str: return source + def get_events(self, combined_events, txt, rmv, quest_id, rwd_conds, action_id=[], + cond_id=[], check_vars=[]): + + action_processing_template = textwrap.dedent(""" + The action{action_id} check is a truth state that varies. + The action{action_id} check is usually false. + After {actions}: + Now the action{action_id} check is true. + """) + + remove_action_processing_template = textwrap.dedent("""Now the action{action_id} check is false; + """) + + combined_ac_processing_template = textwrap.dedent(""" + The condition{cond_id} of quest{quest_id} check is a truth state that varies. + The condition{cond_id} of quest{quest_id} check is usually false. + Every turn: + if {conditions}: + Now the condition{cond_id} of quest{quest_id} check is true. + """) + + remove_condition_processing_template = textwrap.dedent("""Now the condition{cond_id} of quest{quest_id} check is false; + """) + + if isinstance(combined_events, EventCondition) or isinstance(combined_events, EventAction): + if isinstance(combined_events, EventCondition): + check_vars += [self.gen_source_for_conditions(combined_events.condition.preconditions)] + return [None] * 5 + + elif isinstance(combined_events, EventAction): + i7_ = self.gen_source_for_actions(combined_events.actions) + if not rwd_conds or i7_ not in rwd_conds.values(): + txt += [action_processing_template.format(action_id=len(action_id), actions=i7_)] + rmv += [remove_action_processing_template.format(action_id=len(action_id))] + temp = [self.gen_source_for_conditions([prop]) for prop in combined_events.actions[0].preconditions + if prop.verb != 'is'] + if temp: + temp = ' and ' + ' and '.join(t for t in temp) + else: + temp = '' + check_vars += ['action{action_id} check is true'.format(action_id=len(action_id)) + temp] + rwd_conds['action{action_id}'.format(action_id=len(action_id))] = i7_ + action_id += [1] + else: + word = list(rwd_conds.keys())[list(rwd_conds.values()).index(i7_)] + rmv += [remove_action_processing_template.format(action_id=word[6:])] + temp = [self.gen_source_for_conditions([prop]) for prop in combined_events.actions[0].preconditions + if prop.verb != 'is'] + check_vars += ['action{action_id} check is true'.format(action_id=word[6:]) + ' and ' + + ' and '.join(t for t in temp)] + + return [None] * 5 + + act_type, _txt, _rmv, _check_vars, _cond_id = [], [], [], [], [] + for event in combined_events.events: + st, rm, a3, a4, cond_type = self.get_events(event, _txt, _rmv, quest_id, rwd_conds, action_id, cond_id, + check_vars=_check_vars) + act_type.append(isinstance(event, EventAction)) + + if st: + _txt += [st] + _rmv += [rm] + _check_vars.append('condition{cond_id} of quest{quest_id} check is true'.format(cond_id=len(cond_id)-1, + quest_id=quest_id)) + if cond_type: + _cond_id += cond_type + + if any(_cond_id): + _rmv += [remove_condition_processing_template.format(quest_id=quest_id, cond_id=len(cond_id) - 1)] + + event_rule = isinstance(combined_events, EventAnd) * ' and ' + isinstance(combined_events, EventOr) * ' or ' + condition_ = event_rule.join(cv for cv in _check_vars) + tp_txt = ''.join(tx for tx in _txt) + tp_txt += combined_ac_processing_template.format(quest_id=quest_id, cond_id=len(cond_id), conditions=condition_) + tp_rmv = ' '.join(ac for ac in _rmv if ac) + fin_cond = 'condition{cond_id} of quest{quest_id} check is true'.format(cond_id=len(cond_id), quest_id=quest_id) + cond_id += [1] + if any(act_type): + cond_type = [True] + + return tp_txt, tp_rmv, fin_cond, [action_id, cond_id, rwd_conds], cond_type + + + + # def get_events(self, combined_events, txt, rmv, quest_id, action_id=[], state_id=[], cond_id=[], check_vars=[], rwd_conds=[]): + # + # remove_action_processing_template = textwrap.dedent("""if the action{action_id} check is true: + # Now the action{action_id} check is false. + # """) + # remove_condition_processing_template = textwrap.dedent("""if the condition{cond_id} of quest{quest_id} check is true: + # Now the condition{cond_id} of quest{quest_id} check is false. + # """) + # action_processing_template = textwrap.dedent(""" + # The action{action_id} check is a truth state that varies. + # The action{action_id} check is usually false. + # After {actions}: + # Now the action{action_id} check is true. + # """) + # state_processing_template = textwrap.dedent(""" + # The state{state_id} of quest{quest_id} check is a truth state that varies. + # The state{state_id} of quest{quest_id} check is usually false. + # Every turn: + # if {conditions}: + # Now the state{state_id} of quest{quest_id} check is true. + # """) + # combined_ac_processing_template = textwrap.dedent(""" + # The condition{cond_id} of quest{quest_id} check is a truth state that varies. + # The condition{cond_id} of quest{quest_id} check is usually false. + # Every turn: + # if {conditions}: + # Now the condition{cond_id} of quest{quest_id} check is true. + # """) + # + # if isinstance(combined_events, EventCondition) or isinstance(combined_events, EventAction): + # if isinstance(combined_events, EventCondition): + # i7_ = self.gen_source_for_conditions(combined_events.condition.preconditions) + # # txt += [i7_] + # check_vars.append(i7_) + # state_id += [1] + # return [None] * 6 + # + # elif isinstance(combined_events, EventAction): + # if not rwd_conds or 'action{action_id}'.format(action_id=len(action_id)) not in rwd_conds: + # i7_ = self.gen_source_for_actions(combined_events.actions) + # txt += [action_processing_template.format(quest_id=quest_id, action_id=len(action_id), actions=i7_)] + # rmv += [remove_action_processing_template.format(quest_id=quest_id, action_id=len(action_id))] + # check_vars.append('action{action_id} check is true'.format(action_id=len(action_id))) + # rwd_conds += {i7_: 'action{action_id}'.format(action_id=len(action_id))} + # action_id += [1] + # else: + # rmv += [remove_action_processing_template.format(quest_id=quest_id, action_id=len(action_id))] + # check_vars.append('action{action_id} check is true'.format(action_id=len(action_id))) + # + # return [None] * 6 + # + # # act_type = [] + # act_type, _txt, _rmv, _check_vars, _cond_id, _rwd_conds = [], [], [], [], [], [] + # # _cond_id, _rwd_conds = [], [] + # for event in combined_events.events: + # st, rm, _, _, cond_type, rwd = self.get_events(event, _txt, _rmv, quest_id, action_id, state_id, cond_id, + # _check_vars, _rwd_conds) + # act_type.append(isinstance(event, EventAction)) + # + # if st: + # _txt += [st] + # _rmv += [rm] + # _check_vars.append('condition{cond_id} of quest{quest_id} check is true'.format(cond_id=len(cond_id)-1, + # quest_id=quest_id)) + # # _rwd_conds += rwd + # + # if cond_type: + # _cond_id += cond_type + # + # if any(_cond_id): + # _rmv += [remove_condition_processing_template.format(quest_id=quest_id, cond_id=len(cond_id) - 1)] + # + # event_rule = isinstance(combined_events, EventAnd) * ' and ' + isinstance(combined_events, EventOr) * ' or ' + # # _rwd_conds += _check_vars + # condition_ = event_rule.join(cv for cv in _check_vars) + # tp_txt = ''.join(tx for tx in _txt) + # tp_txt += combined_ac_processing_template.format(quest_id=quest_id, cond_id=len(cond_id), conditions=condition_) + # tp_rmv = ' '.join(ac for ac in _rmv if ac) + # fin_cond = 'condition{cond_id} of quest{quest_id} check is true'.format(cond_id=len(cond_id), quest_id=quest_id) + # cond_id += [1] + # if any(act_type): + # cond_type = [True] + # + # return tp_txt, tp_rmv, fin_cond, [action_id, state_id, cond_id, _rwd_conds], cond_type, _rwd_conds + def generate_inform7_source(game: Game, seed: int = 1234, use_i7_description: bool = False) -> str: inform7 = Inform7Game(game) diff --git a/textworld/generator/maker.py b/textworld/generator/maker.py index 9dcfd0d7..68f2c83e 100644 --- a/textworld/generator/maker.py +++ b/textworld/generator/maker.py @@ -21,7 +21,7 @@ from textworld.generator.vtypes import get_new from textworld.logic import State, Variable, Proposition, Action from textworld.generator.game import GameOptions -from textworld.generator.game import Game, World, Quest, Event, EntityInfo +from textworld.generator.game import Game, World, Quest, EventAnd, EventOr, EventCondition, EventAction, EntityInfo from textworld.generator.graph_networks import DIRECTIONS from textworld.render import visualize from textworld.envs.wrappers import Recorder @@ -29,7 +29,7 @@ def get_failing_constraints(state, kb: Optional[KnowledgeBase] = None): kb = kb or KnowledgeBase.default() - fail = Proposition("fail", []) + fail = Proposition("is__fail", []) failed_constraints = [] constraints = state.all_applicable_actions(kb.constraints.values()) @@ -76,6 +76,12 @@ def __init__(self, failed_constraints: List[Action]) -> None: super().__init__(msg) +class UnderspecifiedEventError(NameError): + def __init__(self): + msg = "Either the actions or the conditions is required to create an event. Both cannot be provided." + super().__init__(msg) + + class WorldEntity: """ Represents an entity in the world. @@ -145,7 +151,7 @@ def add_fact(self, name: str, *entities: List["WorldEntity"]) -> None: *entities: A list of entities as arguments to the new fact. """ args = [entity.var for entity in entities] - self._facts.append(Proposition(name, args)) + self._facts.append(Proposition(name='is__' + name, arguments=args)) def remove_fact(self, name: str, *entities: List["WorldEntity"]) -> None: args = [entity.var for entity in entities] @@ -632,7 +638,7 @@ def record_quest(self) -> Quest: actions = [action for action in recorder.actions if action is not None] # Assume the last action contains all the relevant facts about the winning condition. - event = Event(actions=actions) + event = EventCondition(actions=actions) self.quests.append(Quest(win_events=[event])) # Calling build will generate the description for the quest. self.build() @@ -665,14 +671,14 @@ def set_quest_from_commands(self, commands: List[str]) -> Quest: unrecognized_commands = [c for c, a in zip(commands, recorder.actions) if a is None] raise QuestError("Some of the actions were unrecognized: {}".format(unrecognized_commands)) - event = Event(actions=actions) + event = EventCondition(actions=actions, conditions=winning_facts) self.quests = [Quest(win_events=[event])] # Calling build will generate the description for the quest. self.build() return self.quests[-1] - def new_fact(self, name: str, *entities: List["WorldEntity"]) -> None: + def new_fact(self, name: str, *entities: List["WorldEntity"]) -> Proposition: """ Create new fact. Args: @@ -682,7 +688,7 @@ def new_fact(self, name: str, *entities: List["WorldEntity"]) -> None: args = [entity.var for entity in entities] return Proposition(name, args) - def new_rule_fact(self, name: str, *entities: List["WorldEntity"]) -> None: + def new_action(self, name: str, *entities: List["WorldEntity"]) -> Union[None, Action]: """ Create new fact about a rule. Args: @@ -694,8 +700,7 @@ def new_conditions(conditions, args): new_ph = [] for pred in conditions: new_var = [var for ph in pred.parameters for var in args if ph.type == var.type] - new_ph.append(Proposition(pred.name, new_var)) - + new_ph.append(Proposition(name=pred.name, arguments=new_var)) return new_ph args = [entity.var for entity in entities] @@ -705,11 +710,77 @@ def new_conditions(conditions, args): precond = new_conditions(rule.preconditions, args) postcond = new_conditions(rule.postconditions, args) - return Action(rule.name, precond, postcond) + action = Action(rule.name, precond, postcond) + + if action.has_traceable(): + action.activate_traceable() + + return action return None - def new_event_using_commands(self, commands: List[str]) -> Event: + def new_event(self, action: Iterable[Action] = (), condition: Iterable[Proposition] = (), + command: Iterable[str] = (), condition_verb_tense: dict = (), action_verb_tense: dict = ()): + + if action and condition: + raise UnderspecifiedEventError + + if action: + event = EventAction(actions=action, verb_tense=action_verb_tense, commands=command) + + elif condition: + event = EventCondition(conditions=condition, verb_tense=condition_verb_tense, actions=action, + commands=command) + + # return tuple(ev for ev in [event] if ev) + return event + + def new_operation(self, operation={}): + def func(operator='or', events=[]): + if operator == 'or' and events: + return EventOr(events=events) + if operator == 'and' and events: + return EventAnd(events=events) + else: + raise + + if not isinstance(operation, dict): + if len(operation) == 0: + return () + else: + raise + + y1 = [] + for k, v in operation.items(): + if isinstance(v, dict): + y1.append(self.new_operation(operation=v)[0]) + y1 = [func(k, y1)] + else: + if isinstance(v, EventCondition) or isinstance(v, EventAction): + y1.append(func(k, [v])) + else: + if any((isinstance(it, dict) for it in v)): + y2 = [] + for it in v: + if isinstance(it, dict): + y2.append(self.new_operation(operation=it)[0]) + else: + y2.append(func(k, [it])) + + y1 = [func(k, y2)] + else: + y1.append(func(k, v)) + + return tuple(y1) + + def new_quest(self, win_event=(), fail_event=(), reward=None, desc=None, commands=()) -> Quest: + return Quest(win_events=self.new_operation(operation=win_event), + fail_events=self.new_operation(operation=fail_event), + reward=reward, + desc=desc, + commands=commands) + + def new_event_using_commands(self, commands: List[str]) -> EventCondition: """ Creates a new event using predefined text commands. This launches a `textworld.play` session to execute provided commands. @@ -731,7 +802,7 @@ def new_event_using_commands(self, commands: List[str]) -> Event: # Skip "None" actions. actions, commands = zip(*[(a, c) for a, c in zip(recorder.actions, commands) if a is not None]) - event = Event(actions=actions, commands=commands) + event = EventCondition(actions=actions, commands=commands) return event def new_quest_using_commands(self, commands: List[str]) -> Quest: diff --git a/textworld/generator/text_generation.py b/textworld/generator/text_generation.py index 3b1cdfca..48a91915 100644 --- a/textworld/generator/text_generation.py +++ b/textworld/generator/text_generation.py @@ -4,8 +4,9 @@ import re from collections import OrderedDict +from typing import Union, Iterable -from textworld.generator.game import Quest, Event, Game +from textworld.generator.game import Quest, EventCondition, EventAction, EventAnd, EventOr, Game from textworld.generator.text_grammar import Grammar from textworld.generator.text_grammar import fix_determinant @@ -380,16 +381,305 @@ def generate_instruction(action, grammar, game, counts): return desc, separator +def make_str(txt): + Text = [] + for t in txt: + if len(t) > 0: + Text += [t] + return Text + + +def quest_counter(counter): + if counter == 0: + return '' + elif counter == 1: + return 'First' + elif counter == 2: + return 'Second' + elif counter == 3: + return 'Third' + else: + return str(counter) + 'th' + + +def describe_quests(game: Game, grammar: Grammar): + counter = 1 + quests_desc_arr = [] + for quest in game.quests: + if quest.desc: + quests_desc_arr.append("The " + quest_counter(counter) + " quest: \n" + quest.desc) + counter += 1 + + quests_desc_arr + if quests_desc_arr: + quests_desc_ = " \n ".join(txt for txt in quests_desc_arr if txt) + quests_desc_ = ": \n " + quests_desc_ + " \n *** " + quests_tag = grammar.get_random_expansion("#all_quests#") + quests_tag = quests_tag.replace("(quests_string)", quests_desc_.strip()) + quests_description = grammar.expand(quests_tag) + quests_description = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", + lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), + quests_description) + else: + quests_tag = grammar.get_random_expansion("#all_quests_non#") + quests_description = grammar.expand(quests_tag) + quests_description = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", + lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), + quests_description) + return quests_description + + def assign_description_to_quest(quest: Quest, game: Game, grammar: Grammar): - event_descriptions = [] - for event in quest.win_events: - event_descriptions += [describe_event(event, game, grammar)] + desc = [] + indx = '> ' + for event in quest.win_events[0].events: + if isinstance(event, EventCondition) or isinstance(event, EventAction): + st = assign_description_to_event(event, game, grammar) + else: + st = assign_description_to_combined_events(event, game, grammar, indx) + + if st: + desc += [st] - quest_desc = " OR ".join(desc for desc in event_descriptions if desc) + if quest.reward < 0: + return describe_punishing_quest(make_str(desc), grammar, indx) + else: + return describe_quest(make_str(desc), quest.win_events[0], grammar, indx) + + +def describe_punishing_quest(quest_desc: Iterable[str], grammar: Grammar, index_symbol='> '): + if len(quest_desc) == 0: + description = describe_punishing_quest_none(grammar) + else: + description = describe_punishing_quest(quest_desc, grammar, index_symbol) + + return description + + +def describe_punishing_quest_none(grammar: Grammar): + quest_tag = grammar.get_random_expansion("#punishing_quest_none#") + quest_desc = grammar.expand(quest_tag) + quest_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", + lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), + quest_desc) return quest_desc -def describe_event(event: Event, game: Game, grammar: Grammar) -> str: +def describe_punishing_quest(quest_desc: Iterable[str], grammar: Grammar, index_symbol) -> str: + only_one_task = len(quest_desc) < 2 + quest_desc = [index_symbol + desc for desc in quest_desc if desc] + quest_txt = " \n ".join(desc for desc in quest_desc if desc) + quest_txt = ": \n " + quest_txt + + if only_one_task: + quest_tag = grammar.get_random_expansion("#punishing_quest_one_task#") + quest_tag = quest_tag.replace("(combined_task)", quest_txt.strip()) + else: + quest_tag = grammar.get_random_expansion("#punishing_quest_tasks#") + quest_tag = quest_tag.replace("(list_of_combined_tasks)", quest_txt.strip()) + + description = grammar.expand(quest_tag) + description = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", + lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), + description) + return description + + +def describe_quest(quest_desc: Iterable[str], combination_rule: Iterable[Union[EventOr, EventAnd]], + grammar: Grammar, index_symbol='> '): + if len(quest_desc) == 0: + description = describe_quest_none(grammar) + else: + if isinstance(combination_rule, EventOr): + description = describe_quest_or(quest_desc, grammar, index_symbol) + elif isinstance(combination_rule, EventAnd): + description = describe_quest_and(quest_desc, grammar, index_symbol) + + return description + + +def describe_quest_none(grammar: Grammar): + quest_tag = grammar.get_random_expansion("#quest_none#") + quest_desc = grammar.expand(quest_tag) + quest_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", + lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), + quest_desc) + return quest_desc + + +def describe_quest_or(quest_desc: Iterable[str], grammar: Grammar, index_symbol) -> str: + only_one_task = len(quest_desc) < 2 + quest_desc = [index_symbol + desc for desc in quest_desc if desc] + quest_txt = " \n ".join(desc for desc in quest_desc if desc) + quest_txt = ": \n " + quest_txt + + if only_one_task: + quest_tag = grammar.get_random_expansion("#quest_one_task#") + quest_tag = quest_tag.replace("(combined_task)", quest_txt.strip()) + else: + quest_tag = grammar.get_random_expansion("#quest_or_tasks#") + quest_tag = quest_tag.replace("(list_of_combined_tasks)", quest_txt.strip()) + + description = grammar.expand(quest_tag) + description = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", + lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), + description) + return description + + +def describe_quest_and(quest_desc: Iterable[str], grammar: Grammar, index_symbol) -> str: + only_one_task = len(quest_desc) < 2 + quest_desc = [index_symbol + desc for desc in quest_desc if desc] + quest_txt = " \n ".join(desc for desc in quest_desc if desc) + quest_txt = ": \n " + quest_txt + + if only_one_task: + quest_tag = grammar.get_random_expansion("#quest_one_task#") + quest_tag = quest_tag.replace("(combined_task)", quest_txt.strip()) + else: + quest_tag = grammar.get_random_expansion("#quest_and_tasks#") + quest_tag = quest_tag.replace("(list_of_combined_tasks)", quest_txt.strip()) + + description = grammar.expand(quest_tag) + description = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", + lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), + description) + return description + + +def assign_description_to_combined_events(events: Union[EventAnd, EventOr], game: Game, grammar: Grammar, index_symbol, + _desc=[]): + if isinstance(events, EventCondition) or isinstance(events, EventAction): + _desc += [assign_description_to_event(events, game, grammar)] + return + + index_symbol = '-' + index_symbol + desc, ev_type = [], [] + for event in events.events: + st = assign_description_to_combined_events(event, game, grammar, index_symbol, desc) + ev_type.append(isinstance(event, EventCondition) or isinstance(event, EventAction)) + + if st: + desc += [st] + + if all(ev_type): + st1 = combine_events(make_str(desc), events, grammar) + else: + st1 = combine_tasks(make_str(desc), events, grammar, index_symbol) + + return st1 + + +def combine_events(events: Iterable[str], combination_rule: Iterable[Union[EventOr, EventAnd]], grammar: Grammar): + if len(events) == 0: + events_desc = "" + else: + if isinstance(combination_rule, EventOr): + events_desc = describe_event_or(events, grammar) + elif isinstance(combination_rule, EventAnd): + events_desc = describe_event_and(events, grammar) + + return events_desc + + +def describe_event_or(events_desc: Iterable[str], grammar: Grammar) -> str: + only_one_event = len(events_desc) < 2 + combined_event_txt = " , or, ".join(desc for desc in events_desc if desc) + combined_event_txt = ": " + combined_event_txt + + if only_one_event: + combined_event_tag = grammar.get_random_expansion("#combined_one_event#") + combined_event_tag = combined_event_tag.replace("(only_event)", combined_event_txt.strip()) + else: + combined_event_tag = grammar.get_random_expansion("#combined_or_events#") + combined_event_tag = combined_event_tag.replace("(list_of_events)", combined_event_txt.strip()) + + combined_event_desc = grammar.expand(combined_event_tag) + combined_event_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", + lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), + combined_event_desc) + + return combined_event_desc + + +def describe_event_and(events_desc: Iterable[str], grammar: Grammar) -> str: + only_one_event = len(events_desc) < 2 + combined_event_txt = " , and, ".join(desc for desc in events_desc if desc) + combined_event_txt = ": " + combined_event_txt + + if only_one_event: + combined_event_tag = grammar.get_random_expansion("#combined_one_event#") + combined_event_tag = combined_event_tag.replace("(only_event)", combined_event_txt.strip()) + else: + combined_event_tag = grammar.get_random_expansion("#combined_and_events#") + combined_event_tag = combined_event_tag.replace("(list_of_events)", combined_event_txt.strip()) + + combined_event_desc = grammar.expand(combined_event_tag) + combined_event_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", + lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), + combined_event_desc) + + return combined_event_desc + + +def combine_tasks(tasks: Iterable[str], combination_rule: Iterable[Union[EventOr, EventAnd]], + grammar: Grammar, index_symbol: str): + if len(tasks) == 0: + tasks_desc = "" + else: + if isinstance(combination_rule, EventOr): + tasks_desc = describe_tasks_or(tasks, grammar, index_symbol) + if isinstance(combination_rule, EventAnd): + tasks_desc = describe_tasks_and(tasks, grammar, index_symbol) + + return tasks_desc + + +def describe_tasks_and(tasks_desc: Iterable[str], grammar: Grammar, index_symbol: str) -> str: + only_one_task = len(tasks_desc) < 2 + tasks_desc = [index_symbol + desc for desc in tasks_desc if desc] + tasks_txt = " \n ".join(desc for desc in tasks_desc if desc) + tasks_txt = ": \n " + tasks_txt + + if only_one_task: + combined_task_tag = grammar.get_random_expansion("#combined_one_task#") + combined_task_tag = combined_task_tag.replace("(only_task)", tasks_txt.strip()) + else: + combined_task_tag = grammar.get_random_expansion("#combined_and_tasks#") + combined_task_tag = combined_task_tag.replace("(list_of_tasks)", tasks_txt.strip()) + + combined_task_desc = grammar.expand(combined_task_tag) + combined_task_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", + lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), + combined_task_desc) + return combined_task_desc + + +def describe_tasks_or(tasks_desc: Iterable[str], grammar: Grammar, index_symbol: str) -> str: + only_one_task = len(tasks_desc) < 2 + tasks_desc = [index_symbol + desc for desc in tasks_desc if desc] + tasks_txt = " \n ".join(desc for desc in tasks_desc if desc) + tasks_txt = ": \n " + tasks_txt + + if only_one_task: + combined_task_tag = grammar.get_random_expansion("#combined_one_task#") + combined_task_tag = combined_task_tag.replace("(only_task)", tasks_txt.strip()) + else: + combined_task_tag = grammar.get_random_expansion("#combined_and_tasks#") + combined_task_tag = combined_task_tag.replace("(list_of_tasks)", tasks_txt.strip()) + + combined_task_desc = grammar.expand(combined_task_tag) + combined_task_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", + lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), + combined_task_desc) + return combined_task_desc + + +def assign_description_to_event(events: Union[EventAction, EventCondition], game: Game, grammar: Grammar): + return describe_event(events, game, grammar) + + +def describe_event(event: Union[EventCondition, EventAction], game: Game, grammar: Grammar) -> str: """ Assign a descripton to a quest. """ @@ -424,7 +714,7 @@ def describe_event(event: Event, game: Game, grammar: Grammar) -> str: if grammar.options.blend_instructions: instructions = get_action_chains(event.actions, grammar, game) else: - instructions = event.actions + instructions = [act for act in event.actions] only_one_action = len(instructions) < 2 for c in instructions: @@ -434,19 +724,13 @@ def describe_event(event: Event, game: Game, grammar: Grammar) -> str: actions_desc_list.append(separator) actions_desc = " ".join(actions_desc_list) - if only_one_action: - quest_tag = grammar.get_random_expansion("#quest_one_action#") - quest_tag = quest_tag.replace("(action)", actions_desc.strip()) + event_tag = grammar.get_random_expansion("#event#") + event_tag = event_tag.replace("(list_of_actions)", actions_desc.strip()) - else: - quest_tag = grammar.get_random_expansion("#quest#") - quest_tag = quest_tag.replace("(list_of_actions)", actions_desc.strip()) - - event_desc = grammar.expand(quest_tag) + event_desc = grammar.expand(event_tag) event_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), event_desc) - return event_desc diff --git a/textworld/generator/world.py b/textworld/generator/world.py index f96d5eab..095af9fc 100644 --- a/textworld/generator/world.py +++ b/textworld/generator/world.py @@ -256,9 +256,9 @@ def _process_rooms(self) -> None: room = self._get_room(fact.arguments[0]) room.add_related_fact(fact) - if fact.name.endswith("_of"): + if fact.definition.endswith("_of"): # Handle room positioning facts. - exit = reverse_direction(fact.name.split("_of")[0]) + exit = reverse_direction(fact.definition.split("_of")[0]) dest = self._get_room(fact.arguments[1]) dest.add_related_fact(fact) assert exit not in room.exits diff --git a/textworld/logic/__init__.py b/textworld/logic/__init__.py index 6cf7b0a5..32dcc87c 100644 --- a/textworld/logic/__init__.py +++ b/textworld/logic/__init__.py @@ -136,10 +136,10 @@ def walk_VariableNode(self, node): return self._walk_variable_ish(node, Variable) def walk_SignatureNode(self, node): - return Signature(node.name, node.types, node.verb, node.definition) + return Signature(node.name, node.types) def walk_PropositionNode(self, node): - return Proposition(node.name, self.walk(node.arguments), node.verb, node.definition) + return Proposition(node.name, self.walk(node.arguments)) def walk_ActionNode(self, node): return self._walk_action_ish(node, Action) @@ -148,7 +148,7 @@ def walk_PlaceholderNode(self, node): return self._walk_variable_ish(node, Placeholder) def walk_PredicateNode(self, node): - return Predicate(node.name, self.walk(node.parameters), node.verb, node.definition) + return Predicate(node.name, self.walk(node.parameters)) def walk_RuleNode(self, node): return self._walk_action_ish(node, Rule) @@ -554,9 +554,10 @@ def deserialize(cls, data: Mapping) -> "Variable": lambda cls, args, kwargs: ( cls, kwargs.get("name", args[0] if len(args) >= 1 else None), - tuple(kwargs.get("types", args[1] if len(args) >= 2 else [])), - kwargs.get("verb", args[2] if len(args) >= 3 else None), - kwargs.get("definition", args[3] if len(args) == 4 else None), + tuple(kwargs.get("types", args[1] if len(args) == 2 else [])) + # tuple(kwargs.get("types", args[1] if len(args) >= 2 else [])), + # kwargs.get("verb", args[2] if len(args) >= 3 else None), + # kwargs.get("definition", args[3] if len(args) == 4 else None), ) ) @@ -569,7 +570,7 @@ class Signature(with_metaclass(SignatureTracker, object)): __slots__ = ("name", "types", "_hash", "verb", "definition") - def __init__(self, name: str, types: Iterable[str], verb=None, definition=None): + def __init__(self, name: str, types: Iterable[str]): """ Create a Signature. @@ -580,22 +581,21 @@ def __init__(self, name: str, types: Iterable[str], verb=None, definition=None): types : The types of the parameters to the proposition/predicate. """ - if (not verb and definition) or (verb and not definition): - raise UnderspecifiedSignatureError if name.count('__') == 0: - verb = "is" - definition = name - name = "is__"+name + self.verb = "is" + self.definition = name + self.name = "is__" + name else: - verb = name[:name.find('__')] - definition = name[name.find('__') + 2:] + self.verb = name[:name.find('__')] + self.definition = name[name.find('__') + 2:] + self.name = name - self.name = name + # self.name = name self.types = tuple(types) - self.verb = verb - self.definition = definition - self._hash = hash((self.name, self.types, self.verb, self.definition)) + # self.verb = verb + # self.definition = definition + self._hash = hash((self.name, self.types)) def __str__(self): return "{}({})".format(self.name, ", ".join(map(str, self.types))) @@ -605,7 +605,7 @@ def __repr__(self): def __eq__(self, other): if isinstance(other, Signature): - return self.name == other.name and self.types == other.types and self.verb == other.verb and self.definition == other.definition + return self.name == other.name and self.types == other.types else: return NotImplemented @@ -636,11 +636,11 @@ def parse(cls, expr: str) -> "Signature": lambda cls, args, kwargs: ( cls, kwargs.get("name", args[0] if len(args) >= 1 else None), - tuple(v.name for v in kwargs.get("arguments", args[1] if len(args) >= 2 else [])), - kwargs.get("verb", args[2] if len(args) >= 3 else None), - kwargs.get("definition", args[3] if len(args) >= 4 else None), - # kwargs.get("activate", 0) - kwargs.get("activate", args[4] if len(args) == 5 else 0) + tuple(v.name for v in kwargs.get("arguments", args[1] if len(args) == 2 else [])), + # tuple(v.name for v in kwargs.get("arguments", args[1] if len(args) >= 2 else [])), + # kwargs.get("verb", args[2] if len(args) >= 3 else None), + # kwargs.get("definition", args[3] if len(args) >= 4 else None), + # kwargs.get("activate", args[4] if len(args) == 5 else 0) ) ) @@ -653,8 +653,7 @@ class Proposition(with_metaclass(PropositionTracker, object)): __slots__ = ("name", "arguments", "signature", "_hash", "verb", "definition", "activate") - def __init__(self, name: str, arguments: Iterable[Variable] = [], verb: str = None, definition: str = None, - activate: int = 0): + def __init__(self, name: str, arguments: Iterable[Variable] = []): """ Create a Proposition. @@ -666,28 +665,26 @@ def __init__(self, name: str, arguments: Iterable[Variable] = [], verb: str = No The variables this proposition is applied to. """ - if (not verb and definition) or (verb and not definition): - raise UnderspecifiedPropositionError - if name.count('__') == 0: - verb = "is" - definition = name - name = "is__"+name + self.verb = "is" + self.definition = name + self.name = "is__" + name else: - verb = name[:name.find('__')].replace('_', ' ') - definition = name[name.find('__') + 2:] + self.verb = name[:name.find('__')].replace('_', ' ') + self.definition = name[name.find('__') + 2:] + self.name = name - self.name = name + # self.name = name self.arguments = tuple(arguments) - self.verb = verb - self.definition = definition - self.signature = Signature(name, [var.type for var in self.arguments], self.verb, self.definition) - self._hash = hash((self.name, self.arguments, self.verb, self.definition)) - - if self.verb == 'is': - activate = 1 + # self.verb = verb + # self.definition = definition + self.signature = Signature(name, [var.type for var in self.arguments]) + self._hash = hash((self.name, self.arguments)) - self.activate = activate + # if self.verb == 'is': + # activate = 1 + # + # self.activate = activate @property def names(self) -> Collection[str]: @@ -711,8 +708,7 @@ def __repr__(self): def __eq__(self, other): if isinstance(other, Proposition): - return (self.name, self.arguments, self.verb, self.definition, self.activate) == \ - (other.name, other.arguments, other.verb, other.definition, other.activate) + return (self.name, self.arguments) == (other.name, other.arguments) else: return NotImplemented @@ -741,19 +737,18 @@ def serialize(self) -> Mapping: return { "name": self.name, "arguments": [var.serialize() for var in self.arguments], - "verb": self.verb, - "definition": self.definition, - "activate": self.activate + # "verb": self.verb, + # "definition": self.definition, } @classmethod def deserialize(cls, data: Mapping) -> "Proposition": name = data["name"] args = [Variable.deserialize(arg) for arg in data["arguments"]] - verb = data["verb"] - definition = data["definition"] - activate = data["activate"] - return cls(name, args, verb, definition, activate) + # verb = data["verb"] + # definition = data["definition"] + # activate = data["activate"] + return cls(name, args) @total_ordering @@ -837,7 +832,7 @@ class Predicate: A boolean-valued function over variables. """ - def __init__(self, name: str, parameters: Iterable[Placeholder], verb=None, definition=None): + def __init__(self, name: str, parameters: Iterable[Placeholder]): """ Create a Predicate. @@ -848,22 +843,21 @@ def __init__(self, name: str, parameters: Iterable[Placeholder], verb=None, defi parameters : The symbolic arguments to this predicate. """ - if (not verb and definition) or (verb and not definition): - raise UnderspecifiedPredicateError if name.count('__') == 0: - verb = "is" - definition = name - name = "is__" + name + self.verb = "is" + self.definition = name + self.name = "is__" + name else: - verb = name[:name.find('__')] - definition = name[name.find('__') + 2:] + self.verb = name[:name.find('__')] + self.definition = name[name.find('__') + 2:] + self.name = name - self.name = name + # self.name = name self.parameters = tuple(parameters) - self.verb = verb - self.definition = definition - self.signature = Signature(name, [ph.type for ph in self.parameters], self.verb, self.definition) + # self.verb = verb + # self.definition = definition + self.signature = Signature(name, [ph.type for ph in self.parameters]) @property def names(self) -> Collection[str]: @@ -887,12 +881,12 @@ def __repr__(self): def __eq__(self, other): if isinstance(other, Predicate): - return (self.name, self.parameters, self.verb, self.definition) == (other.name, other.parameters, other.verb, other.definition) + return (self.name, self.parameters) == (other.name, other.parameters) else: return NotImplemented def __hash__(self): - return hash((self.name, self.types, self.verb, self.definition)) + return hash((self.name, self.types)) def __lt__(self, other): if isinstance(other, Predicate): @@ -916,17 +910,18 @@ def serialize(self) -> Mapping: return { "name": self.name, "parameters": [ph.serialize() for ph in self.parameters], - "verb": self.verb, - "definition": self.definition + # "verb": self.verb, + # "definition": self.definition } @classmethod def deserialize(cls, data: Mapping) -> "Predicate": name = data["name"] params = [Placeholder.deserialize(ph) for ph in data["parameters"]] - verb = data["verb"] - definition = data["definition"] - return cls(name, params, verb, definition) + # verb = data["verb"] + # definition = data["definition"] + # return cls(name, params, verb, definition) + return cls(name, params) def substitute(self, mapping: Mapping[Placeholder, Placeholder]) -> "Predicate": """ @@ -939,7 +934,8 @@ def substitute(self, mapping: Mapping[Placeholder, Placeholder]) -> "Predicate": """ params = [mapping.get(param, param) for param in self.parameters] - return Predicate(self.name, params, self.verb, self.definition) + # return Predicate(self.name, params, self.verb, self.definition) + return Predicate(self.name, params) def instantiate(self, mapping: Mapping[Placeholder, Variable]) -> Proposition: """ @@ -956,13 +952,8 @@ def instantiate(self, mapping: Mapping[Placeholder, Variable]) -> Proposition: """ args = [mapping[param] for param in self.parameters] - return Proposition(self.name, arguments=args, verb=self.verb, definition=self.definition) - - # args = [mapping[param] for param in self.parameters] - # if Proposition.name == 'event': - # return Proposition(self.name, args, verb=special.verb, definition=special.definition) - # else: - # return Proposition(self.name, args) + # return Proposition(self.name, arguments=args, verb=self.verb, definition=self.definition) + return Proposition(self.name, arguments=args) def match(self, proposition: Proposition) -> Optional[Mapping[Placeholder, Variable]]: """ @@ -1160,8 +1151,9 @@ def activate_traceable(self): if not prop.name.startswith('is__'): prop.activate = 1 - def is_valid(self): - return all([prop.activate == 1 for prop in self.all_propositions]) + # def is_valid(self): + # aa = self.all_propositions + # return all([prop.activate == 1 for prop in self.all_propositions]) class Rule: @@ -1567,19 +1559,16 @@ def _predicate_diversity(self): new_preds = [] for pred in self.predicates: for v in ['was', 'has been', 'had been']: - new_preds.append(Signature(name=v.replace(' ', '_') + pred.name[pred.name.find('__'):], types=pred.types, - verb=v, definition=pred.definition)) + new_preds.append(Signature(name=v.replace(' ', '_') + pred.name[pred.name.find('__'):], types=pred.types)) self.predicates.update(set(new_preds)) def _inform7_predicates_diversity(self): new_preds = {} for k, v in self.inform7.predicates.items(): for vt in ['was', 'has been', 'had been']: - new_preds[Signature(name=vt.replace(' ', '_') + k.name[k.name.find('__'):], types=k.types, - verb=vt, definition=k.definition)] = \ + new_preds[Signature(name=vt.replace(' ', '_') + k.name[k.name.find('__'):], types=k.types)] = \ Inform7Predicate(predicate=Predicate(name=vt.replace(' ', '_') + v.predicate.name[v.predicate.name.find('__'):], - parameters=v.predicate.parameters, verb=vt, - definition=v.predicate.definition), + parameters=v.predicate.parameters), source=v.source.replace('is', vt)) self.inform7.predicates.update(new_preds) @@ -1703,10 +1692,6 @@ def are_facts(self, props: Iterable[Proposition]) -> bool: for prop in props: if not self.is_fact(prop): return False - - if not prop.activate: - return False - return True @property diff --git a/textworld/logic/parser.py b/textworld/logic/parser.py index e86ae355..d18db13a 100644 --- a/textworld/logic/parser.py +++ b/textworld/logic/parser.py @@ -117,20 +117,33 @@ def _variable_(self): # noqa [] ) + def _predVT_(self, name): + self._constant(name[:name.find('__')].replace('_', ' ')) + + def _predDef_(self, name): + self._constant(name[name.find('__') + 2:]) + @tatsumasu('SignatureNode') def _signature_(self): # noqa + self._predName_() + if self.cst.count('__') == 0: + self.last_node = 'is__' + self.cst self.name_last_node('name') + # self._predVT_(self.ast['name']) + # self.name_last_node('verb') + # self._predDef_(self.ast['name']) + # self.name_last_node('definition') self._token('(') def sep2(): self._token(',') - def block2(): self._name_() self._gather(block2, sep2) self.name_last_node('types') self._token(')') + self.ast._define( ['name', 'types'], [] @@ -139,17 +152,23 @@ def block2(): @tatsumasu('PropositionNode') def _proposition_(self): # noqa self._predName_() + if self.cst.count('__') == 0: + self.last_node = 'is__' + self.cst self.name_last_node('name') + # self._predVT_(self.ast['name']) + # self.name_last_node('verb') + # self._predDef_(self.ast['name']) + # self.name_last_node('definition') self._token('(') def sep2(): self._token(',') - def block2(): self._variable_() self._gather(block2, sep2) self.name_last_node('arguments') self._token(')') + self.ast._define( ['arguments', 'name'], [] @@ -210,17 +229,23 @@ def _placeholder_(self): # noqa @tatsumasu('PredicateNode') def _predicate_(self): # noqa self._predName_() + if self.cst.count('__') == 0: + self.last_node = 'is__' + self.cst self.name_last_node('name') + # self._predVT_(self.ast['name']) + # self.name_last_node('verb') + # self._predDef_(self.ast['name']) + # self.name_last_node('definition') self._token('(') def sep2(): self._token(',') - def block2(): self._placeholder_() self._gather(block2, sep2) self.name_last_node('parameters') self._token(')') + self.ast._define( ['name', 'parameters'], [] From dbf47fa97cc52d72d38ed6290f98cd41969db5fb Mon Sep 17 00:00:00 2001 From: HakiRose Date: Mon, 11 May 2020 20:32:41 -0400 Subject: [PATCH 4/5] Updates on new TextWorld framework structure include: updates & developments on test files, some updates on the frame work, semi-final updates on Content Check Game, etc. --- .../data/text_grammars/house_quests.twg | 56 +++ textworld/generator/game.py | 278 +++----------- textworld/generator/inform7/world2inform7.py | 87 ----- textworld/generator/maker.py | 119 +++--- textworld/generator/tests/test_game.py | 346 +++++++++++++----- textworld/generator/tests/test_maker.py | 12 +- .../generator/tests/test_text_generation.py | 7 +- textworld/generator/vtypes.py | 12 +- textworld/logic/__init__.py | 4 + 9 files changed, 441 insertions(+), 480 deletions(-) create mode 100644 textworld/generator/data/text_grammars/house_quests.twg diff --git a/textworld/generator/data/text_grammars/house_quests.twg b/textworld/generator/data/text_grammars/house_quests.twg new file mode 100644 index 00000000..b220d1c4 --- /dev/null +++ b/textworld/generator/data/text_grammars/house_quests.twg @@ -0,0 +1,56 @@ +#------------------------- +#Quests Grammar +#------------------------- + +punishing_quest_none:#punishing_prologue_none# +punishing_prologue_none: in this quest, there are some activities that if you do them, you will be punished. You will find out what are those.;your activities matter in this quest, some of them will have punishment for you. Explore the environment.;your activities changes the state of the game. There is a state which is dangerous for you, you\'ll be charged if you get there. + +punishing_quest_one_task:#punishing_prologue# (combined_task) + +punishing_quest_tasks:#punishing_prologue# #AndOr# (list_of_combined_tasks) +AndOr:complete any of the following tasks;do none of the following tasks;do any single of the following tasks; commit any of the following tasks + +punishing_prologue: your mission for this quest is not to; your task for this quest is not to; there is something I need you to be careful about it. Please never; your objective is not to; please hesitate to + +quest_none:#prologue_none# +prologue_none: there are some activities I need you to do for me in this quest. You will find out what are those.;your activities matter in this quest. Explore the environment.;your activities changes the state of the game. There is a state which is important to this quest. + +quest_one_task:#prologue# (combined_task) + +quest_and_tasks:#prologue# #And# (list_of_combined_tasks) +And:complete all the following tasks;do every single of the following tasks; finish every single of the following tasks + +quest_or_tasks:#prologue# #Or# (list_of_combined_tasks) +Or:complete any of the following tasks;do at least one of the following tasks; finish one or more of the following tasks + +prologue: your mission for this quest is to; your task for this quest is to; there is something I need you to do for me. Please ; your objective is to; please + +combined_one_task:#prologue_combined_one_task# (only_task) +prologue_combined_one_task:there is only one event to do in this task;the only objective event here is; + +combined_and_tasks:#prologue_combined_and_tasks# (list_of_tasks) +prologue_combined_and_tasks:make sure to complete all these events in this task;complete all the events for this task;do all these events; all these events need to be done in this task + +combined_or_tasks:#prologue_combined_or_tasks# (list_of_tasks) +prologue_combined_or_tasks:make sure to complete any of these events in this task;complete at least one of the following events in this in this task;do minimum one of these events here; at least one of these events need to be done for this task + +combined_one_event:#prologue_combined_one_event# (only_event) +prologue_combined_one_event:The event includes the following event;the objective of this event is given as follows + +combined_and_events:#prologue_combined_and_events# (list_of_events) +prologue_combined_and_events:complete all these actions;do all these actions; all these actions need to be done + +combined_or_events:#prologue_combined_or_events# (list_of_events) +prologue_combined_or_events:complete any of these actions;doing any of these actions is sufficient;Completing any from the following actions is okay + +event:(list_of_actions) + + +all_quests_non: #prologue_quests# #epilogue# + +all_quests: #prologue_quests# (quests_string) #epilogue# + +prologue_quests:#welcome#;Hey, thanks for coming over to the TextWorld today, there is something I need you to do for me.;Hey, thanks for coming over to TextWorld! Your activity matters here. +epilogue:Once that's handled, you can stop!;You will know when you can stop!;And once you've done, you win!;You're the winner when all activities are taken!;Got that? Great!;Alright, thanks! +welcome:Welcome to TextWorld;You are now playing a #exciting# #game# of TextWorld Spaceship;Welcome to another #exciting# #game# of TextWorld;It's time to explore the amazing world of TextWorld Galaxy;Get ready to pick stuff up and put it in places, because you've just entered TextWorld shuttle;I hope you're ready to go into rooms and interact with objects, because you've just entered TextWorld shuttle;Who's got a virtual machine and is about to play through an #exciting# round of TextWorld? You do; +game:game;round;session;episode diff --git a/textworld/generator/game.py b/textworld/generator/game.py index 8a5c50dd..4b731bc9 100644 --- a/textworld/generator/game.py +++ b/textworld/generator/game.py @@ -43,7 +43,7 @@ def __init__(self): class UnderspecifiedEventActionError(NameError): def __init__(self): - msg = "No action is defined, action is required to create an event." + msg = "The EventAction includes ONLY one action." super().__init__(msg) @@ -182,10 +182,6 @@ def serialize(self) -> Mapping: """ return {"commands_Event": self.commands, "actions_Event": [action.serialize() for action in self.actions]} - # data = {} - # data["commands"] = self.commands - # data["actions"] = [action.serialize() for action in self.actions] - # return data def copy(self) -> "Event": """ Copy this event. """ @@ -286,12 +282,6 @@ def serialize(self) -> Mapping: "actions_EventCondition": [action.serialize() for action in self.actions], "condition_EventCondition": self.condition.serialize(), "verb_tense_EventCondition": self.verb_tense} - # data = {} - # data["commands"] = self.commands - # data["actions"] = [action.serialize() for action in self.actions] - # data["condition"] = self.condition.serialize() - # data["verb_tense"] = self.verb_tense - # return data def copy(self) -> "EventCondition": """ Copy this event. """ @@ -311,6 +301,9 @@ def __init__(self, actions: Iterable[Action] = (), """ super(EventAction, self).__init__(actions, commands) + if self.is_valid(): + raise UnderspecifiedEventActionError + self.verb_tense = verb_tense self.traceable = self.set_actions() @@ -329,7 +322,7 @@ def set_actions(self): return [prop for ar in traceable for prop in ar] def is_valid(self): - return len(self.actions) != 0 + return len(self.actions) != 1 def is_triggering(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: """ Check if this event would be triggered for a given action. """ @@ -377,10 +370,6 @@ def serialize(self) -> Mapping: "commands_EventAction": self.commands, "verb_tense_EventAction": self.verb_tense, } - # return {"actions": [action.serialize() for action in self.actions], - # "commands": self.commands, - # "verb_tense": self.verb_tense, - # } def copy(self) -> "EventAction": """ Copy this event. """ @@ -388,7 +377,7 @@ def copy(self) -> "EventAction": class EventOr: - def __init__(self, events=()): + def __init__(self, events: Tuple =()): self.events = events self._any_triggered = False self._any_untriggered = False @@ -402,23 +391,14 @@ def events(self, events) -> None: self._events = tuple(events) def are_triggering(self, state, action): - # status_i, status_t = [], [] status = [] for ev in self.events: if isinstance(ev, EventCondition) or isinstance(ev, EventAction): status.append(ev.is_triggering(state, [action])) - # status_i.append(ev.is_triggering(state, action)) continue status.append(ev.are_triggering(state, action)) - # status_t.append(ev.are_triggering(state, action)) return any(status) - # status = [] - # for ev in self.events: - # if isinstance(ev, EventCondition) or isinstance(ev, EventAction): - # status.append(ev.is_triggering(state, action)) - # status - # return any(status) def are_events_triggered(self, state, action): return any((ev.is_triggering(state, action) for ev in self.events)) @@ -466,7 +446,7 @@ def copy(self) -> "EventOr": class EventAnd: - def __init__(self, events=()): + def __init__(self, events: Tuple = ()): self.events = events self._all_triggered = False self._all_untriggered = False @@ -480,17 +460,12 @@ def events(self, events) -> None: self._events = tuple(events) def are_triggering(self, state, action): - # status_i, status_t = [], [] status = [] for ev in self.events: if isinstance(ev, EventCondition) or isinstance(ev, EventAction): status.append(ev.is_triggering(state, [action])) - # status_i.append(ev.is_triggering(state, action)) continue status.append(ev.are_triggering(state, action)) - # status_t.append(ev.are_triggering(state, action)) - # status_i - # status_t return all(status) def are_events_triggered(self, state, action): @@ -656,11 +631,11 @@ def events_organizer(self, combined_events=()): def is_winning(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: """ Check if this quest is winning in that particular state. """ - return any(event.is_triggering(state, actions) for event in self.win_events) + return any(event.are_triggering(state, actions) for event in self.win_events) def is_failing(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: """ Check if this quest is failing in that particular state. """ - return any(event.is_triggering(state, actions) for event in self.fail_events) + return any(event.are_triggering(state, actions) for event in self.fail_events) def __hash__(self) -> int: return hash((self.win_events, self.fail_events, self.reward, self.desc, self.commands)) @@ -718,154 +693,6 @@ def copy(self) -> "Quest": """ Copy this quest. """ return self.deserialize(self.serialize()) -# class Quest: -# """ Quest representation in TextWorld. -# -# A quest is defined by a mutually exclusive set of winning events and -# a mutually exclusive set of failing events. -# -# Attributes: -# win_events: Mutually exclusive set of winning events. That is, -# only one such event needs to be triggered in order -# to complete this quest. -# fail_events: Mutually exclusive set of failing events. That is, -# only one such event needs to be triggered in order -# to fail this quest. -# reward: Reward given for completing this quest. -# desc: A text description of the quest. -# commands: List of text commands leading to this quest completion. -# """ -# -# def __init__(self, -# win_events: Iterable[Union[Event, EventCondition, EventAction]] = (), -# fail_events: Iterable[Union[Event, EventCondition, EventAction]] = (), -# reward: Optional[int] = None, -# desc: Optional[str] = None, -# commands: Iterable[str] = ()) -> None: -# r""" -# Args: -# win_events: Mutually exclusive set of winning events. That is, -# only one such event needs to be triggered in order -# to complete this quest. -# fail_events: Mutually exclusive set of failing events. That is, -# only one such event needs to be triggered in order -# to fail this quest. -# reward: Reward given for completing this quest. By default, -# reward is set to 1 if there is at least one winning events -# otherwise it is set to 0. -# desc: A text description of the quest. -# commands: List of text commands leading to this quest completion. -# """ -# self.win_events = tuple(win_events) -# self.fail_events = tuple(fail_events) -# self.desc = desc -# self.commands = tuple(commands) -# -# # Unless explicitly provided, reward is set to 1 if there is at least -# # one winning events otherwise it is set to 0. -# self.reward = int(len(win_events) > 0) if reward is None else reward -# -# if len(self.win_events) == 0 and len(self.fail_events) == 0: -# raise UnderspecifiedQuestError() -# -# @property -# def win_events(self) -> Iterable[Union[Event, EventCondition, EventAction]]: -# return self._win_events -# -# @win_events.setter -# def win_events(self, events: Iterable[Union[Event, EventCondition, EventAction]]) -> None: -# self._win_events = tuple(events) -# -# @property -# def fail_events(self) -> Iterable[Union[Event, EventCondition, EventAction]]: -# return self._fail_events -# -# @fail_events.setter -# def fail_events(self, events: Iterable[Union[Event, EventCondition, EventAction]]) -> None: -# self._fail_events = tuple(events) -# -# @property -# def commands(self) -> Iterable[str]: -# return self._commands -# -# @commands.setter -# def commands(self, commands: Iterable[str]) -> None: -# self._commands = tuple(commands) -# -# def is_winning(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: -# """ Check if this quest is winning in that particular state. """ -# -# return any(event.is_triggering(state, actions) for event in self.win_events) -# -# def is_failing(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: -# """ Check if this quest is failing in that particular state. """ -# return any(event.is_triggering(state, actions) for event in self.fail_events) -# -# def __hash__(self) -> int: -# return hash((self.win_events, self.fail_events, self.reward, self.desc, self.commands)) -# -# def __eq__(self, other: Any) -> bool: -# return (isinstance(other, Quest) -# and self.win_events == other.win_events -# and self.fail_events == other.fail_events -# and self.reward == other.reward -# and self.desc == other.desc -# and self.commands == other.commands) -# -# @classmethod -# def deserialize(cls, data: Mapping) -> "Quest": -# """ Creates a `Quest` from serialized data. -# -# Args: -# data: Serialized data with the needed information to build a -# `Quest` object. -# """ -# -# win_events = [] -# for d in data["win_events"]: -# if "action_verb_tense" in d.keys(): -# win_events.append(Event.deserialize(d)) -# -# elif "output_verb_tense" in d.keys() and "commands" in d.keys(): -# win_events.append(EventCondition.deserialize(d)) -# -# else: -# win_events.append(EventAction.deserialize(d)) -# -# fail_events = [] -# for d in data["fail_events"]: -# if "action_verb_tense" in d.keys(): -# fail_events.append(Event.deserialize(d)) -# -# elif "output_verb_tense" in d.keys() and "commands" in d.keys(): -# fail_events.append(EventCondition.deserialize(d)) -# -# else: -# fail_events.append(EventAction.deserialize(d)) -# -# commands = data.get("commands", []) -# reward = data["reward"] -# desc = data["desc"] -# return cls(win_events, fail_events, reward, desc, commands) -# -# def serialize(self) -> Mapping: -# """ Serialize this quest. -# -# Results: -# Quest's data serialized to be JSON compatible -# """ -# data = {} -# data["desc"] = self.desc -# data["reward"] = self.reward -# data["commands"] = self.commands -# data["win_events"] = [event.serialize() for event in self.win_events] -# data["fail_events"] = [event.serialize() for event in self.fail_events] -# return data -# -# def copy(self) -> "Quest": -# """ Copy this quest. """ -# return self.deserialize(self.serialize()) - class EntityInfo: """ Additional information about entities in the game. """ @@ -949,14 +776,11 @@ def __init__(self, world: World, grammar: Optional[Grammar] = None, self.quests = tuple(quests) self.metadata = {} self._objective = None - # self.objective self._infos = self._build_infos() self.kb = world.kb self.change_grammar(grammar) - - @property def infos(self) -> Dict[str, EntityInfo]: """ Information about the entities in the game. """ @@ -1003,33 +827,12 @@ def change_grammar(self, grammar: Grammar) -> None: # Check if we can derive a global winning policy from the quests. if self.grammar: - from textworld.generator.text_generation import describe_event policy = GameProgression(self).winning_policy if policy: mapping = {k: info.name for k, info in self._infos.items()} commands = [a.format_command(mapping) for a in policy] self.metadata["walkthrough"] = commands - self.objective = describe_event(EventCondition(policy), self, self.grammar) - - def command_generator(self, events, _gen_commands): - for event in events: - events.commands = _gen_commands(events.actions) - - # def command_generator(self, events, _gen_commands, quest): - # if isinstance(events, EventCondition) or isinstance(events, EventAction): - # events.commands = _gen_commands(events.actions) - # # quest.append(events.actions) - # quest.append(events) - # return - # - # act = [] - # for event in events.events: - # out = self.command_generator(event, _gen_commands, act) - # if out: - # for a in out: - # quest.append(a) - # - # return (len(act) > 0 and len(act) > len(quest)) * act or (len(quest) > 0 and len(quest) > len(act)) * quest + # self.objective = describe_event(EventCondition(actions=policy), self, self.grammar) def save(self, filename: str) -> None: """ Saves the serialized data of this game to a file. """ @@ -1300,7 +1103,7 @@ class EventProgression: relevant actions to be performed. """ - def __init__(self, event: Union[Event, EventCondition, EventAction], kb: KnowledgeBase) -> None: + def __init__(self, event: Union[EventAnd, EventOr], kb: KnowledgeBase) -> None: """ Args: quest: The quest to keep track of its completion. @@ -1312,33 +1115,39 @@ def __init__(self, event: Union[Event, EventCondition, EventAction], kb: Knowled self._policy = () # Build a tree representation of the quest. - self._tree = ActionDependencyTree(kb=self._kb, - element_type=ActionDependencyTreeElement) - - self.tree_policy(event) - # if not isinstance(event, EventAction) and not isinstance(event, Event): - # if len(event.actions) > 0: - # self._tree.push(event.condition) - # - # for action in event.actions[::-1]: - # self._tree.push(action) - # - # self._policy = event.actions + (event.condition,) + self._tree = ActionDependencyTree(kb=self._kb, element_type=ActionDependencyTreeElement) - def tree_policy(self, event): - if isinstance(event, EventCondition) or isinstance(event, EventAction): - if isinstance(event, EventCondition): - if len(event.actions) > 0: - self._tree.push(event.condition) + action_list, _ = self.tree_policy(event) + for action in action_list: + self._tree.push(action) + self._policy = [a for a in action_list[::-1]] - for action in event.actions[::-1]: - self._tree.push(action) + def tree_policy(self, event): - self._policy = event.actions + (event.condition,) - return + if isinstance(event, EventCondition) or isinstance(event, EventAction): + if isinstance(event, EventCondition) and len(event.actions) > 0: + return [event.condition] + [action for action in event.actions[::-1]], 1 + elif isinstance(event, EventAction) and len(event.actions) > 0: + return [action for action in event.actions[::-1]], 0 + else: + return [], 1 + _actions, _ev_type = [], [] for ev in event.events: - self.tree_policy(ev) + a, b = self.tree_policy(ev) + _actions.append(a) + _ev_type.append(b) + + if isinstance(event, EventAnd): + act_list = [a for act in [x for _, x in sorted(zip(_ev_type, _actions))] for a in act] + elif isinstance(event, EventOr): + _actions = [x for x in _actions if len(x) > 0] + if _actions: + act_list = min(_actions, key=lambda act: len(act)) + else: + act_list = [] + + return act_list, 0 def copy(self) -> "EventProgression": """ Return a soft copy. """ @@ -1422,7 +1231,6 @@ def _find_shorter_policy(policy): self._tree.push(action) return shorter_policy - return None compressed = False @@ -1622,6 +1430,12 @@ def winning_policy(self) -> Optional[List[Action]]: master_quest_tree = ActionDependencyTree(kb=self.game.kb, element_type=ActionDependencyTreeElement, trees=trees) + actions = tuple(a for a in master_quest_tree.flatten() if a.name != "trigger") + for action in actions: + if not action.command_template: + m = {c: d for c in self.game.kb.rules[action.name].placeholders for d in action.variables if c.type == d.type} + substitutions = {ph.name: "{{{}}}".format(var.name) for ph, var in m.items()} + action.command_template = self.game.kb.rules[action.name].command_template.format(**substitutions) # Discard all "trigger" actions. return tuple(a for a in master_quest_tree.flatten() if a.name != "trigger") @@ -1637,7 +1451,6 @@ def any_traceable_exist(self, events): return any(trc_exist) def add_traceables(self, action): - s = self.state.facts trace = [] for quest_progression in self.quest_progressions: if quest_progression.quest.reward >= 0: @@ -1685,7 +1498,6 @@ def update(self, action: Action) -> None: # Get valid actions. self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), self.game.kb.types.constants_mapping)) - x = 0 class GameOptions: diff --git a/textworld/generator/inform7/world2inform7.py b/textworld/generator/inform7/world2inform7.py index 387afad8..08c3f0c6 100644 --- a/textworld/generator/inform7/world2inform7.py +++ b/textworld/generator/inform7/world2inform7.py @@ -1127,93 +1127,6 @@ def get_events(self, combined_events, txt, rmv, quest_id, rwd_conds, action_id=[ return tp_txt, tp_rmv, fin_cond, [action_id, cond_id, rwd_conds], cond_type - - # def get_events(self, combined_events, txt, rmv, quest_id, action_id=[], state_id=[], cond_id=[], check_vars=[], rwd_conds=[]): - # - # remove_action_processing_template = textwrap.dedent("""if the action{action_id} check is true: - # Now the action{action_id} check is false. - # """) - # remove_condition_processing_template = textwrap.dedent("""if the condition{cond_id} of quest{quest_id} check is true: - # Now the condition{cond_id} of quest{quest_id} check is false. - # """) - # action_processing_template = textwrap.dedent(""" - # The action{action_id} check is a truth state that varies. - # The action{action_id} check is usually false. - # After {actions}: - # Now the action{action_id} check is true. - # """) - # state_processing_template = textwrap.dedent(""" - # The state{state_id} of quest{quest_id} check is a truth state that varies. - # The state{state_id} of quest{quest_id} check is usually false. - # Every turn: - # if {conditions}: - # Now the state{state_id} of quest{quest_id} check is true. - # """) - # combined_ac_processing_template = textwrap.dedent(""" - # The condition{cond_id} of quest{quest_id} check is a truth state that varies. - # The condition{cond_id} of quest{quest_id} check is usually false. - # Every turn: - # if {conditions}: - # Now the condition{cond_id} of quest{quest_id} check is true. - # """) - # - # if isinstance(combined_events, EventCondition) or isinstance(combined_events, EventAction): - # if isinstance(combined_events, EventCondition): - # i7_ = self.gen_source_for_conditions(combined_events.condition.preconditions) - # # txt += [i7_] - # check_vars.append(i7_) - # state_id += [1] - # return [None] * 6 - # - # elif isinstance(combined_events, EventAction): - # if not rwd_conds or 'action{action_id}'.format(action_id=len(action_id)) not in rwd_conds: - # i7_ = self.gen_source_for_actions(combined_events.actions) - # txt += [action_processing_template.format(quest_id=quest_id, action_id=len(action_id), actions=i7_)] - # rmv += [remove_action_processing_template.format(quest_id=quest_id, action_id=len(action_id))] - # check_vars.append('action{action_id} check is true'.format(action_id=len(action_id))) - # rwd_conds += {i7_: 'action{action_id}'.format(action_id=len(action_id))} - # action_id += [1] - # else: - # rmv += [remove_action_processing_template.format(quest_id=quest_id, action_id=len(action_id))] - # check_vars.append('action{action_id} check is true'.format(action_id=len(action_id))) - # - # return [None] * 6 - # - # # act_type = [] - # act_type, _txt, _rmv, _check_vars, _cond_id, _rwd_conds = [], [], [], [], [], [] - # # _cond_id, _rwd_conds = [], [] - # for event in combined_events.events: - # st, rm, _, _, cond_type, rwd = self.get_events(event, _txt, _rmv, quest_id, action_id, state_id, cond_id, - # _check_vars, _rwd_conds) - # act_type.append(isinstance(event, EventAction)) - # - # if st: - # _txt += [st] - # _rmv += [rm] - # _check_vars.append('condition{cond_id} of quest{quest_id} check is true'.format(cond_id=len(cond_id)-1, - # quest_id=quest_id)) - # # _rwd_conds += rwd - # - # if cond_type: - # _cond_id += cond_type - # - # if any(_cond_id): - # _rmv += [remove_condition_processing_template.format(quest_id=quest_id, cond_id=len(cond_id) - 1)] - # - # event_rule = isinstance(combined_events, EventAnd) * ' and ' + isinstance(combined_events, EventOr) * ' or ' - # # _rwd_conds += _check_vars - # condition_ = event_rule.join(cv for cv in _check_vars) - # tp_txt = ''.join(tx for tx in _txt) - # tp_txt += combined_ac_processing_template.format(quest_id=quest_id, cond_id=len(cond_id), conditions=condition_) - # tp_rmv = ' '.join(ac for ac in _rmv if ac) - # fin_cond = 'condition{cond_id} of quest{quest_id} check is true'.format(cond_id=len(cond_id), quest_id=quest_id) - # cond_id += [1] - # if any(act_type): - # cond_type = [True] - # - # return tp_txt, tp_rmv, fin_cond, [action_id, state_id, cond_id, _rwd_conds], cond_type, _rwd_conds - - def generate_inform7_source(game: Game, seed: int = 1234, use_i7_description: bool = False) -> str: inform7 = Inform7Game(game) inform7.use_i7_description = use_i7_description diff --git a/textworld/generator/maker.py b/textworld/generator/maker.py index 68f2c83e..bfecfb6e 100644 --- a/textworld/generator/maker.py +++ b/textworld/generator/maker.py @@ -45,6 +45,45 @@ def get_failing_constraints(state, kb: Optional[KnowledgeBase] = None): return failed_constraints +def new_operation(operation={}): + def func(operator='or', events=[]): + if operator == 'or' and events: + return EventOr(events=events) + if operator == 'and' and events: + return EventAnd(events=events) + else: + raise + + if not isinstance(operation, dict): + if len(operation) == 0: + return () + else: + operation = {'or': tuple(ev for ev in operation)} + + y1 = [] + for k, v in operation.items(): + if isinstance(v, dict): + y1.append(new_operation(operation=v)[0]) + y1 = [func(k, y1)] + else: + if isinstance(v, EventCondition) or isinstance(v, EventAction): + y1.append(func(k, [v])) + else: + if any((isinstance(it, dict) for it in v)): + y2 = [] + for it in v: + if isinstance(it, dict): + y2.append(new_operation(operation=it)[0]) + else: + y2.append(func(k, [it])) + + y1 = [func(k, y2)] + else: + y1.append(func(k, v)) + + return tuple(y1) + + class MissingPlayerError(ValueError): pass @@ -78,7 +117,7 @@ def __init__(self, failed_constraints: List[Action]) -> None: class UnderspecifiedEventError(NameError): def __init__(self): - msg = "Either the actions or the conditions is required to create an event. Both cannot be provided." + msg = "The event type should be specified. It can be either the action or condition." super().__init__(msg) @@ -671,8 +710,8 @@ def set_quest_from_commands(self, commands: List[str]) -> Quest: unrecognized_commands = [c for c, a in zip(commands, recorder.actions) if a is None] raise QuestError("Some of the actions were unrecognized: {}".format(unrecognized_commands)) - event = EventCondition(actions=actions, conditions=winning_facts) - self.quests = [Quest(win_events=[event])] + event = self.new_event(action=actions, condition=winning_facts, command=commands, event_style=event_style) + self.quests = [self.new_quest(win_event=[event])] # Calling build will generate the description for the quest. self.build() @@ -720,67 +759,26 @@ def new_conditions(conditions, args): return None def new_event(self, action: Iterable[Action] = (), condition: Iterable[Proposition] = (), - command: Iterable[str] = (), condition_verb_tense: dict = (), action_verb_tense: dict = ()): - - if action and condition: - raise UnderspecifiedEventError - - if action: - event = EventAction(actions=action, verb_tense=action_verb_tense, commands=command) - - elif condition: + command: Iterable[str] = (), condition_verb_tense: dict = (), action_verb_tense: dict = (), + event_style: str = 'condition'): + if event_style == 'condition': event = EventCondition(conditions=condition, verb_tense=condition_verb_tense, actions=action, commands=command) - - # return tuple(ev for ev in [event] if ev) - return event - - def new_operation(self, operation={}): - def func(operator='or', events=[]): - if operator == 'or' and events: - return EventOr(events=events) - if operator == 'and' and events: - return EventAnd(events=events) - else: - raise - - if not isinstance(operation, dict): - if len(operation) == 0: - return () - else: - raise - - y1 = [] - for k, v in operation.items(): - if isinstance(v, dict): - y1.append(self.new_operation(operation=v)[0]) - y1 = [func(k, y1)] - else: - if isinstance(v, EventCondition) or isinstance(v, EventAction): - y1.append(func(k, [v])) - else: - if any((isinstance(it, dict) for it in v)): - y2 = [] - for it in v: - if isinstance(it, dict): - y2.append(self.new_operation(operation=it)[0]) - else: - y2.append(func(k, [it])) - - y1 = [func(k, y2)] - else: - y1.append(func(k, v)) - - return tuple(y1) + return event + elif event_style == 'action': + event = EventAction(actions=action, verb_tense=action_verb_tense, commands=command) + return event + else: + raise UnderspecifiedEventError def new_quest(self, win_event=(), fail_event=(), reward=None, desc=None, commands=()) -> Quest: - return Quest(win_events=self.new_operation(operation=win_event), - fail_events=self.new_operation(operation=fail_event), + return Quest(win_events=new_operation(operation=win_event), + fail_events=new_operation(operation=fail_event), reward=reward, desc=desc, commands=commands) - def new_event_using_commands(self, commands: List[str]) -> EventCondition: + def new_event_using_commands(self, commands: List[str], event_style: str) -> Union[EventCondition, EventAction]: """ Creates a new event using predefined text commands. This launches a `textworld.play` session to execute provided commands. @@ -802,10 +800,10 @@ def new_event_using_commands(self, commands: List[str]) -> EventCondition: # Skip "None" actions. actions, commands = zip(*[(a, c) for a, c in zip(recorder.actions, commands) if a is not None]) - event = EventCondition(actions=actions, commands=commands) + event = self.new_event(action=actions, command=commands, event_style=event_style) return event - def new_quest_using_commands(self, commands: List[str]) -> Quest: + def new_quest_using_commands(self, commands: List[str], event_style: str) -> Quest: """ Creates a new quest using predefined text commands. This launches a `textworld.play` session to execute provided commands. @@ -816,8 +814,8 @@ def new_quest_using_commands(self, commands: List[str]) -> Quest: Returns: The resulting quest. """ - event = self.new_event_using_commands(commands) - return Quest(win_events=[event], commands=event.commands) + event = self.new_event_using_commands(commands, event_style=event_style) + return Quest(win_events=new_operation(operation=[event]), commands=event.commands) def set_walkthrough(self, commands: List[str]): with make_temp_directory() as tmpdir: @@ -955,6 +953,7 @@ def render(self, interactive: bool = False): :param filename: filename for screenshot """ game = self.build(validate=False) + game.change_grammar(self.grammar) # Generate missing object names. return visualize(game, interactive=interactive) def import_graph(self, G: nx.Graph) -> List[WorldRoom]: diff --git a/textworld/generator/tests/test_game.py b/textworld/generator/tests/test_game.py index 7cc47d86..abc4d40f 100644 --- a/textworld/generator/tests/test_game.py +++ b/textworld/generator/tests/test_game.py @@ -15,12 +15,13 @@ from textworld.generator.data import KnowledgeBase from textworld.generator import World from textworld.generator import make_small_map +from textworld.generator.maker import new_operation from textworld.generator.chaining import ChainingOptions, sample_quest from textworld.logic import Action from textworld.generator.game import GameOptions -from textworld.generator.game import Quest, Game, Event +from textworld.generator.game import Quest, Game, Event, EventAction, EventCondition, EventOr, EventAnd from textworld.generator.game import QuestProgression, GameProgression, EventProgression from textworld.generator.game import UnderspecifiedEventError, UnderspecifiedQuestError from textworld.generator.game import ActionDependencyTree, ActionDependencyTreeElement @@ -119,7 +120,7 @@ def test_variable_infos(verbose=False): assert var_infos.desc is not None -class TestEvent(unittest.TestCase): +class TestEventCondition(unittest.TestCase): @classmethod def setUpClass(cls): @@ -139,28 +140,193 @@ def setUpClass(cls): chest.add_property("open") R1.add(chest) - cls.event = M.new_event_using_commands(commands) + cls.event = M.new_event_using_commands(commands, event_style='condition') cls.actions = cls.event.actions + cls.traceable = cls.event.traceable cls.conditions = {M.new_fact("in", carrot, chest)} def test_init(self): - event = Event(self.actions) + event = EventCondition(actions=self.actions) assert event.actions == self.actions assert event.condition == self.event.condition + assert event.traceable == self.traceable assert event.condition.preconditions == self.actions[-1].postconditions assert set(event.condition.preconditions).issuperset(self.conditions) - event = Event(conditions=self.conditions) + event = EventCondition(conditions=self.conditions) assert len(event.actions) == 0 + assert event.traceable == self.traceable assert set(event.condition.preconditions) == set(self.conditions) - npt.assert_raises(UnderspecifiedEventError, Event, actions=[]) - npt.assert_raises(UnderspecifiedEventError, Event, actions=[], conditions=[]) - npt.assert_raises(UnderspecifiedEventError, Event, conditions=[]) + npt.assert_raises(UnderspecifiedEventError, EventCondition, actions=[]) + npt.assert_raises(UnderspecifiedEventError, EventCondition, actions=[], conditions=[]) + npt.assert_raises(UnderspecifiedEventError, EventCondition, conditions=[]) def test_serialization(self): data = self.event.serialize() - event = Event.deserialize(data) + event = EventCondition.deserialize(data) + assert event == self.event + + def test_copy(self): + event = self.event.copy() + assert event == self.event + assert id(event) != id(self.event) + + +class TestEventAction(unittest.TestCase): + + @classmethod + def setUpClass(cls): + M = GameMaker() + + # The goal + commands = ["take carrot"] + + R1 = M.new_room("room") + M.set_player(R1) + + carrot = M.new(type='f', name='carrot') + R1.add(carrot) + + # Add a closed chest in R2. + chest = M.new(type='c', name='chest') + chest.add_property("open") + R1.add(chest) + + cls.event = M.new_event_using_commands(commands, event_style='action') + cls.actions = cls.event.actions + cls.traceable = cls.event.traceable + + def test_init(self): + event = EventAction(actions=self.actions) + assert event.actions == self.actions + assert event.traceable == self.traceable + + npt.assert_raises(UnderspecifiedEventError, EventCondition, actions=[]) + + def test_serialization(self): + data = self.event.serialize() + event = EventAction.deserialize(data) + assert event == self.event + + def test_copy(self): + event = self.event.copy() + assert event == self.event + assert id(event) != id(self.event) + + +class TestEventOr(unittest.TestCase): + + @classmethod + def setUpClass(cls): + + M = GameMaker() + + # The goal + commands = ["take lime juice", "insert lime juice into chest", "take carrot"] + + R1 = M.new_room("room") + M.set_player(R1) + + lime = M.new(type='f', name='lime juice') + R1.add(lime) + + carrot = M.new(type='f', name='carrot') + R1.add(carrot) + + # Add a closed chest in R2. + chest = M.new(type='c', name='chest') + chest.add_property("open") + R1.add(chest) + + cls.first_event = M.new_event_using_commands(commands[:-1], event_style='condition') + cls.first_event_actions = cls.first_event.actions + cls.first_event_traceable = cls.first_event.traceable + cls.first_event_conditions = {M.new_fact("in", lime, chest)} + + cls.second_event = M.new_event_using_commands([commands[-1]], event_style='action') + cls.second_event_actions = cls.second_event.actions + cls.second_event_traceable = cls.second_event.traceable + + cls.event = EventOr(events=(cls.first_event, cls.second_event)) + cls.events = cls.event.events + + def test_init(self): + first_event = EventCondition(actions=self.first_event_actions) + second_event = EventAction(actions=self.second_event_actions) + event = EventOr(events=(first_event, second_event)) + + assert event.events[0].actions == self.first_event_actions + assert event.events[0].condition == self.first_event.condition + assert event.events[0].traceable == self.first_event_traceable + assert event.events[0].condition.preconditions == self.first_event_actions[-1].postconditions + assert set(event.events[0].condition.preconditions).issuperset(self.first_event_conditions) + assert event.events[1].actions == self.second_event_actions + assert event.events[1].traceable == self.second_event_traceable + + def test_serialization(self): + data = self.event.serialize() + event = EventOr.deserialize(data) + assert event == self.event + + def test_copy(self): + event = self.event.copy() + assert event == self.event + assert id(event) != id(self.event) + + +class TestEventAnd(unittest.TestCase): + + @classmethod + def setUpClass(cls): + + M = GameMaker() + + # The goal + commands = ["take lime juice", "insert lime juice into chest", "take carrot"] + + R1 = M.new_room("room") + M.set_player(R1) + + lime = M.new(type='f', name='lime juice') + R1.add(lime) + + carrot = M.new(type='f', name='carrot') + R1.add(carrot) + + # Add a closed chest in R2. + chest = M.new(type='c', name='chest') + chest.add_property("open") + R1.add(chest) + + cls.first_event = M.new_event_using_commands(commands[:-1], event_style='condition') + cls.first_event_actions = cls.first_event.actions + cls.first_event_traceable = cls.first_event.traceable + cls.first_event_conditions = {M.new_fact("in", lime, chest)} + + cls.second_event = M.new_event_using_commands([commands[-1]], event_style='action') + cls.second_event_actions = cls.second_event.actions + cls.second_event_traceable = cls.second_event.traceable + + cls.event = EventAnd(events=(cls.first_event, cls.second_event)) + cls.events = cls.event.events + + def test_init(self): + first_event = EventCondition(actions=self.first_event_actions) + second_event = EventAction(actions=self.second_event_actions) + event = EventAnd(events=(first_event, second_event)) + + assert event.events[0].actions == self.first_event_actions + assert event.events[0].condition == self.first_event.condition + assert event.events[0].traceable == self.first_event_traceable + assert event.events[0].condition.preconditions == self.first_event_actions[-1].postconditions + assert set(event.events[0].condition.preconditions).issuperset(self.first_event_conditions) + assert event.events[1].actions == self.second_event_actions + assert event.events[1].traceable == self.second_event_traceable + + def test_serialization(self): + data = self.event.serialize() + event = EventAnd.deserialize(data) assert event == self.event def test_copy(self): @@ -176,7 +342,7 @@ def setUpClass(cls): M = GameMaker() # The goal - commands = ["go east", "insert carrot into chest"] + commands = ["open wooden door", "go east", "insert carrot into chest"] # Create a 'bedroom' room. R1 = M.new_room("bedroom") @@ -184,8 +350,8 @@ def setUpClass(cls): M.set_player(R1) path = M.connect(R1.east, R2.west) - path.door = M.new(type='d', name='wooden door') - path.door.add_property("open") + door_a = M.new_door(path, name="wooden door") + M.add_fact("closed", door_a) carrot = M.new(type='f', name='carrot') M.inventory.add(carrot) @@ -195,16 +361,14 @@ def setUpClass(cls): chest.add_property("open") R2.add(chest) - cls.eventA = M.new_event_using_commands(commands) - cls.eventB = Event(conditions={M.new_fact("at", carrot, R1), - M.new_fact("closed", path.door)}) - cls.eventC = Event(conditions={M.new_fact("eaten", carrot)}) - cls.eventD = Event(conditions={M.new_fact("closed", chest), - M.new_fact("closed", path.door)}) - cls.quest = Quest(win_events=[cls.eventA, cls.eventB], - fail_events=[cls.eventC, cls.eventD], - reward=2) - + cls.eventA = M.new_event_using_commands(commands, event_style='condition') + cls.eventB = M.new_event(condition={M.new_fact("at", carrot, R1), M.new_fact("closed", path.door)}, + event_style='condition') + cls.eventC = M.new_event(condition={M.new_fact("eaten", carrot)}, event_style='condition') + cls.eventD = M.new_event(condition={M.new_fact("closed", chest), M.new_fact("closed", path.door)}, + event_style='condition') + cls.quest = M.new_quest(win_event={'or': (cls.eventA, cls.eventB)}, + fail_event={'or': (cls.eventC, cls.eventD)}, reward=2) M.quests = [cls.quest] cls.game = M.build() cls.inform7 = Inform7Game(cls.game) @@ -212,14 +376,14 @@ def setUpClass(cls): def test_init(self): npt.assert_raises(UnderspecifiedQuestError, Quest) - quest = Quest(win_events=[self.eventA, self.eventB]) + quest = Quest(win_events=new_operation(operation={'and': (self.eventA, self.eventB)})) assert len(quest.fail_events) == 0 - quest = Quest(fail_events=[self.eventC, self.eventD]) + quest = Quest(fail_events=new_operation(operation={'or': (self.eventC, self.eventD)})) assert len(quest.win_events) == 0 - quest = Quest(win_events=[self.eventA], - fail_events=[self.eventC, self.eventD]) + quest = Quest(win_events=new_operation(operation={'and': (self.eventA, self.eventB)}), + fail_events=new_operation(operation={'or': (self.eventC, self.eventD)})) assert len(quest.win_events) > 0 assert len(quest.fail_events) > 0 @@ -270,7 +434,8 @@ def _rule_to_skip(rule): # Build the quest by providing the actions. actions = chain.actions assert len(actions) == max_depth, rule.name - quest = Quest(win_events=[Event(actions)]) + + quest = Quest(win_events=new_operation(operation={'and': (EventCondition(actions=actions))})) tmp_world = World.from_facts(chain.initial_state.facts) state = tmp_world.state @@ -281,7 +446,7 @@ def _rule_to_skip(rule): assert quest.is_winning(state) # Build the quest by only providing the winning conditions. - quest = Quest(win_events=[Event(conditions=actions[-1].postconditions)]) + quest = Quest(win_events=new_operation(operation={'and': (EventCondition(conditions=actions[-1].postconditions))})) tmp_world = World.from_facts(chain.initial_state.facts) state = tmp_world.state @@ -293,7 +458,7 @@ def _rule_to_skip(rule): def test_win_actions(self): state = self.game.world.state.copy() - for action in self.quest.win_events[0].actions: + for action in self.quest.win_events_list[0].actions: assert not self.quest.is_winning(state) state.apply(action) @@ -306,20 +471,20 @@ def test_win_actions(self): self.game.kb.types.constants_mapping)) drop_carrot = _find_action("drop carrot", actions, self.inform7) - close_door = _find_action("close wooden door", actions, self.inform7) + open_door = _find_action("open wooden door", actions, self.inform7) state = self.game.world.state.copy() assert state.apply(drop_carrot) - assert not self.quest.is_winning(state) - assert state.apply(close_door) assert self.quest.is_winning(state) + assert state.apply(open_door) + assert not self.quest.is_winning(state) # Or the other way around. state = self.game.world.state.copy() - assert state.apply(close_door) + assert state.apply(open_door) assert not self.quest.is_winning(state) assert state.apply(drop_carrot) - assert self.quest.is_winning(state) + assert not self.quest.is_winning(state) def test_fail_actions(self): state = self.game.world.state.copy() @@ -328,7 +493,10 @@ def test_fail_actions(self): actions = list(state.all_applicable_actions(self.game.kb.rules.values(), self.game.kb.types.constants_mapping)) eat_carrot = _find_action("eat carrot", actions, self.inform7) - go_east = _find_action("go east", actions, self.inform7) + open_door = _find_action("open wooden door", actions, self.inform7) + state.apply(open_door) + actions = list(state.all_applicable_actions(self.game.kb.rules.values(), + self.game.kb.types.constants_mapping)) for action in actions: state = self.game.world.state.copy() @@ -337,6 +505,13 @@ def test_fail_actions(self): assert self.quest.is_failing(state) == (action == eat_carrot) state = self.game.world.state.copy() + actions = list(state.all_applicable_actions(self.game.kb.rules.values(), + self.game.kb.types.constants_mapping)) + open_door = _find_action("open wooden door", actions, self.inform7) + state.apply(open_door) + actions = list(state.all_applicable_actions(self.game.kb.rules.values(), + self.game.kb.types.constants_mapping)) + go_east = _find_action("go east", actions, self.inform7) state.apply(go_east) # Move to the kitchen. actions = list(state.all_applicable_actions(self.game.kb.rules.values(), self.game.kb.types.constants_mapping)) @@ -369,7 +544,7 @@ def setUpClass(cls): M = GameMaker() # The goal - commands = ["go east", "insert carrot into chest"] + commands = ["open wooden door", "go east", "insert carrot into chest"] # Create a 'bedroom' room. R1 = M.new_room("bedroom") @@ -377,8 +552,8 @@ def setUpClass(cls): M.set_player(R1) path = M.connect(R1.east, R2.west) - path.door = M.new(type='d', name='wooden door') - path.door.add_property("open") + door_a = M.new_door(path, name="wooden door") + M.add_fact("closed", door_a) carrot = M.new(type='f', name='carrot') M.inventory.add(carrot) @@ -388,7 +563,7 @@ def setUpClass(cls): chest.add_property("open") R2.add(chest) - M.set_quest_from_commands(commands) + M.set_quest_from_commands(commands, event_style='condition') cls.game = M.build() def test_directions_names(self): @@ -458,18 +633,18 @@ def setUpClass(cls): chest.add_property("open") R1.add(chest) - cls.event = M.new_event_using_commands(commands) - cls.actions = cls.event.actions + cls.event = new_operation([M.new_event_using_commands(commands, event_style='condition')]) + cls.actions = cls.event[0].events[0].actions cls.conditions = {M.new_fact("in", carrot, chest)} cls.game = M.build() commands = ["take carrot", "eat carrot"] - cls.eating_carrot = M.new_event_using_commands(commands) + cls.eating_carrot = new_operation([M.new_event_using_commands(commands, event_style='condition')]) def test_triggering_policy(self): - event = EventProgression(self.event, KnowledgeBase.default()) + event = EventProgression(self.event[0], KnowledgeBase.default()) state = self.game.world.state.copy() - expected_actions = self.event.actions + expected_actions = self.event[0].events[0].actions for i, action in enumerate(expected_actions): assert event.triggering_policy == expected_actions[i:] assert not event.done @@ -484,10 +659,10 @@ def test_triggering_policy(self): assert not event.untriggerable def test_untriggerable(self): - event = EventProgression(self.event, KnowledgeBase.default()) + event = EventProgression(self.event[0], KnowledgeBase.default()) state = self.game.world.state.copy() - for action in self.eating_carrot.actions: + for action in self.eating_carrot[0].events[0].actions: assert event.triggering_policy != () assert not event.done assert not event.triggered @@ -521,24 +696,23 @@ def setUpClass(cls): # The goals commands = ["take carrot", "insert carrot into chest"] - cls.eventA = M.new_event_using_commands(commands) + cls.eventA = M.new_event_using_commands(commands, event_style='condition') commands = ["take lettuce", "insert lettuce into chest", "close chest"] - event = M.new_event_using_commands(commands) - cls.eventB = Event(actions=event.actions, - conditions={M.new_fact("in", lettuce, chest), - M.new_fact("closed", chest)}) + event = M.new_event_using_commands(commands, event_style='condition') + cls.eventB = EventCondition(actions=event.actions, + conditions={M.new_fact("in", lettuce, chest), M.new_fact("closed", chest)}) - cls.fail_eventA = Event(conditions={M.new_fact("eaten", carrot)}) - cls.fail_eventB = Event(conditions={M.new_fact("eaten", lettuce)}) + cls.fail_eventA = EventCondition(conditions={M.new_fact("eaten", carrot)}) + cls.fail_eventB = EventCondition(conditions={M.new_fact("eaten", lettuce)}) - cls.quest = Quest(win_events=[cls.eventA, cls.eventB], - fail_events=[cls.fail_eventA, cls.fail_eventB]) + cls.quest = M.new_quest(win_event={'or': (cls.eventA, cls.eventB)}, + fail_event={'or': (cls.fail_eventA, cls.fail_eventB)}, reward=2) commands = ["take carrot", "eat carrot"] - cls.eating_carrot = M.new_event_using_commands(commands) + cls.eating_carrot = M.new_event_using_commands(commands, event_style='condition') commands = ["take lettuce", "eat lettuce"] - cls.eating_lettuce = M.new_event_using_commands(commands) + cls.eating_lettuce = M.new_event_using_commands(commands, event_style='condition') M.quests = [cls.quest] cls.game = M.build() @@ -594,10 +768,10 @@ def test_winning_policy(self): for i, action in enumerate(self.eventB.actions): if i < 2: assert quest.winning_policy == self.eventA.actions - else: - # After taking the lettuce and putting it in the chest, - # QuestB becomes the shortest one to complete. - assert quest.winning_policy == self.eventB.actions[i:] + # else: + # # After taking the lettuce and putting it in the chest, + # # QuestB becomes the shortest one to complete. + # assert quest.winning_policy == self.eventB.actions[i:] assert not quest.done state.apply(action) quest.update(action, state) @@ -620,8 +794,8 @@ def setUpClass(cls): M.set_player(R2) path = M.connect(R1.east, R2.west) - path.door = M.new(type='d', name='wooden door') - path.door.add_property("closed") + door_a = M.new_door(path, name="wooden door") + M.add_fact("closed", door_a) carrot = M.new(type='f', name='carrot') lettuce = M.new(type='f', name='lettuce') @@ -640,30 +814,30 @@ def setUpClass(cls): # The goals commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"] - cls.eventA = M.new_event_using_commands(commands) + cls.eventA = M.new_event_using_commands(commands, event_style='condition') commands = ["open wooden door", "go west", "take lettuce", "go east", "insert lettuce into chest"] - cls.eventB = M.new_event_using_commands(commands) + cls.eventB = M.new_event_using_commands(commands, event_style='condition') commands = ["drop pepper"] - cls.eventC = M.new_event_using_commands(commands) + cls.eventC = M.new_event_using_commands(commands, event_style='condition') - cls.losing_eventA = Event(conditions={M.new_fact("eaten", carrot)}) - cls.losing_eventB = Event(conditions={M.new_fact("eaten", lettuce)}) + cls.losing_eventA = EventCondition(conditions={M.new_fact("eaten", carrot)}) + cls.losing_eventB = EventCondition(conditions={M.new_fact("eaten", lettuce)}) - cls.questA = Quest(win_events=[cls.eventA], fail_events=[cls.losing_eventA]) - cls.questB = Quest(win_events=[cls.eventB], fail_events=[cls.losing_eventB]) - cls.questC = Quest(win_events=[cls.eventC], fail_events=[]) - cls.questD = Quest(win_events=[], fail_events=[cls.losing_eventA, cls.losing_eventB]) + cls.questA = M.new_quest(win_event=[cls.eventA], fail_event=[cls.losing_eventA]) + cls.questB = M.new_quest(win_event=[cls.eventB], fail_event=[cls.losing_eventB]) + cls.questC = M.new_quest(win_event=[cls.eventC], fail_event=[]) + cls.questD = M.new_quest(win_event=[], fail_event=[cls.losing_eventA, cls.losing_eventB]) commands = ["open wooden door", "go west", "take carrot", "eat carrot"] - cls.eating_carrot = M.new_event_using_commands(commands) + cls.eating_carrot = M.new_event_using_commands(commands, event_style='condition') commands = ["open wooden door", "go west", "take lettuce", "eat lettuce"] - cls.eating_lettuce = M.new_event_using_commands(commands) + cls.eating_lettuce = M.new_event_using_commands(commands, event_style='condition') commands = ["eat tomato"] - cls.eating_tomato = M.new_event_using_commands(commands) + cls.eating_tomato = M.new_event_using_commands(commands, event_style='condition') commands = ["eat pepper"] - cls.eating_pepper = M.new_event_using_commands(commands) + cls.eating_pepper = M.new_event_using_commands(commands, event_style='condition') M.quests = [cls.questA, cls.questB, cls.questC] cls.game = M.build() @@ -770,8 +944,8 @@ def test_cycle_in_winning_policy(self): R4 = M.new_room("r4") M.set_player(R1) - M.connect(R0.south, R1.north), - M.connect(R1.east, R2.west), + M.connect(R0.south, R1.north) + M.connect(R1.east, R2.west) M.connect(R3.east, R4.west) M.connect(R1.south, R3.north) M.connect(R2.south, R4.north) @@ -783,7 +957,7 @@ def test_cycle_in_winning_policy(self): R2.add(apple) commands = ["go north", "take carrot"] - M.set_quest_from_commands(commands) + M.set_quest_from_commands(commands, event_style='condition') game = M.build() inform7 = Inform7Game(game) game_progression = GameProgression(game) @@ -807,7 +981,7 @@ def test_cycle_in_winning_policy(self): # Quest where player's has to pick up the carrot first. commands = ["go east", "take apple", "go west", "go north", "drop apple"] - M.set_quest_from_commands(commands) + M.set_quest_from_commands(commands, event_style='condition') game = M.build() game_progression = GameProgression(game) @@ -842,8 +1016,8 @@ def test_game_with_multiple_quests(self): M.set_player(R2) path = M.connect(R1.east, R2.west) - path.door = M.new(type='d', name='wooden door') - path.door.add_property("closed") + door_a = M.new_door(path, name="wooden door") + M.add_fact("closed", door_a) carrot = M.new(type='f', name='carrot') lettuce = M.new(type='f', name='lettuce') @@ -854,15 +1028,15 @@ def test_game_with_multiple_quests(self): chest.add_property("open") R2.add(chest) - quest1 = M.new_quest_using_commands(commands[0]) + quest1 = M.new_quest_using_commands(commands[0], event_style='condition') quest1.desc = "Fetch the carrot and drop it on the kitchen's ground." - quest2 = M.new_quest_using_commands(commands[0] + commands[1]) + quest2 = M.new_quest_using_commands(commands[0] + commands[1], event_style='condition') quest2.desc = "Fetch the lettuce and drop it on the kitchen's ground." - quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2]) + quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2], event_style='condition') winning_facts = [M.new_fact("in", lettuce, chest), M.new_fact("in", carrot, chest), M.new_fact("closed", chest)] - quest3.win_events[0].set_conditions(winning_facts) + quest3.win_events[0].events[0].set_conditions(winning_facts) quest3.desc = "Put the lettuce and the carrot into the chest before closing it." M.quests = [quest1, quest2, quest3] diff --git a/textworld/generator/tests/test_maker.py b/textworld/generator/tests/test_maker.py index 6ffb6f70..3aebc523 100644 --- a/textworld/generator/tests/test_maker.py +++ b/textworld/generator/tests/test_maker.py @@ -113,8 +113,8 @@ def test_making_a_small_game(play_the_game=False): path = M.connect(R1.east, R2.west) # Undirected path # Add a closed door between R1 and R2. - door = M.new_door(path, name='glass door') - door.add_property("locked") + door = M.new_door(path, name="glass door") + M.add_fact("locked", door) # Put a matching key for the door on R1's floor. key = M.new(type='k', name='rusty key') @@ -148,7 +148,7 @@ def test_record_quest_from_commands(play_the_game=False): M = GameMaker() # The goal - commands = ["go east", "insert ball into chest"] + commands = ["open wooden door", "go east", "insert ball into chest"] # Create a 'bedroom' room. R1 = M.new_room("bedroom") @@ -156,8 +156,8 @@ def test_record_quest_from_commands(play_the_game=False): M.set_player(R1) path = M.connect(R1.east, R2.west) - path.door = M.new(type='d', name='wooden door') - path.door.add_property("open") + door_a = M.new_door(path, name="wooden door") + M.add_fact("closed", door_a) ball = M.new(type='o', name='ball') M.inventory.add(ball) @@ -167,7 +167,7 @@ def test_record_quest_from_commands(play_the_game=False): chest.add_property("open") R2.add(chest) - M.set_quest_from_commands(commands) + M.set_quest_from_commands(commands, event_style='condition') game = M.build() with make_temp_directory(prefix="test_record_quest_from_commands") as tmpdir: diff --git a/textworld/generator/tests/test_text_generation.py b/textworld/generator/tests/test_text_generation.py index fe4c72f7..290555ac 100644 --- a/textworld/generator/tests/test_text_generation.py +++ b/textworld/generator/tests/test_text_generation.py @@ -57,15 +57,14 @@ def test_blend_instructions(verbose=False): M.set_player(r1) path = M.connect(r1.north, r2.south) - path.door = M.new(type="d", name="door") - M.add_fact("locked", path.door) + door_a = M.new_door(path, name="wooden door") + M.add_fact("locked", door_a) key = M.new(type="k", name="key") M.add_fact("match", key, path.door) r1.add(key) quest = M.set_quest_from_commands(["take key", "unlock door with key", "open door", "go north", - "close door", "lock door with key", "drop key"]) - + "close door", "lock door with key", "drop key"], event_style='condition') game = M.build() grammar1 = textworld.generator.make_grammar({"blend_instructions": False}, diff --git a/textworld/generator/vtypes.py b/textworld/generator/vtypes.py index 8ee08b1d..f9ceb9de 100644 --- a/textworld/generator/vtypes.py +++ b/textworld/generator/vtypes.py @@ -178,7 +178,10 @@ def load(cls, path: str): def __getitem__(self, vtype): """ Get VariableType object from its type string. """ vtype = vtype.rstrip("'") - return self.variables_types[vtype] + if vtype in self.variables_types.keys(): + return self.variables_types[vtype] + else: + return None def __contains__(self, vtype): vtype = vtype.rstrip("'") @@ -199,9 +202,10 @@ def descendants(self, vtype): return [] descendants = [] - for child_type in self[vtype].children: - descendants.append(child_type) - descendants += self.descendants(child_type) + if self[vtype]: + for child_type in self[vtype].children: + descendants.append(child_type) + descendants += self.descendants(child_type) return descendants diff --git a/textworld/logic/__init__.py b/textworld/logic/__init__.py index 32dcc87c..ba1e5019 100644 --- a/textworld/logic/__init__.py +++ b/textworld/logic/__init__.py @@ -2053,3 +2053,7 @@ def has_traceable(self): return True return False + @property + def logic(self): + return self._logic + From ec5185b4784730664c5814f78b9d125997c92268 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 20 May 2020 20:28:16 -0400 Subject: [PATCH 5/5] Refactor new Event class. Makes it work with DependencyTree. Add tests. --- .gitignore | 2 + .../challenges/tests/test_coin_collector.py | 2 +- .../challenges/tests/test_treasure_hunter.py | 2 +- .../tw_coin_collector/coin_collector.py | 21 +- .../textworld_data/logic/player.twl | 4 + textworld/challenges/tw_cooking/cooking.py | 2 +- .../tw_simple/textworld_data/logic/key.twl | 44 +- .../tw_simple/textworld_data/logic/room.twl | 17 +- .../tw_treasure_hunter/treasure_hunter.py | 2 +- textworld/generator/__init__.py | 15 +- .../data/text_grammars/house_quests.twg | 56 - textworld/generator/game.py | 992 ++++++++---------- .../inform7/tests/test_world2inform7.py | 113 +- textworld/generator/inform7/world2inform7.py | 242 +---- textworld/generator/logger.py | 4 +- textworld/generator/maker.py | 174 ++- textworld/generator/tests/test_game.py | 786 ++++++-------- textworld/generator/tests/test_maker.py | 12 +- .../generator/tests/test_text_generation.py | 7 +- textworld/generator/text_generation.py | 318 +----- textworld/generator/vtypes.py | 12 +- textworld/generator/world.py | 4 +- textworld/logic/__init__.py | 154 +-- textworld/logic/parser.py | 31 +- textworld/testing.py | 104 +- 25 files changed, 1204 insertions(+), 1916 deletions(-) delete mode 100644 textworld/generator/data/text_grammars/house_quests.twg diff --git a/.gitignore b/.gitignore index 6632a7cd..faa33005 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,6 @@ tmp/* *.ipynb_checkpoints /dist /wheelhouse +docs/build +docs/src *.orig diff --git a/textworld/challenges/tests/test_coin_collector.py b/textworld/challenges/tests/test_coin_collector.py index f705a4cd..facab8b1 100644 --- a/textworld/challenges/tests/test_coin_collector.py +++ b/textworld/challenges/tests/test_coin_collector.py @@ -20,5 +20,5 @@ def test_making_coin_collector(): settings = {"level": level} game = coin_collector.make(settings, options) - assert len(game.quests[0].commands) == expected[level]["quest_length"] + assert len(game.walkthrough) == expected[level]["quest_length"] assert len(game.world.rooms) == expected[level]["nb_rooms"] diff --git a/textworld/challenges/tests/test_treasure_hunter.py b/textworld/challenges/tests/test_treasure_hunter.py index 65dbd9fd..922a70cf 100644 --- a/textworld/challenges/tests/test_treasure_hunter.py +++ b/textworld/challenges/tests/test_treasure_hunter.py @@ -13,5 +13,5 @@ def test_making_treasure_hunter_games(): settings = {"level": level} game = treasure_hunter.make(settings, options) - assert len(game.quests[0].commands) == game.metadata["quest_length"], "Level {}".format(level) + assert len(game.walkthrough) == game.metadata["quest_length"], "Level {}".format(level) assert len(game.world.rooms) == game.metadata["world_size"], "Level {}".format(level) diff --git a/textworld/challenges/tw_coin_collector/coin_collector.py b/textworld/challenges/tw_coin_collector/coin_collector.py index 243f66ce..d916d27b 100644 --- a/textworld/challenges/tw_coin_collector/coin_collector.py +++ b/textworld/challenges/tw_coin_collector/coin_collector.py @@ -15,17 +15,23 @@ other than the coin to collect. """ +import os import argparse +from os.path import join as pjoin from typing import Mapping, Optional, Any import textworld from textworld.generator.graph_networks import reverse_direction from textworld.utils import encode_seeds -from textworld.generator.game import GameOptions, Quest, EventCondition +from textworld.generator.data import KnowledgeBase +from textworld.generator.game import GameOptions, Quest, Event from textworld.challenges import register +KB_PATH = pjoin(os.path.dirname(__file__), "textworld_data") + + def build_argparser(parser=None): parser = parser or argparse.ArgumentParser() @@ -33,9 +39,6 @@ def build_argparser(parser=None): group.add_argument("--level", required=True, type=int, help="The difficulty level. Must be between 1 and 300 (included).") - group.add_argument("--force-entity-numbering", required=True, action="store_true", - help="This will set `--entity-numbering` to be True which is required for this challenge.") - return parser @@ -49,7 +52,7 @@ def make(settings: Mapping[str, Any], options: Optional[GameOptions] = None) -> :py:class:`textworld.GameOptions ` for the list of available options). - .. warning:: This challenge requires `options.grammar.allowed_variables_numbering` to be `True`. + .. warning:: This challenge enforces `options.grammar.allowed_variables_numbering` to be `True`. Returns: Generated game. @@ -68,9 +71,11 @@ def make(settings: Mapping[str, Any], options: Optional[GameOptions] = None) -> """ options = options or GameOptions() + # Load knowledge base specific to this challenge. + options.kb = KnowledgeBase.load(KB_PATH) + # Needed for games with a lot of rooms. - options.grammar.allowed_variables_numbering = settings["force_entity_numbering"] - assert options.grammar.allowed_variables_numbering + options.grammar.allowed_variables_numbering = True level = settings["level"] if level < 1 or level > 300: @@ -167,7 +172,7 @@ def make_game(mode: str, options: GameOptions) -> textworld.Game: # Generate the quest thats by collecting the coin. quest = Quest(win_events=[ - EventCondition(conditions={M.new_fact("in", coin, M.inventory)}) + Event(conditions={M.new_fact("in", coin, M.inventory)}) ]) M.quests = [quest] diff --git a/textworld/challenges/tw_coin_collector/textworld_data/logic/player.twl b/textworld/challenges/tw_coin_collector/textworld_data/logic/player.twl index 6783223b..450e47bd 100644 --- a/textworld/challenges/tw_coin_collector/textworld_data/logic/player.twl +++ b/textworld/challenges/tw_coin_collector/textworld_data/logic/player.twl @@ -4,6 +4,10 @@ type P { look :: at(P, r) -> at(P, r); # Nothing changes. } + reverse_rules { + look :: look; + } + inform7 { commands { look :: "look" :: "looking"; diff --git a/textworld/challenges/tw_cooking/cooking.py b/textworld/challenges/tw_cooking/cooking.py index 4513b1b1..c7b312fd 100644 --- a/textworld/challenges/tw_cooking/cooking.py +++ b/textworld/challenges/tw_cooking/cooking.py @@ -928,7 +928,7 @@ def make(settings: Mapping[str, str], options: Optional[GameOptions] = None) -> start_room = rng_map.choice(M.rooms) M.set_player(start_room) - M.grammar = textworld.generator.make_grammar(options.grammar, rng=rng_grammar) + M.grammar = textworld.generator.make_grammar(options.grammar, rng=rng_grammar, kb=options.kb) # Remove every food preparation with grilled, if there is no BBQ. if M.find_by_name("BBQ") is None: diff --git a/textworld/challenges/tw_simple/textworld_data/logic/key.twl b/textworld/challenges/tw_simple/textworld_data/logic/key.twl index c7da05e5..ff6d0499 100644 --- a/textworld/challenges/tw_simple/textworld_data/logic/key.twl +++ b/textworld/challenges/tw_simple/textworld_data/logic/key.twl @@ -1,25 +1,25 @@ -# # key -# type k : o { -# predicates { -# match(k, c); -# match(k, d); -# } +# key +type k : o { + predicates { + match(k, c); + match(k, d); + } -# constraints { -# k1 :: match(k, c) & match(k', c) -> fail(); -# k2 :: match(k, c) & match(k, c') -> fail(); -# k3 :: match(k, d) & match(k', d) -> fail(); -# k4 :: match(k, d) & match(k, d') -> fail(); -# } + constraints { + k1 :: match(k, c) & match(k', c) -> fail(); + k2 :: match(k, c) & match(k, c') -> fail(); + k3 :: match(k, d) & match(k', d) -> fail(); + k4 :: match(k, d) & match(k, d') -> fail(); + } -# inform7 { -# type { -# kind :: "key"; -# } + inform7 { + type { + kind :: "key"; + } -# predicates { -# match(k, c) :: "The matching key of the {c} is the {k}"; -# match(k, d) :: "The matching key of the {d} is the {k}"; -# } -# } -# } + predicates { + match(k, c) :: "The matching key of the {c} is the {k}"; + match(k, d) :: "The matching key of the {d} is the {k}"; + } + } +} diff --git a/textworld/challenges/tw_simple/textworld_data/logic/room.twl b/textworld/challenges/tw_simple/textworld_data/logic/room.twl index 62bde7f0..58715688 100644 --- a/textworld/challenges/tw_simple/textworld_data/logic/room.twl +++ b/textworld/challenges/tw_simple/textworld_data/logic/room.twl @@ -7,15 +7,16 @@ type r { north_of(r, r); west_of(r, r); + north_of/d(r, d, r); + west_of/d(r, d, r); + free(r, r); south_of(r, r') = north_of(r', r); east_of(r, r') = west_of(r', r); - # north_of/d(r, d, r); - # west_of/d(r, d, r); - # south_of/d(r, d, r') = north_of/d(r', d, r); - # east_of/d(r, d, r') = west_of/d(r', d, r); + south_of/d(r, d, r') = north_of/d(r', d, r); + east_of/d(r, d, r') = west_of/d(r', d, r); } rules { @@ -65,10 +66,10 @@ type r { east_of(r, r') :: "The {r} is mapped east of {r'}"; west_of(r, r') :: "The {r} is mapped west of {r'}"; - # north_of/d(r, d, r') :: "South of {r} and north of {r'} is a door called {d}"; - # south_of/d(r, d, r') :: "North of {r} and south of {r'} is a door called {d}"; - # east_of/d(r, d, r') :: "West of {r} and east of {r'} is a door called {d}"; - # west_of/d(r, d, r') :: "East of {r} and west of {r'} is a door called {d}"; + north_of/d(r, d, r') :: "South of {r} and north of {r'} is a door called {d}"; + south_of/d(r, d, r') :: "North of {r} and south of {r'} is a door called {d}"; + east_of/d(r, d, r') :: "West of {r} and east of {r'} is a door called {d}"; + west_of/d(r, d, r') :: "East of {r} and west of {r'} is a door called {d}"; } commands { diff --git a/textworld/challenges/tw_treasure_hunter/treasure_hunter.py b/textworld/challenges/tw_treasure_hunter/treasure_hunter.py index fabea384..45a8d57d 100644 --- a/textworld/challenges/tw_treasure_hunter/treasure_hunter.py +++ b/textworld/challenges/tw_treasure_hunter/treasure_hunter.py @@ -233,7 +233,7 @@ def make_game(mode: str, options: GameOptions) -> textworld.Game: quest = Quest(win_events=[event], fail_events=[Event(conditions={Proposition("in", [wrong_obj, world.inventory])})]) - grammar = textworld.generator.make_grammar(options.grammar, rng=rng_grammar) + grammar = textworld.generator.make_grammar(options.grammar, rng=rng_grammar, kb=options.kb) game = textworld.generator.make_game_with(world, [quest], grammar) game.metadata.update(metadata) mode_choice = modes.index(mode) diff --git a/textworld/generator/__init__.py b/textworld/generator/__init__.py index 87f5de50..6dd70764 100644 --- a/textworld/generator/__init__.py +++ b/textworld/generator/__init__.py @@ -17,7 +17,9 @@ from textworld.generator.chaining import ChainingOptions, QuestGenerationError from textworld.generator.chaining import sample_quest from textworld.generator.world import World -from textworld.generator.game import Game, Quest, Event, GameOptions +from textworld.generator.game import Game, Quest, GameOptions +from textworld.generator.game import EventCondition, EventAction, EventAnd, EventOr +from textworld.generator.game import Event # For backward compatibility from textworld.generator.graph_networks import create_map, create_small_map from textworld.generator.text_generation import generate_text_from_grammar @@ -142,19 +144,18 @@ def make_quest(world: Union[World, State], options: Optional[GameOptions] = None for i in range(1, len(chain.nodes)): actions.append(chain.actions[i - 1]) if chain.nodes[i].breadth != chain.nodes[i - 1].breadth: - event = Event(actions) - quests.append(Quest(win_events=[event])) + quests.append(Quest(win_event=EventCondition(actions=actions))) actions.append(chain.actions[-1]) - event = Event(actions) - quests.append(Quest(win_events=[event])) + quests.append(Quest(win_event=EventCondition(actions=actions))) return quests -def make_grammar(options: Mapping = {}, rng: Optional[RandomState] = None) -> Grammar: +def make_grammar(options: Mapping = {}, rng: Optional[RandomState] = None, + kb: Optional[KnowledgeBase] = None) -> Grammar: rng = g_rng.next() if rng is None else rng - grammar = Grammar(options, rng) + grammar = Grammar(options, rng, kb) grammar.check() return grammar diff --git a/textworld/generator/data/text_grammars/house_quests.twg b/textworld/generator/data/text_grammars/house_quests.twg deleted file mode 100644 index b220d1c4..00000000 --- a/textworld/generator/data/text_grammars/house_quests.twg +++ /dev/null @@ -1,56 +0,0 @@ -#------------------------- -#Quests Grammar -#------------------------- - -punishing_quest_none:#punishing_prologue_none# -punishing_prologue_none: in this quest, there are some activities that if you do them, you will be punished. You will find out what are those.;your activities matter in this quest, some of them will have punishment for you. Explore the environment.;your activities changes the state of the game. There is a state which is dangerous for you, you\'ll be charged if you get there. - -punishing_quest_one_task:#punishing_prologue# (combined_task) - -punishing_quest_tasks:#punishing_prologue# #AndOr# (list_of_combined_tasks) -AndOr:complete any of the following tasks;do none of the following tasks;do any single of the following tasks; commit any of the following tasks - -punishing_prologue: your mission for this quest is not to; your task for this quest is not to; there is something I need you to be careful about it. Please never; your objective is not to; please hesitate to - -quest_none:#prologue_none# -prologue_none: there are some activities I need you to do for me in this quest. You will find out what are those.;your activities matter in this quest. Explore the environment.;your activities changes the state of the game. There is a state which is important to this quest. - -quest_one_task:#prologue# (combined_task) - -quest_and_tasks:#prologue# #And# (list_of_combined_tasks) -And:complete all the following tasks;do every single of the following tasks; finish every single of the following tasks - -quest_or_tasks:#prologue# #Or# (list_of_combined_tasks) -Or:complete any of the following tasks;do at least one of the following tasks; finish one or more of the following tasks - -prologue: your mission for this quest is to; your task for this quest is to; there is something I need you to do for me. Please ; your objective is to; please - -combined_one_task:#prologue_combined_one_task# (only_task) -prologue_combined_one_task:there is only one event to do in this task;the only objective event here is; - -combined_and_tasks:#prologue_combined_and_tasks# (list_of_tasks) -prologue_combined_and_tasks:make sure to complete all these events in this task;complete all the events for this task;do all these events; all these events need to be done in this task - -combined_or_tasks:#prologue_combined_or_tasks# (list_of_tasks) -prologue_combined_or_tasks:make sure to complete any of these events in this task;complete at least one of the following events in this in this task;do minimum one of these events here; at least one of these events need to be done for this task - -combined_one_event:#prologue_combined_one_event# (only_event) -prologue_combined_one_event:The event includes the following event;the objective of this event is given as follows - -combined_and_events:#prologue_combined_and_events# (list_of_events) -prologue_combined_and_events:complete all these actions;do all these actions; all these actions need to be done - -combined_or_events:#prologue_combined_or_events# (list_of_events) -prologue_combined_or_events:complete any of these actions;doing any of these actions is sufficient;Completing any from the following actions is okay - -event:(list_of_actions) - - -all_quests_non: #prologue_quests# #epilogue# - -all_quests: #prologue_quests# (quests_string) #epilogue# - -prologue_quests:#welcome#;Hey, thanks for coming over to the TextWorld today, there is something I need you to do for me.;Hey, thanks for coming over to TextWorld! Your activity matters here. -epilogue:Once that's handled, you can stop!;You will know when you can stop!;And once you've done, you win!;You're the winner when all activities are taken!;Got that? Great!;Alright, thanks! -welcome:Welcome to TextWorld;You are now playing a #exciting# #game# of TextWorld Spaceship;Welcome to another #exciting# #game# of TextWorld;It's time to explore the amazing world of TextWorld Galaxy;Get ready to pick stuff up and put it in places, because you've just entered TextWorld shuttle;I hope you're ready to go into rooms and interact with objects, because you've just entered TextWorld shuttle;Who's got a virtual machine and is about to play through an #exciting# round of TextWorld? You do; -game:game;round;session;episode diff --git a/textworld/generator/game.py b/textworld/generator/game.py index 4b731bc9..cc97f74a 100644 --- a/textworld/generator/game.py +++ b/textworld/generator/game.py @@ -5,21 +5,20 @@ import copy import json import textwrap -import re +import warnings +import itertools from typing import List, Dict, Optional, Mapping, Any, Iterable, Union, Tuple from collections import OrderedDict -from functools import partial from numpy.random import RandomState -import textworld from textworld import g_rng from textworld.utils import encode_seeds from textworld.generator.data import KnowledgeBase from textworld.generator.text_grammar import Grammar, GrammarOptions from textworld.generator.world import World -from textworld.logic import Action, Proposition, State +from textworld.logic import Action, Proposition, State, Rule, Variable from textworld.generator.graph_networks import DIRECTIONS from textworld.generator.chaining import ChainingOptions @@ -41,107 +40,33 @@ def __init__(self): super().__init__(msg) -class UnderspecifiedEventActionError(NameError): - def __init__(self): - msg = "The EventAction includes ONLY one action." - super().__init__(msg) - - class UnderspecifiedQuestError(NameError): def __init__(self): msg = "At least one winning or failing event is needed to create a quest." super().__init__(msg) -def gen_commands_from_actions(actions: Iterable[Action], kb: Optional[KnowledgeBase] = None) -> List[str]: - kb = kb or KnowledgeBase.default() - - def _get_name_mapping(action): - mapping = kb.rules[action.name].match(action) - return {ph.name: var.name for ph, var in mapping.items()} - - commands = [] - for action in actions: - command = "None" - if action is not None: - command = kb.inform7_commands[action.name] - command = command.format(**_get_name_mapping(action)) - - commands.append(command) - - return commands - - -class PropositionControl: - """ - Controlling the proposition's appearance within the game. - - When a proposition is activated in the state set, it may be important to track this event. This basically is - determined in the quest design directly or indirectly. This class manages the creation of the event propositions, - Add or Remove the event proposition from the state set, etc. - - Attributes: - - """ - - def __init__(self, props: Iterable[Proposition], verbs: dict): - - self.propositions = props - self.verbs = verbs - self.traceable_propositions, self.addon = self.set_events() - - def set_events(self): - variables = sorted(set([v for c in self.propositions for v in c.arguments])) - event = Proposition("event", arguments=variables) - - if self.verbs: - state_event = [Proposition(name=self.verbs[prop.definition].replace(' ', '_') + '__' + prop.definition, - arguments=prop.arguments) - for prop in self.propositions if prop.definition in self.verbs.keys()] - else: - state_event = [] - - return state_event, event +class TextworldGameVersionWarning(UserWarning): + pass - @classmethod - def remove(cls, prop: Proposition, state: State): - if not prop.name.startswith('was__'): - return - - if prop in state.facts: - if Proposition(prop.definition, prop.arguments) not in state.facts: - state.remove_fact(prop) - - def has_traceable(self): - for prop in self.get_facts(): - if not prop.name.startswith('is__'): - return True - return False +class AbstractEvent: -class Event: + _SERIAL_VERSION = 2 - def __init__(self, actions: Iterable[Action] = (), commands: Iterable[str] = ()) -> None: + def __init__(self, actions: Iterable[Action] = (), commands: Iterable[str] = (), name: str = "") -> None: """ Args: actions: The actions to be performed to trigger this event. commands: Human readable version of the actions. """ - - self.actions = list(actions) - + self.actions = actions self.commands = commands + self.name = name + self.is_dnf = False @property - def verb_tense(self) -> dict: - return self._verb_tense - - @verb_tense.setter - def verb_tense(self, verb: dict) -> None: - self._verb_tense = verb - - @property - def actions(self) -> Tuple[Action]: + def actions(self) -> Iterable[Action]: return self._actions @actions.setter @@ -160,53 +85,108 @@ def __hash__(self) -> int: return hash((self.actions, self.commands)) def __eq__(self, other: Any) -> bool: - return (isinstance(other, Event) and - self.actions == other.actions and - self.commands == other.commands) + return (isinstance(other, AbstractEvent) + and self.actions == other.actions + and self.commands == other.commands + and self.name == other.name + and self.is_dnf == other.is_dnf) @classmethod - def deserialize(cls, data: Mapping) -> "Event": - """ Creates an `Event` from serialized data. + def deserialize(cls, data: Mapping) -> Union["AbstractEvent", "EventCondition", "EventAction", "EventOr", "EventAnd"]: + """ Creates a `AbstractEvent` (or one of its subtypes) from serialized data. Args: - data: Serialized data with the needed information to build a `Event` object. + data: Serialized data with the needed information to build a `AbstractEvent` (or one of its subtypes) object. """ - actions = [Action.deserialize(d) for d in data["actions_Event"]] - return cls(actions, data["commands_Event"]) + version = data.get("version", 1) + if version == 1: + data["type"] = "EventCondition" + + if data["type"] == "EventCondition": + obj = EventCondition.deserialize(data) + elif data["type"] == "EventAction": + obj = EventAction.deserialize(data) + elif data["type"] == "EventOr": + obj = EventOr.deserialize(data) + elif data["type"] == "EventAnd": + obj = EventAnd.deserialize(data) + elif data["type"] == "AbstractEvent": + obj = cls() + + obj.actions = [Action.deserialize(d) for d in data["actions"]] + obj.commands = data["commands"] + obj.name = data.get("name", "") + obj.is_dnf = data.get("is_dnf", False) + return obj def serialize(self) -> Mapping: """ Serialize this event. Results: - `Event`'s data serialized to be JSON compatible. + Event's data serialized to be JSON compatible """ - return {"commands_Event": self.commands, - "actions_Event": [action.serialize() for action in self.actions]} + return { + "version": self._SERIAL_VERSION, + "type": self.__class__.__name__, + "commands": self.commands, + "actions": [action.serialize() for action in self.actions], + "name": self.name, + "is_dnf": self.is_dnf, + } - def copy(self) -> "Event": + def copy(self) -> "AbstractEvent": """ Copy this event. """ return self.deserialize(self.serialize()) + @classmethod + def to_dnf(cls, expr: Optional["AbstractEvent"]) -> Optional["AbstractEvent"]: + """Normalize a boolean expression to its DNF. + + Expr can be an AbstractEvent, it this case it returns EventOr([EventAnd([element])]). + Expr can be an EventOr(...) / EventAnd(...) expressions, + in which cases it returns also a disjunctive normalised form (removing identical elements) + + References: + Code inspired by https://stackoverflow.com/a/58372345 + """ + if expr is None: + return None + + if expr.is_dnf: + return expr # Expression is already in DNF. + + if not isinstance(expr, (EventOr, EventAnd)): + result = EventOr((EventAnd((expr,)),)) + + elif isinstance(expr, EventOr): + result = EventOr(se for e in expr for se in cls.to_dnf(e)) + + elif isinstance(expr, EventAnd): + total = [] + for c in itertools.product(*[cls.to_dnf(e) for e in expr]): + total.append(EventAnd(se for e in c for se in e)) + + result = EventOr(total) + + result.is_dnf = True + return result + + +class EventCondition(AbstractEvent): -class EventCondition(Event): def __init__(self, conditions: Iterable[Proposition] = (), - verb_tense: dict = (), actions: Iterable[Action] = (), commands: Iterable[str] = (), - ) -> None: + **kwargs) -> None: """ Args: - actions: The actions to be performed to trigger this event. - If an empty list, then `conditions` must be provided. conditions: Set of propositions which need to be all true in order for this event to get triggered. + actions: The actions to be performed to trigger this event. + If an empty list, then `conditions` must be provided. commands: Human readable version of the actions. - verb_tense: The desired verb tense for any state propositions which are been tracking. """ - super(EventCondition, self).__init__(actions, commands) - - self.verb_tense = verb_tense - + super(EventCondition, self).__init__(actions, commands, **kwargs) self.condition = self.set_conditions(conditions) def set_conditions(self, conditions: Iterable[Proposition]) -> Action: @@ -228,49 +208,43 @@ def set_conditions(self, conditions: Iterable[Proposition]) -> Action: # last action in the quest. conditions = self.actions[-1].postconditions - event = PropositionControl(conditions, self.verb_tense) - self.traceable = event.traceable_propositions - condition = Action("trigger", preconditions=conditions, postconditions=list(conditions) + [event.addon]) - - return condition - - def is_valid(self): - return isinstance(self.condition, Action) + variables = sorted(set([v for c in conditions for v in c.arguments])) + event = Proposition("event", arguments=variables) + self.condition = Action("trigger", preconditions=conditions, + postconditions=list(conditions) + [event]) + return self.condition - def is_triggering(self, state: State, actions: Iterable[Action] = ()) -> bool: + def is_triggering(self, state: State, action: Optional[Action] = None, callback: Optional[callable] = None) -> bool: """ Check if this event would be triggered in a given state. """ + is_triggering = state.is_applicable(self.condition) + if callback and is_triggering: + callback(self) - return state.is_applicable(self.condition) + return is_triggering - @property - def traceable(self) -> Iterable[Proposition]: - return self._traceable + def __str__(self) -> str: + return str(self.condition) - @traceable.setter - def traceable(self, traceable: Iterable[Proposition]) -> None: - self._traceable = tuple(traceable) + def __repr__(self) -> str: + return "EventCondition(Action.parse('{}'), name={})".format(self.condition, self.name) def __hash__(self) -> int: - return hash((self.actions, self.commands, self.condition, self.verb_tense, self.traceable)) + return hash((self.actions, self.commands, self.condition)) def __eq__(self, other: Any) -> bool: - return (isinstance(other, EventCondition) and - self.actions == other.actions and - self.commands == other.commands and - self.condition == other.condition and - self.verb_tense == other.verb_tense and - self.traceable == other.traceable) + return (isinstance(other, EventCondition) + and super().__eq__(other) + and self.condition == other.condition) @classmethod def deserialize(cls, data: Mapping) -> "EventCondition": """ Creates an `EventCondition` from serialized data. Args: - data: Serialized data with the needed information to build a `EventCondotion` object. + data: Serialized data with the needed information to build a `EventCondition` object. """ - actions = [Action.deserialize(d) for d in data["actions_EventCondition"]] - condition = Action.deserialize(data["condition_EventCondition"]) - return cls(condition.preconditions, data["verb_tense_EventCondition"], actions, data["commands_EventCondition"]) + condition = Action.deserialize(data["condition"]) + return cls(conditions=condition.preconditions) def serialize(self) -> Mapping: """ Serialize this event. @@ -278,76 +252,77 @@ def serialize(self) -> Mapping: Results: `EventCondition`'s data serialized to be JSON compatible. """ - return {"commands_EventCondition": self.commands, - "actions_EventCondition": [action.serialize() for action in self.actions], - "condition_EventCondition": self.condition.serialize(), - "verb_tense_EventCondition": self.verb_tense} - - def copy(self) -> "EventCondition": - """ Copy this event. """ - return self.deserialize(self.serialize()) - + data = super().serialize() + data["condition"] = self.condition.serialize() + return data -class EventAction(Event): - def __init__(self, actions: Iterable[Action] = (), - verb_tense: dict = (), - commands: Iterable[str] = ()) -> None: - """ - Args: - actions: The actions to be performed to trigger this event. - commands: Human readable version of the actions. - verb_tense: The desired verb tense for any state propositions which are been tracking. - """ - super(EventAction, self).__init__(actions, commands) +class Event: # For backward compatibility. + """ + Event happening in TextWorld. - if self.is_valid(): - raise UnderspecifiedEventActionError + An event gets triggered when its set of conditions become all statisfied. - self.verb_tense = verb_tense + .. warning:: Deprecated in favor of + :py:class:`textworld.generator.EventCondition `. + """ - self.traceable = self.set_actions() + def __new__(cls, actions: Iterable[Action] = (), + conditions: Iterable[Proposition] = (), + commands: Iterable[str] = ()): + return EventCondition(actions=actions, conditions=conditions, commands=commands) - def set_actions(self): - traceable = [] - for act in self.actions: - props = [] - for p in act.all_propositions: - if p not in props: - props.append(p) - event = PropositionControl(props, self.verb_tense) - traceable.append(event.traceable_propositions) +class EventAction(AbstractEvent): - return [prop for ar in traceable for prop in ar] + def __init__(self, action: Rule, + actions: Iterable[Action] = (), + commands: Iterable[str] = (), + **kwargs) -> None: + """ + Args: + action: The action to be performed to trigger this event. + actions: The actions to be performed to trigger this event. + commands: Human readable version of the actions. - def is_valid(self): - return len(self.actions) != 1 + Notes: + TODO: EventAction are temporal. + """ + super(EventAction, self).__init__(actions, commands, **kwargs) + self.action = action - def is_triggering(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: + def is_triggering(self, state: Optional[State] = None, + action: Optional[Action] = None, + callback: Optional[callable] = None) -> bool: """ Check if this event would be triggered for a given action. """ - if not actions: + if action is None: return False - return all((actions[i] == self.actions[i] for i in range(len(actions)))) + mapping = self.action.match(action) + if mapping is None: + return False - @property - def traceable(self) -> Iterable[Proposition]: - return self._traceable + is_triggering = all( + ph.name == mapping[ph].name for ph in self.action.placeholders if ph.name != ph.type + ) + if callback and is_triggering: + callback(self) + + return is_triggering - @traceable.setter - def traceable(self, traceable: Iterable[Proposition]) -> None: - self._traceable = tuple(traceable) + def __str__(self) -> str: + return str(self.action) + + def __repr__(self) -> str: + return "EventAction(Rule.parse('{}'), name={})".format(self.action, self.name) def __hash__(self) -> int: - return hash((self.actions, self.commands, self.verb_tense, self.traceable)) + return hash((self.actions, self.commands, self.action)) def __eq__(self, other: Any) -> bool: - return (isinstance(other, EventAction) and - self.actions == other.actions and - self.commands == other.commands and - self.verb_tense == other.verb_tense and - self.traceable == other.traceable) + return (isinstance(other, EventAction) + and super().__eq__(other) + and self.action == other.action) @classmethod def deserialize(cls, data: Mapping) -> "EventAction": @@ -357,8 +332,8 @@ def deserialize(cls, data: Mapping) -> "EventAction": data: Serialized data with the needed information to build a `EventAction` object. """ - action = [Action.deserialize(d) for d in data["actions_EventAction"]] - return cls(action, data["verb_tense_EventAction"], data["commands_EventAction"]) + action = Rule.deserialize(data["action"]) + return cls(action=action) def serialize(self) -> Mapping: """ Serialize this event. @@ -366,42 +341,48 @@ def serialize(self) -> Mapping: Results: `EventAction`'s data serialized to be JSON compatible. """ - return {"actions_EventAction": [action.serialize() for action in self.actions], - "commands_EventAction": self.commands, - "verb_tense_EventAction": self.verb_tense, - } - - def copy(self) -> "EventAction": - """ Copy this event. """ - return self.deserialize(self.serialize()) + data = super().serialize() + data["action"] = self.action.serialize() + return data -class EventOr: - def __init__(self, events: Tuple =()): +class EventOr(AbstractEvent): + def __init__(self, events: Iterable[AbstractEvent] = ()): + super().__init__() self.events = events - self._any_triggered = False - self._any_untriggered = False + if len(self.events) == 1: + self.commands = self.events[0].commands + self.actions = self.events[0].actions @property - def events(self) -> Tuple[Union[EventAction, EventCondition]]: + def events(self) -> Tuple[AbstractEvent]: return self._events @events.setter - def events(self, events) -> None: + def events(self, events: Iterable[AbstractEvent]) -> None: self._events = tuple(events) - def are_triggering(self, state, action): - status = [] - for ev in self.events: - if isinstance(ev, EventCondition) or isinstance(ev, EventAction): - status.append(ev.is_triggering(state, [action])) - continue - status.append(ev.are_triggering(state, action)) + def is_triggering(self, state: Optional[State] = None, + action: Optional[Action] = None, + callback: Optional[callable] = None) -> bool: + """ Check if this event would be triggered for a given state and/or action. """ + is_triggering = any(event.is_triggering(state, action, callback) for event in self.events) + if callback and is_triggering: + callback(self) + + return is_triggering + + def __iter__(self) -> Iterable[AbstractEvent]: + yield from self.events - return any(status) + def __len__(self) -> int: + return len(self.events) - def are_events_triggered(self, state, action): - return any((ev.is_triggering(state, action) for ev in self.events)) + def __repr__(self) -> str: + return "EventOr({!r})".format(self.events) + + def __str__(self) -> str: + return "EventOr({})".format(self.events) def __hash__(self) -> int: return hash(self.events) @@ -416,7 +397,9 @@ def serialize(self) -> Mapping: Results: EventOr's data serialized to be JSON compatible """ - return {"events_EventOr": [ev.serialize() for ev in self.events]} + data = super().serialize() + data["events"] = [e.serialize() for e in self.events] + return data @classmethod def deserialize(cls, data: Mapping) -> "EventOr": @@ -425,51 +408,46 @@ def deserialize(cls, data: Mapping) -> "EventOr": Args: data: Serialized data with the needed information to build a `EventOr` object. """ - events = [] - for d in data["events_EventOr"]: - if "condition_EventCondition" in d.keys(): - events.append(EventCondition.deserialize(d)) - elif "actions_EventAction" in d.keys(): - events.append(EventAction.deserialize(d)) - elif "actions_Event" in d.keys(): - events.append(Event.deserialize(d)) - elif "events_EventAnd" in d.keys(): - events.append(EventAnd.deserialize(d)) - elif "events_EventOr" in d.keys(): - events.append(EventOr.deserialize(d)) - - return cls(events) - - def copy(self) -> "EventOr": - """ Copy this EventOr. """ - return self.deserialize(self.serialize()) + return cls([AbstractEvent.deserialize(d) for d in data["events"]]) -class EventAnd: - def __init__(self, events: Tuple = ()): +class EventAnd(AbstractEvent): + def __init__(self, events: Iterable[AbstractEvent] = ()): + super().__init__() self.events = events - self._all_triggered = False - self._all_untriggered = False + if len(self.events) == 1: + self.commands = self.events[0].commands + self.actions = self.events[0].actions @property - def events(self) -> Tuple[Union[EventAction, EventCondition]]: + def events(self) -> Tuple[AbstractEvent]: return self._events @events.setter - def events(self, events) -> None: + def events(self, events: Iterable[AbstractEvent]) -> None: self._events = tuple(events) - def are_triggering(self, state, action): - status = [] - for ev in self.events: - if isinstance(ev, EventCondition) or isinstance(ev, EventAction): - status.append(ev.is_triggering(state, [action])) - continue - status.append(ev.are_triggering(state, action)) - return all(status) + def is_triggering(self, state: Optional[State] = None, + action: Optional[Action] = None, + callback: Optional[callable] = None) -> bool: + """ Check if this event would be triggered for a given state and/or action. """ + is_triggering = all(event.is_triggering(state, action, callback) for event in self.events) + if callback and is_triggering: + callback(self) + + return is_triggering - def are_events_triggered(self, state, action): - return all((ev.is_triggering(state, action) for ev in self.events)) + def __iter__(self) -> Iterable[AbstractEvent]: + yield from self.events + + def __len__(self) -> int: + return len(self.events) + + def __repr__(self) -> str: + return "EventAnd({!r})".format(self.events) + + def __str__(self) -> str: + return "EventAnd({})".format(self.events) def __hash__(self) -> int: return hash(self.events) @@ -484,7 +462,9 @@ def serialize(self) -> Mapping: Results: EventAnd's data serialized to be JSON compatible """ - return {"events_EventAnd": [ev.serialize() for ev in self.events]} + data = super().serialize() + data["events"] = [e.serialize() for e in self.events] + return data @classmethod def deserialize(cls, data: Mapping) -> "EventAnd": @@ -493,24 +473,7 @@ def deserialize(cls, data: Mapping) -> "EventAnd": Args: data: Serialized data with the needed information to build a `EventAnd` object. """ - events = [] - for d in data["events_EventAnd"]: - if "condition_EventCondition" in d.keys(): - events.append(EventCondition.deserialize(d)) - elif "actions_EventAction" in d.keys(): - events.append(EventAction.deserialize(d)) - elif "actions_Event" in d.keys(): - events.append(Event.deserialize(d)) - elif "events_EventAnd" in d.keys(): - events.append(EventAnd.deserialize(d)) - elif "events_EventOr" in d.keys(): - events.append(EventOr.deserialize(d)) - - return cls(events) - - def copy(self) -> "EventAnd": - """ Copy this EventAnd. """ - return self.deserialize(self.serialize()) + return cls([AbstractEvent.deserialize(d) for d in data["events"]]) class Quest: @@ -520,10 +483,10 @@ class Quest: a mutually exclusive set of failing events. Attributes: - win_events: Mutually exclusive set of winning events. That is, + win_event: Mutually exclusive set of winning events. That is, only one such event needs to be triggered in order to complete this quest. - fail_events: Mutually exclusive set of failing events. That is, + fail_event: Mutually exclusive set of failing events. That is, only one such event needs to be triggered in order to fail this quest. reward: Reward given for completing this quest. @@ -531,72 +494,64 @@ class Quest: commands: List of text commands leading to this quest completion. """ + _SERIAL_VERSION = 2 + def __init__(self, - win_events: Iterable[Union[EventAnd, EventOr]] = (), - fail_events: Iterable[Union[EventAnd, EventOr]] = (), + win_event: Optional[AbstractEvent] = None, + fail_event: Optional[AbstractEvent] = None, reward: Optional[int] = None, desc: Optional[str] = None, - commands: Iterable[str] = ()) -> None: + commands: Iterable[str] = (), + **kwargs) -> None: r""" Args: - win_events: Mutually exclusive set of winning events. That is, + win_event: Mutually exclusive set of winning events. That is, + only one such event needs to be triggered in order + to complete this quest. + fail_event: Mutually exclusive set of failing events. That is, only one such event needs to be triggered in order - to complete this quest. - fail_events: Mutually exclusive set of failing events. That is, - only one such event needs to be triggered in order - to fail this quest. + to fail this quest. reward: Reward given for completing this quest. By default, reward is set to 1 if there is at least one winning events otherwise it is set to 0. desc: A text description of the quest. commands: List of text commands leading to this quest completion. """ - self.win_events = win_events - self.fail_events = fail_events + # Backward compatibility: check for old argument names. + if "win_events" in kwargs: + win_event = kwargs["win_events"] + if "fail_events" in kwargs: + fail_event = kwargs["fail_events"] + + # Backward compatibility: convert list of Events to EventOr(events). + if win_event is not None and not isinstance(win_event, AbstractEvent): + win_event = EventOr(win_event) + + if fail_event is not None and not isinstance(fail_event, AbstractEvent): + fail_event = EventOr(fail_event) + + self.win_event = AbstractEvent.to_dnf(win_event) if win_event else None + self.fail_event = AbstractEvent.to_dnf(fail_event) if fail_event else None self.desc = desc self.commands = tuple(commands) - self.win_events_list = self.events_organizer(self.win_events) - self.fail_events_list = self.events_organizer(self.fail_events) - # Unless explicitly provided, reward is set to 1 if there is at least # one winning events otherwise it is set to 0. - self.reward = int(len(win_events) > 0) if reward is None else reward + self.reward = reward or int(self.win_event is not None) - if len(self.win_events) == 0 and len(self.fail_events) == 0: + if self.win_event is None and self.fail_event is None: raise UnderspecifiedQuestError() @property - def win_events(self) -> Iterable[Union[EventOr, EventAnd]]: - return self._win_events - - @win_events.setter - def win_events(self, events: Iterable[Union[EventOr, EventAnd]]) -> None: - self._win_events = tuple(events) - - @property - def win_events_list(self) -> Iterable[Union[EventOr, EventAnd]]: - return self._win_events_list - - @win_events_list.setter - def win_events_list(self, events: Iterable[Union[EventOr, EventAnd]]) -> None: - self._win_events_list = tuple(events) - - @property - def fail_events(self) -> Iterable[Union[EventOr, EventAnd]]: - return self._fail_events - - @fail_events.setter - def fail_events(self, events: Iterable[Union[EventOr, EventAnd]]) -> None: - self._fail_events = tuple(events) + def events(self) -> Iterable[EventAnd]: + events = [] + if self.win_event: + events += list(self.win_event) - @property - def fail_events_list(self) -> Iterable[Union[EventOr, EventAnd]]: - return self._fail_events_list + if self.fail_event: + events += list(self.fail_event) - @fail_events_list.setter - def fail_events_list(self, events: Iterable[Union[EventOr, EventAnd]]) -> None: - self._fail_events_list = tuple(events) + return events @property def commands(self) -> Iterable[str]: @@ -606,44 +561,21 @@ def commands(self) -> Iterable[str]: def commands(self, commands: Iterable[str]) -> None: self._commands = tuple(commands) - def event_organizer(self, combined_event=(), _events=[]): - if isinstance(combined_event, EventCondition) or isinstance(combined_event, EventAction): - _events.append(combined_event) - return - - act = [] - for event in combined_event.events: - out = self.event_organizer(event, act) - if out: - for a in out: - _events.append(a) - - return (len(act) > 0 and len(act) > len(_events)) * act or (len(_events) > 0 and len(_events) > len(act)) * _events - - def events_organizer(self, combined_events=()): - _events_ = [] - for comb_ev in combined_events: - for ev in self.event_organizer(comb_ev, _events=[]): - _events_.append(ev) - - return _events_ - - def is_winning(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: - """ Check if this quest is winning in that particular state. """ + def is_winning(self, state: Optional[State] = None, action: Optional[Action] = None) -> bool: + """ Check if this quest is winning for a given state and/or after a given action. """ + return self.win_event.is_triggering(state, action) - return any(event.are_triggering(state, actions) for event in self.win_events) - - def is_failing(self, state: Optional[State] = None, actions: Tuple[Action] = ()) -> bool: - """ Check if this quest is failing in that particular state. """ - return any(event.are_triggering(state, actions) for event in self.fail_events) + def is_failing(self, state: Optional[State] = None, action: Optional[Action] = None) -> bool: + """ Check if this quest is failing for a given state and/or after a given action. """ + return self.fail_event.is_triggering(state, action) def __hash__(self) -> int: - return hash((self.win_events, self.fail_events, self.reward, self.desc, self.commands)) + return hash((self.win_event, self.fail_event, self.reward, self.desc, self.commands)) def __eq__(self, other: Any) -> bool: return (isinstance(other, Quest) - and self.win_events == other.win_events - and self.fail_events == other.fail_events + and self.win_event == other.win_event + and self.fail_event == other.fail_event and self.reward == other.reward and self.desc == other.desc and self.commands == other.commands) @@ -656,24 +588,22 @@ def deserialize(cls, data: Mapping) -> "Quest": data: Serialized data with the needed information to build a `Quest` object. """ - win_events = [] - for d in data["win_events"]: - if "events_EventOr" in d.keys(): - win_events.append(EventOr.deserialize(d)) - elif "events_EventAnd" in d.keys(): - win_events.append(EventAnd.deserialize(d)) - - fail_events = [] - for d in data["fail_events"]: - if "events_EventOr" in d.keys(): - fail_events.append(EventOr.deserialize(d)) - elif "events_EventAnd" in d.keys(): - fail_events.append(EventAnd.deserialize(d)) - + version = data.get("version", 1) + if version == 1: + win_events = [AbstractEvent.deserialize(event) for event in data["win_events"]] + fail_events = [AbstractEvent.deserialize(event) for event in data["fail_events"]] + commands = data.get("commands", []) + reward = data["reward"] + desc = data["desc"] + quest = cls(win_events, fail_events, reward, desc, commands) + return quest + + win_event = AbstractEvent.deserialize(data["win_event"]) if data["win_event"] else None + fail_event = AbstractEvent.deserialize(data["fail_event"]) if data["fail_event"] else None commands = data.get("commands", []) reward = data["reward"] desc = data["desc"] - return cls(win_events, fail_events, reward, desc, commands) + return cls(win_event, fail_event, reward, desc, commands) def serialize(self) -> Mapping: """ Serialize this quest. @@ -682,11 +612,12 @@ def serialize(self) -> Mapping: Quest's data serialized to be JSON compatible """ return { + "version": self._SERIAL_VERSION, "desc": self.desc, "reward": self.reward, "commands": self.commands, - "win_events": [event.serialize() for event in self.win_events], - "fail_events": [event.serialize() for event in self.fail_events] + "win_event": self.win_event.serialize() if self.win_event else None, + "fail_event": self.fail_event.serialize() if self.fail_event else None } def copy(self) -> "Quest": @@ -760,9 +691,16 @@ class Game: A `Game` is defined by a world and it can have quest(s) or not. Additionally, a grammar can be provided to control the text generation. + + Notes: + ----- + Here's the list of the diffrent `Game` class versions. + - v1: Initial version. + - v2: Games that have been created using the new Event classes. + """ - _SERIAL_VERSION = 1 + _SERIAL_VERSION = 2 def __init__(self, world: World, grammar: Optional[Grammar] = None, quests: Iterable[Quest] = ()) -> None: @@ -806,33 +744,19 @@ def change_grammar(self, grammar: Grammar) -> None: """ Changes the grammar used and regenerate all text. """ self.grammar = grammar - _gen_commands = partial(gen_commands_from_actions, kb=self.kb) if self.grammar: - from textworld.generator.inform7 import Inform7Game from textworld.generator.text_generation import generate_text_from_grammar - inform7 = Inform7Game(self) - _gen_commands = inform7.gen_commands_from_actions generate_text_from_grammar(self, self.grammar) - from textworld.generator.text_generation import describe_quests - self.objective = describe_quests(self, self.grammar) - - for quest in self.quests: - # TODO: should have a generic way of generating text commands from actions - # instead of relying on inform7 convention. - for event in quest.win_events_list: - event.commands = _gen_commands(event.actions) - - if quest.win_events_list: - quest.commands = quest.win_events_list[0].commands # Check if we can derive a global winning policy from the quests. if self.grammar: + from textworld.generator.text_generation import describe_event policy = GameProgression(self).winning_policy if policy: mapping = {k: info.name for k, info in self._infos.items()} commands = [a.format_command(mapping) for a in policy] self.metadata["walkthrough"] = commands - # self.objective = describe_event(EventCondition(actions=policy), self, self.grammar) + self.objective = describe_event(AbstractEvent(policy), self, self.grammar) def save(self, filename: str) -> None: """ Saves the serialized data of this game to a file. """ @@ -854,8 +778,12 @@ def deserialize(cls, data: Mapping) -> "Game": `Game` object. """ - version = data.get("version", cls._SERIAL_VERSION) - if version != cls._SERIAL_VERSION: + version = data.get("version", 1) + if version == 1: + msg = "Loading TextWorld game format (v{})! Current version is {}.".format(version, cls._SERIAL_VERSION) + warnings.warn(msg, TextworldGameVersionWarning) + + elif version != cls._SERIAL_VERSION: msg = "Cannot deserialize a TextWorld version {} game, expected version {}" raise ValueError(msg.format(version, cls._SERIAL_VERSION)) @@ -970,6 +898,21 @@ def objective(self) -> str: def objective(self, value: str): self._objective = value + @property + def walkthrough(self) -> Optional[List[str]]: + walkthrough = self.metadata.get("walkthrough") + if walkthrough: + return walkthrough + + # Check if we can derive a walkthrough from the quests. + policy = GameProgression(self).winning_policy + if policy: + mapping = {k: info.name for k, info in self._infos.items()} + walkthrough = [a.format_command(mapping) for a in policy] + self.metadata["walkthrough"] = walkthrough + + return walkthrough + class ActionDependencyTreeElement(DependencyTreeElement): """ Representation of an `Action` in the dependency tree. @@ -990,11 +933,7 @@ def depends_on(self, other: "ActionDependencyTreeElement") -> bool: of the action1 is not empty, i.e. action1 needs the propositions added by action2. """ - if isinstance(self.action, frozenset): - act = d = [a for a in self.action][0] - else: - act = self.action - return len(other.action.added & act._pre_set) > 0 + return len(other.action.added & self.action._pre_set) > 0 @property def action(self) -> Action: @@ -1103,51 +1042,56 @@ class EventProgression: relevant actions to be performed. """ - def __init__(self, event: Union[EventAnd, EventOr], kb: KnowledgeBase) -> None: + def __init__(self, event: AbstractEvent, kb: KnowledgeBase) -> None: """ Args: quest: The quest to keep track of its completion. """ self._kb = kb or KnowledgeBase.default() - self.event = event + self.event = event # TODO: convert to dnf just to be safe. self._triggered = False self._untriggerable = False - self._policy = () - - # Build a tree representation of the quest. - self._tree = ActionDependencyTree(kb=self._kb, element_type=ActionDependencyTreeElement) - - action_list, _ = self.tree_policy(event) - for action in action_list: - self._tree.push(action) - self._policy = [a for a in action_list[::-1]] - - def tree_policy(self, event): - - if isinstance(event, EventCondition) or isinstance(event, EventAction): - if isinstance(event, EventCondition) and len(event.actions) > 0: - return [event.condition] + [action for action in event.actions[::-1]], 1 - elif isinstance(event, EventAction) and len(event.actions) > 0: - return [action for action in event.actions[::-1]], 0 - else: - return [], 1 - - _actions, _ev_type = [], [] - for ev in event.events: - a, b = self.tree_policy(ev) - _actions.append(a) - _ev_type.append(b) - - if isinstance(event, EventAnd): - act_list = [a for act in [x for _, x in sorted(zip(_ev_type, _actions))] for a in act] - elif isinstance(event, EventOr): - _actions = [x for x in _actions if len(x) > 0] - if _actions: - act_list = min(_actions, key=lambda act: len(act)) + self._policy = None + # self._policy = () + + # Build a tree representations for each subevent. + self._trees = [] + for events in self.event: # Assuming self.event is in DNF. + # trees = [] + + # Dummy action that should trigger when all events are triggered. + conditions = set() + + for event in events: + if isinstance(event, EventCondition): + conditions |= set(event.condition.preconditions) + elif isinstance(event, EventAction): + mapping = {ph: Variable(ph.name, ph.type) for ph in event.action.placeholders} + conditions |= set(predicate.instantiate(mapping) for predicate in event.action.postconditions) + else: + raise NotImplementedError() + + variables = sorted(set([v for c in conditions for v in c.arguments])) + event = Proposition("event", arguments=variables) + trigger = Action("trigger", preconditions=conditions, postconditions=list(conditions) + [event]) + + tree = ActionDependencyTree(kb=self._kb, element_type=ActionDependencyTreeElement) + tree.push(trigger) + + if events.actions: + for action in events.actions[::-1]: + tree.push(action) else: - act_list = [] + for event in events: + for action in event.actions[::-1]: + tree.push(action) + + # trees.append(tree) - return act_list, 0 + # trees = ActionDependencyTree(kb=self._kb, + # element_type=ActionDependencyTreeElement, + # trees=trees) + self._trees.append(tree) def copy(self) -> "EventProgression": """ Return a soft copy. """ @@ -1155,17 +1099,37 @@ def copy(self) -> "EventProgression": ep._triggered = self._triggered ep._untriggerable = self._untriggerable ep._policy = self._policy - ep._tree = self._tree.copy() + ep._trees = [tree.copy() for tree in self._trees] return ep @property - def triggering_policy(self) -> List[Action]: - """ Actions to be performed in order to trigger the event. """ + def triggering_policy(self) -> Optional[List[Action]]: if self.done: return () - # Discard all "trigger" actions. - return tuple(a for a in self._policy if a.name != "trigger") + if self._policy is None or True: # TODO + policies = [] + for trees in self._trees: + # Discard all "trigger" actions. + policies.append(tuple(a for a in trees.flatten() if a.name != "trigger")) + + self._policy = min(policies, key=lambda policy: len(policy)) + + return self._policy + + @property + def _tree(self): + best = None + best_policy = None + for trees in self._trees: + # Discard all "trigger" actions. + policy = tuple(a for a in trees.flatten() if a.name != "trigger") + + if best is None or len(best_policy) > len(policy): + best = trees + best_policy = policy + + return best @property def done(self) -> bool: @@ -1180,9 +1144,11 @@ def triggered(self) -> bool: @property def untriggerable(self) -> bool: """ Check whether the event is in an untriggerable state. """ - return self._untriggerable + return len(self._trees) == 0 - def update(self, action: Tuple[Action] = (), state: Optional[State] = None) -> None: + def update(self, action: Optional[Action] = None, + state: Optional[State] = None, + callback: Optional[callable] = None) -> None: """ Update event progression given available information. Args: @@ -1194,23 +1160,29 @@ def update(self, action: Tuple[Action] = (), state: Optional[State] = None) -> N if state is not None: # Check if event is triggered. - self._triggered = self.event.are_triggering(state, action) + self._triggered = self.event.is_triggering(state, action, callback) + + # Update each dependency trees. + to_delete = [] + for i, trees in enumerate(self._trees): + if self._compress_policy(i, state): + continue # A shorter winning policy has been found. - # Try compressing the winning policy given the new game state. - if self.compress_policy(state): - return # A shorter winning policy has been found. + if action and not trees.empty: + # Determine if we moved away from the goal or closer to it. + changed, reverse_action = trees.remove(action) + if changed and reverse_action is None: # Irreversible action. + to_delete.append(trees) - if action and not self._tree.empty: - # Determine if we moved away from the goal or closer to it. - changed, reverse_action = self._tree.remove(action) - if changed and reverse_action is None: # Irreversible action. - self._untriggerable = True # Can't track quest anymore. + if changed and reverse_action is not None: + # Rebuild policy. + # self._policy = tuple(self._tree.flatten()) + self._policy = None # Will be rebuilt on the next call of triggering_policy. - if changed and reverse_action is not None: - # Rebuild policy. - self._policy = tuple(self._tree.flatten()) + for e in to_delete: + self._trees.remove(e) - def compress_policy(self, state: State) -> bool: + def _compress_policy(self, idx, state: State) -> bool: """ Compress the policy given a game state. Args: @@ -1219,34 +1191,30 @@ def compress_policy(self, state: State) -> bool: Returns: Whether the policy was compressed or not. """ + # Make sure the compressed policy has the same roots. + root_actions = [root.element.action for root in self._trees[idx].roots] def _find_shorter_policy(policy): for j in range(0, len(policy)): for i in range(j + 1, len(policy))[::-1]: shorter_policy = policy[:j] + policy[i:] - if state.is_sequence_applicable(shorter_policy): - self._tree = ActionDependencyTree(kb=self._kb, - element_type=ActionDependencyTreeElement) + if state.is_sequence_applicable(shorter_policy) and all(a in shorter_policy for a in root_actions): + self._trees[idx] = ActionDependencyTree(kb=self._kb, element_type=ActionDependencyTreeElement) for action in shorter_policy[::-1]: - self._tree.push(action) + self._trees[idx].push(action, allow_multi_root=True) return shorter_policy + return None compressed = False - policy = _find_shorter_policy(tuple(a for a in self._tree.flatten())) + policy = _find_shorter_policy(tuple(a for a in self._trees[idx].flatten())) while policy is not None: compressed = True - self._policy = policy policy = _find_shorter_policy(policy) return compressed - def will_trigger(self, state: State, action: Tuple[Action]): - triggered = self.event.are_triggering(state, action) - - return triggered - class QuestProgression: """ QuestProgression keeps track of the completion of a quest. @@ -1262,41 +1230,32 @@ def __init__(self, quest: Quest, kb: KnowledgeBase) -> None: """ self.quest = quest self.kb = kb - self.win_events = [EventProgression(event, kb) for event in quest.win_events] - self.fail_events = [EventProgression(event, kb) for event in quest.fail_events] + self.win_event = EventProgression(quest.win_event, kb) if quest.win_event is not None else None + self.fail_event = EventProgression(quest.fail_event, kb) if quest.fail_event is not None else None def copy(self) -> "QuestProgression": """ Return a soft copy. """ qp = QuestProgression(self.quest, self.kb) - qp.win_events = [event_progression.copy() for event_progression in self.win_events] - qp.fail_events = [event_progression.copy() for event_progression in self.fail_events] + qp.win_event = self.win_event.copy() if self.win_event is not None else None + qp.fail_event = self.fail_event.copy() if self.fail_event is not None else None return qp @property def _tree(self) -> Optional[List[ActionDependencyTree]]: - events = [event for event in self.win_events if len(event.triggering_policy) > 0] - if len(events) == 0: - return None - - event = min(events, key=lambda event: len(event.triggering_policy)) - return event._tree + return self.win_event._tree @property def winning_policy(self) -> Optional[List[Action]]: """ Actions to be performed in order to complete the quest. """ - if self.done: - return None - - winning_policies = [event.triggering_policy for event in self.win_events if len(event.triggering_policy) > 0] - if len(winning_policies) == 0: + if self.done or self.win_event is None: return None - return min(winning_policies, key=lambda policy: len(policy)) + return self.win_event.triggering_policy @property def completable(self) -> bool: """ Check if the quest has winning events. """ - return len(self.win_events) > 0 + return self.win_event is not None @property def done(self) -> bool: @@ -1306,20 +1265,21 @@ def done(self) -> bool: @property def completed(self) -> bool: """ Check whether the quest is completed. """ - return all(event.triggered for event in self.win_events) - # return any(event.triggered for event in self.win_events) + return self.win_event is not None and self.win_event.triggered @property def failed(self) -> bool: """ Check whether the quest has failed. """ - return any(event.triggered for event in self.fail_events) + return self.fail_event is not None and self.fail_event.triggered @property def unfinishable(self) -> bool: """ Check whether the quest is in an unfinishable state. """ - return any(event.untriggerable for event in self.win_events) + return self.win_event.untriggerable if self.win_event else False - def update(self, action: Optional[Action] = None, state: Optional[State] = None) -> None: + def update(self, action: Optional[Action] = None, + state: Optional[State] = None, + callback: Optional[callable] = None) -> None: """ Update quest progression given available information. Args: @@ -1329,8 +1289,15 @@ def update(self, action: Optional[Action] = None, state: Optional[State] = None) if self.done: return # Nothing to do, the quest is already done. - for event in (self.win_events + self.fail_events): - event.update(action, state) + if self.win_event: + self.win_event.update(action, state, callback) + + # Only update fail_event if the quest is not completed. + if self.completed: + return + + if self.fail_event: + self.fail_event.update(action, state, callback) class GameProgression: @@ -1348,13 +1315,15 @@ def __init__(self, game: Game, track_quests: bool = True) -> None: """ self.game = game self.state = game.world.state.copy() + self.callback = None self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), self.game.kb.types.constants_mapping)) + self.quest_progressions = [] if track_quests: self.quest_progressions = [QuestProgression(quest, game.kb) for quest in game.quests] for quest_progression in self.quest_progressions: - quest_progression.update(action=(), state=self.state) + quest_progression.update(action=None, state=self.state) def copy(self) -> "GameProgression": """ Return a soft copy. """ @@ -1366,11 +1335,6 @@ def copy(self) -> "GameProgression": return gp - def valid_actions_gen(self): - potential_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) - return [act for act in potential_actions if act.is_valid()] - @property def done(self) -> bool: """ Whether all quests are completed or at least one has failed or is unfinishable. """ @@ -1378,7 +1342,7 @@ def done(self) -> bool: @property def completed(self) -> bool: - """ Whether all quests are completed. """ + """ Whether all completable quests are completed. """ if not self.tracking_quests: return False # There is nothing to be "completed". @@ -1430,75 +1394,27 @@ def winning_policy(self) -> Optional[List[Action]]: master_quest_tree = ActionDependencyTree(kb=self.game.kb, element_type=ActionDependencyTreeElement, trees=trees) - actions = tuple(a for a in master_quest_tree.flatten() if a.name != "trigger") - for action in actions: - if not action.command_template: - m = {c: d for c in self.game.kb.rules[action.name].placeholders for d in action.variables if c.type == d.type} - substitutions = {ph.name: "{{{}}}".format(var.name) for ph, var in m.items()} - action.command_template = self.game.kb.rules[action.name].command_template.format(**substitutions) # Discard all "trigger" actions. return tuple(a for a in master_quest_tree.flatten() if a.name != "trigger") - def any_traceable_exist(self, events): - if isinstance(events, EventCondition) or isinstance(events, EventAction): - return len(events.traceable) > 0 and not (events.traceable in self.state.facts) - - trc_exist = [] - for event in events.events: - trc_exist.append(self.any_traceable_exist(event)) - - return any(trc_exist) - - def add_traceables(self, action): - trace = [] - for quest_progression in self.quest_progressions: - if quest_progression.quest.reward >= 0: - for win_event in quest_progression.win_events: - if self.any_traceable_exist(win_event.event): - if win_event.will_trigger(self.state, tuple([action])): - trace.append(tr for eve in win_event.event.events for tr in eve.traceable) - - return [p for ar in trace for p in ar] - - def traceable_manager(self): - if not self.state.has_traceable(): - return - - for prop in self.state.get_facts(): - if not prop.name.startswith('is__'): - PropositionControl.remove(prop, self.state) - - def update(self, action: Action) -> None: + def update(self, action: Action, callback: Optional[callable] = None) -> None: """ Update the state of the game given the provided action. Args: action: Action affecting the state of the game. """ - # Update world facts - self.state.apply(action) - trace = self.add_traceables(action) - if trace: - for prop in trace: - if prop.name.startswith('has_been') and prop not in self.state.facts: - self.state.add_facts([prop]) - - # Update all quest progressions given the last action and new state. - for quest_progression in self.quest_progressions: - quest_progression.update(action, self.state) - # Update world facts. - if trace: - for prop in trace: - if not prop.name.startswith('has_been') and prop not in self.state.facts: - self.state.add_facts([prop]) - - self.traceable_manager() + self.state.apply(action) # Get valid actions. self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(), self.game.kb.types.constants_mapping)) + # Update all quest progressions given the last action and new state. + for quest_progression in self.quest_progressions: + quest_progression.update(action, self.state, callback or self.callback) + class GameOptions: """ @@ -1623,7 +1539,7 @@ def _key_missing(seeds): @property def rngs(self) -> Dict[str, RandomState]: rngs = {} - for key, seed in self._seeds.items(): + for key, seed in self.seeds.items(): rngs[key] = RandomState(seed) return rngs diff --git a/textworld/generator/inform7/tests/test_world2inform7.py b/textworld/generator/inform7/tests/test_world2inform7.py index 795d5b19..09861727 100644 --- a/textworld/generator/inform7/tests/test_world2inform7.py +++ b/textworld/generator/inform7/tests/test_world2inform7.py @@ -3,9 +3,14 @@ import itertools +import unittest +import shutil +import tempfile +from os.path import join as pjoin import textworld from textworld import g_rng +from textworld import testing from textworld.utils import make_temp_directory from textworld.core import EnvInfos @@ -104,9 +109,10 @@ def _rule_to_skip(rule): assert not done assert not game_state.won - game_state, _, done = env.step(event.commands[0]) - assert done - assert game_state.won + for cmd in game.walkthrough: + game_state, _, done = env.step(cmd) + assert done + assert game_state.won def test_quest_with_multiple_winning_and_losing_conditions(): @@ -479,3 +485,104 @@ def test_take_all_and_variants(): assert "blue ball:" in game_state.feedback assert "red ball" in game_state.inventory assert "blue ball" in game_state.inventory + + +class TestInform7Game(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.DATA = testing.build_complex_test_game() + cls.game = cls.DATA["game"] + cls.quest1 = cls.DATA["quest1"] + cls.quest2 = cls.DATA["quest2"] + cls.eating_carrot = cls.DATA["eating_carrot"] + cls.onion_eaten = cls.DATA["onion_eaten"] + cls.closing_chest_without_carrot = cls.DATA["closing_chest_without_carrot"] + + cls.tmpdir = pjoin(tempfile.mkdtemp(prefix="test_inform7_game"), "") + options = textworld.GameOptions() + options.path = cls.tmpdir + options.seeds = 20210512 + options.file_ext = ".z8" + cls.game_file = compile_game(cls.game, options) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir) + + def _apply_commands(self, env, commands, state=None): + state = state or env.reset() + for cmd in commands: + assert not state.done + state, _, done = env.step(cmd) + + return state + + def test_game_completion(self): + env = textworld.start(self.game_file) + + # Do quest1. + state = self._apply_commands(env, self.quest1.win_event.events[0].commands) + assert state.score == self.quest1.reward + assert not state.done + + # Then quest2. + state = self._apply_commands(env, self.quest2.win_event.events[0].commands, state) + assert state.score == state.max_score + assert state.done + assert state.won and not state.lost + + # Alternative winning strategy for quest1. + state = self._apply_commands(env, self.quest1.win_event.events[1].commands) + assert state.score == self.quest1.reward + assert not state.done + + state = self._apply_commands(env, self.quest2.win_event.events[0].commands, state) + assert state.score == state.max_score + assert state.done + assert state.won and not state.lost + + # Start with quest2, then quest1. + state = self._apply_commands(env, self.quest2.win_event.events[0].commands) + assert state.score == self.quest2.reward + assert not state.done + + state = self._apply_commands(env, self.quest1.win_event.events[0].commands, state) + assert state.score == state.max_score + assert state.done + assert state.won and not state.lost + + # Closing the chest containing the onion while holding the carrot should not fail. + state = self._apply_commands(env, self.eating_carrot.commands[:1]) # Take carrot. + state = self._apply_commands(env, self.quest1.win_event.events[1].commands, state) + assert state.score == self.quest1.reward + assert not state.done + + def test_game_failure(self): + env = textworld.start(self.game_file) + + # onion_eaten -> eating_carrot != eating_carrot -> onion_eaten + # Eating the carrot *after* eating the onion causes the game to be lost. + state = self._apply_commands(env, self.onion_eaten.commands) + state = self._apply_commands(env, self.eating_carrot.commands, state) + assert state.done + assert not state.won and state.lost + + # Eating the carrot *before* eating the onion does not lose the game, + # but the game becomes unfinishable. + state = self._apply_commands(env, self.eating_carrot.commands) + state = self._apply_commands(env, self.onion_eaten.commands, state) + assert not (state.done or state.won or state.lost) # Couldn't detect game is unfinishable. + + env_with_state_tracking = textworld.start(self.game_file, EnvInfos(policy_commands=True)) + state = self._apply_commands(env_with_state_tracking, self.eating_carrot.commands) + state = self._apply_commands(env_with_state_tracking, self.onion_eaten.commands, state) + assert not (state.done or state.won or state.lost) # Still won't tell the game is unfinishable. + assert state.policy_commands == [] # But there isn't no sequence of commands that can complete the game. + + # Closing the chest *while* the carrot is in the inventory. + state = self._apply_commands(env, ["open chest", "close chest"]) + assert not state.lost + state = self._apply_commands(env, self.closing_chest_without_carrot.commands, state) + assert state.done + assert not state.won and state.lost diff --git a/textworld/generator/inform7/world2inform7.py b/textworld/generator/inform7/world2inform7.py index 08c3f0c6..1f996a3d 100644 --- a/textworld/generator/inform7/world2inform7.py +++ b/textworld/generator/inform7/world2inform7.py @@ -15,7 +15,8 @@ from textworld.utils import make_temp_directory, str2bool, chunk -from textworld.generator.game import EventCondition, EventAction, Event, EventAnd, EventOr, Game +from textworld.generator.game import Game +from textworld.generator.game import EventCondition, EventAction from textworld.generator.world import WorldRoom, WorldEntity from textworld.logic import Signature, Proposition, Action, Variable @@ -99,12 +100,7 @@ def gen_source_for_attribute(self, attr: Proposition) -> Optional[str]: def gen_source_for_attributes(self, attributes: Iterable[Proposition]) -> str: source = "" for attr in attributes: - if attr.name.count('__') == 0: - attr_ = Proposition(name='is__' + attr.name, arguments=attr.arguments, verb='is', definition=attr.name) - else: - attr_ = attr - - source_attr = self.gen_source_for_attribute(attr_) + source_attr = self.gen_source_for_attribute(attr) if source_attr: source += source_attr + ".\n" @@ -119,32 +115,21 @@ def gen_source_for_conditions(self, conds: Iterable[Proposition]) -> str: if i7_cond: i7_conds.append(i7_cond) - # HACK: In Inform7 we have to mention a container/door is unlocked AND closed. - for cond in conds: + # HACK: In Inform7 we have to mention a container/door is unlocked AND closed. if cond.name == "closed": i7_conds.append("the {} is unlocked".format(cond.arguments[0].name)) - return " and ".join(i7_conds) + return i7_conds - def gen_source_for_rule(self, rule: Action) -> Optional[str]: - pt = self.kb.inform7_events[rule.name] + def gen_source_for_event_action(self, event: EventAction) -> Optional[str]: + pt = self.kb.inform7_events[event.action.name] if pt is None: - msg = "Undefined Inform7's command: {}".format(rule.name) + msg = "Undefined Inform7's action: {}".format(event.action.name) warnings.warn(msg, TextworldInform7Warning) - return None - - return pt.format(**self._get_entities_mapping(rule)) - - def gen_source_for_actions(self, acts: Iterable[Action]) -> str: - """Generate Inform 7 source for winning/losing actions.""" - - i7_acts = [] - for act in acts: - i7_act = self.gen_source_for_rule(act) - if i7_act: - i7_acts.append(i7_act) + return [] - return " and ".join(i7_acts) + mapping = {ph.type: ph.name for ph in event.action.placeholders} + return [pt.format(**mapping)] def gen_source_for_objects(self, objects: Iterable[WorldEntity]) -> str: source = "" @@ -218,15 +203,8 @@ def gen_source_for_rooms(self) -> str: def _get_name_mapping(self, action): mapping = self.kb.rules[action.name].match(action) - for ph, var in mapping.items(): - a = ph.name - b = self.entity_infos[var.name].name return {ph.name: self.entity_infos[var.name].name for ph, var in mapping.items()} - def _get_entities_mapping(self, action): - mapping = self.kb.rules[action.name].match(action) - return {ph.name: self.entity_infos[var.name].id for ph, var in mapping.items()} - def gen_commands_from_actions(self, actions: Iterable[Action]) -> List[str]: commands = [] for action in actions: @@ -271,8 +249,6 @@ def detect_action(self, i7_event: str, actions: Iterable[Action]) -> Optional[Ac """ # Prioritze actions with many precondition terms. actions = sorted(actions, key=lambda a: len(a.preconditions), reverse=True) - from pprint import pprint - pprint(actions) for action in actions: event = self.kb.inform7_events[action.name] if event.format(**self._get_name_mapping(action)).lower() == i7_event.lower(): @@ -346,9 +322,6 @@ def gen_source(self, seed: int = 1234) -> str: objective = self.game.objective.replace("\n", "[line break]") maximum_score = 0 - wining = 0 - quests_text, viewed_actions = [], {} - action_id = [] for quest_id, quest in enumerate(self.game.quests): maximum_score += quest.reward @@ -356,14 +329,12 @@ def gen_source(self, seed: int = 1234) -> str: The quest{quest_id} completed is a truth state that varies. The quest{quest_id} completed is usually false. """) - quest_ending = quest_completed.format(quest_id=quest_id) + source += quest_completed.format(quest_id=quest_id) - for event_id, event in enumerate(quest.win_events_list): - commands = self.gen_commands_from_actions(event.actions) - event.commands = commands - - walkthrough = '\nTest quest{}_{} with "{}"\n\n'.format(quest_id, event_id, " / ".join(commands)) - quest_ending += walkthrough + if quest.win_event: + commands = quest.win_event.commands or self.gen_commands_from_actions(quest.win_event.actions) + walkthrough = '\nTest quest{} with "{}"\n\n'.format(quest_id, " / ".join(commands)) + source += walkthrough # Add winning and losing conditions for quest. quest_ending_conditions = textwrap.dedent("""\ @@ -371,76 +342,58 @@ def gen_source(self, seed: int = 1234) -> str: do nothing;""".format(quest_id=quest_id)) fail_template = textwrap.dedent(""" - otherwise if {conditions}: - end the story; [Lost];""") + else if {conditions}: + end the story; [Lost]""") win_template = textwrap.dedent(""" - otherwise if {conditions}: + else if {conditions}: increase the score by {reward}; [Quest completed] - Now the quest{quest_id} completed is true; - {removed_conditions}""") - - otherwise_template = textwrap.dedent("""\ - otherwise: - {removed_conditions}""") - - conditions, removals = '', '' - cond_id = [] - for fail_event in quest.fail_events: - condition, removed_conditions, final_condition, _, _ = self.get_events(fail_event, - textwrap.dedent(""""""), - textwrap.dedent(""""""), - action_id=action_id, - cond_id=cond_id, - quest_id=quest_id, - rwd_conds=viewed_actions) - removals += (len(removals) > 0) * ' ' + '' + removed_conditions - quest_ending_conditions += fail_template.format(conditions=final_condition) - conditions += condition - - wining += 1 - - for win_event in quest.win_events: - condition, removed_conditions, final_condition, _, _ = self.get_events(win_event, - textwrap.dedent(""""""), - textwrap.dedent(""""""), - action_id=action_id, - cond_id=cond_id, - quest_id=quest_id, - rwd_conds=viewed_actions) - removals += (len(removals) > 0) * ' ' + '' + removed_conditions - quest_ending_conditions += win_template.format(reward=quest.reward, quest_id=quest_id, - conditions=final_condition, - removed_conditions=textwrap.indent(removals, "")) - conditions += condition - - wining += 1 - - if removals: - quest_ending_conditions += otherwise_template.format(removed_conditions=textwrap.indent(removals, "")) - - quest_condition_template = """\ + Now the quest{quest_id} completed is true;""") + + if quest.win_event: + # Assuming quest.win_event is in a DNF. + for events in quest.win_event: # Loop over EventOr + i7_conds = [] + for event in events: # Loop over EventAnd + if isinstance(event, EventCondition): + i7_conds += self.gen_source_for_conditions(event.condition.preconditions) + elif isinstance(event, EventAction): + i7_conds += self.gen_source_for_event_action(event) + else: + raise NotImplementedError("Unknown event type: {!r}".format(event)) + + quest_ending_conditions += win_template.format(conditions=" and ".join(i7_conds), + reward=quest.reward, + quest_id=quest_id) + + if quest.fail_event: + # Assuming quest.fail_event is in a DNF. + for events in quest.fail_event: # Loop over EventOr + i7_conds = [] + for event in events: # Loop over EventAnd + if isinstance(event, EventCondition): + i7_conds += self.gen_source_for_conditions(event.condition.preconditions) + elif isinstance(event, EventAction): + i7_conds += self.gen_source_for_event_action(event) + else: + raise NotImplementedError("Unknown event type: {!r}".format(event)) + + quest_ending_conditions += fail_template.format(conditions=" and ".join(i7_conds)) + + quest_ending = """\ Every turn:\n{conditions} """.format(conditions=textwrap.indent(quest_ending_conditions, " ")) - - quest_ending += textwrap.dedent(quest_condition_template) - - source += textwrap.dedent(conditions) - source += textwrap.dedent('\n') - quests_text += [quest_ending] - - source += textwrap.dedent('\n'.join(txt for txt in quests_text if txt)) + source += textwrap.dedent(quest_ending) # Enable scoring is at least one quest has nonzero reward. - if maximum_score >= 0: + if maximum_score != 0: source += "Use scoring. The maximum score is {}.\n".format(maximum_score) # Build test condition for winning the game. game_winning_test = "1 is 0 [always false]" - if wining > 0: - if maximum_score != 0: - game_winning_test = "score is at least maximum score" + if len(self.game.quests) > 0: + game_winning_test = "score is maximum score" # Remove square bracket when printing score increases. Square brackets are conflicting with # Inform7's events parser in tw_inform7.py. @@ -458,7 +411,6 @@ def gen_source(self, seed: int = 1234) -> str: if {game_winning_test}: end the story finally; [Win] - The simpler notify score changes rule substitutes for the notify score changes rule. """.format(game_winning_test=game_winning_test)) @@ -1044,88 +996,6 @@ def gen_source(self, seed: int = 1234) -> str: return source - def get_events(self, combined_events, txt, rmv, quest_id, rwd_conds, action_id=[], - cond_id=[], check_vars=[]): - - action_processing_template = textwrap.dedent(""" - The action{action_id} check is a truth state that varies. - The action{action_id} check is usually false. - After {actions}: - Now the action{action_id} check is true. - """) - - remove_action_processing_template = textwrap.dedent("""Now the action{action_id} check is false; - """) - - combined_ac_processing_template = textwrap.dedent(""" - The condition{cond_id} of quest{quest_id} check is a truth state that varies. - The condition{cond_id} of quest{quest_id} check is usually false. - Every turn: - if {conditions}: - Now the condition{cond_id} of quest{quest_id} check is true. - """) - - remove_condition_processing_template = textwrap.dedent("""Now the condition{cond_id} of quest{quest_id} check is false; - """) - - if isinstance(combined_events, EventCondition) or isinstance(combined_events, EventAction): - if isinstance(combined_events, EventCondition): - check_vars += [self.gen_source_for_conditions(combined_events.condition.preconditions)] - return [None] * 5 - - elif isinstance(combined_events, EventAction): - i7_ = self.gen_source_for_actions(combined_events.actions) - if not rwd_conds or i7_ not in rwd_conds.values(): - txt += [action_processing_template.format(action_id=len(action_id), actions=i7_)] - rmv += [remove_action_processing_template.format(action_id=len(action_id))] - temp = [self.gen_source_for_conditions([prop]) for prop in combined_events.actions[0].preconditions - if prop.verb != 'is'] - if temp: - temp = ' and ' + ' and '.join(t for t in temp) - else: - temp = '' - check_vars += ['action{action_id} check is true'.format(action_id=len(action_id)) + temp] - rwd_conds['action{action_id}'.format(action_id=len(action_id))] = i7_ - action_id += [1] - else: - word = list(rwd_conds.keys())[list(rwd_conds.values()).index(i7_)] - rmv += [remove_action_processing_template.format(action_id=word[6:])] - temp = [self.gen_source_for_conditions([prop]) for prop in combined_events.actions[0].preconditions - if prop.verb != 'is'] - check_vars += ['action{action_id} check is true'.format(action_id=word[6:]) + ' and ' + - ' and '.join(t for t in temp)] - - return [None] * 5 - - act_type, _txt, _rmv, _check_vars, _cond_id = [], [], [], [], [] - for event in combined_events.events: - st, rm, a3, a4, cond_type = self.get_events(event, _txt, _rmv, quest_id, rwd_conds, action_id, cond_id, - check_vars=_check_vars) - act_type.append(isinstance(event, EventAction)) - - if st: - _txt += [st] - _rmv += [rm] - _check_vars.append('condition{cond_id} of quest{quest_id} check is true'.format(cond_id=len(cond_id)-1, - quest_id=quest_id)) - if cond_type: - _cond_id += cond_type - - if any(_cond_id): - _rmv += [remove_condition_processing_template.format(quest_id=quest_id, cond_id=len(cond_id) - 1)] - - event_rule = isinstance(combined_events, EventAnd) * ' and ' + isinstance(combined_events, EventOr) * ' or ' - condition_ = event_rule.join(cv for cv in _check_vars) - tp_txt = ''.join(tx for tx in _txt) - tp_txt += combined_ac_processing_template.format(quest_id=quest_id, cond_id=len(cond_id), conditions=condition_) - tp_rmv = ' '.join(ac for ac in _rmv if ac) - fin_cond = 'condition{cond_id} of quest{quest_id} check is true'.format(cond_id=len(cond_id), quest_id=quest_id) - cond_id += [1] - if any(act_type): - cond_type = [True] - - return tp_txt, tp_rmv, fin_cond, [action_id, cond_id, rwd_conds], cond_type - def generate_inform7_source(game: Game, seed: int = 1234, use_i7_description: bool = False) -> str: inform7 = Inform7Game(game) diff --git a/textworld/generator/logger.py b/textworld/generator/logger.py index 88a118d7..358733fa 100644 --- a/textworld/generator/logger.py +++ b/textworld/generator/logger.py @@ -76,8 +76,8 @@ def collect(self, game): # Collect distribution of commands leading to winning events. for quest in game.quests: self.quests.add(quest.desc) - for event in quest.win_events: - actions = event.actions + for events in quest.win_event: + actions = events.actions update_bincount(self.dist_quest_length_count, len(actions)) for action in actions: diff --git a/textworld/generator/maker.py b/textworld/generator/maker.py index bfecfb6e..569fd74e 100644 --- a/textworld/generator/maker.py +++ b/textworld/generator/maker.py @@ -19,9 +19,9 @@ from textworld.generator.graph_networks import direction from textworld.generator.data import KnowledgeBase from textworld.generator.vtypes import get_new -from textworld.logic import State, Variable, Proposition, Action +from textworld.logic import State, Variable, Proposition, Action, Placeholder from textworld.generator.game import GameOptions -from textworld.generator.game import Game, World, Quest, EventAnd, EventOr, EventCondition, EventAction, EntityInfo +from textworld.generator.game import Game, World, Quest, EventAnd, EventCondition, EventAction, EntityInfo from textworld.generator.graph_networks import DIRECTIONS from textworld.render import visualize from textworld.envs.wrappers import Recorder @@ -29,7 +29,7 @@ def get_failing_constraints(state, kb: Optional[KnowledgeBase] = None): kb = kb or KnowledgeBase.default() - fail = Proposition("is__fail", []) + fail = Proposition("fail", []) failed_constraints = [] constraints = state.all_applicable_actions(kb.constraints.values()) @@ -45,45 +45,6 @@ def get_failing_constraints(state, kb: Optional[KnowledgeBase] = None): return failed_constraints -def new_operation(operation={}): - def func(operator='or', events=[]): - if operator == 'or' and events: - return EventOr(events=events) - if operator == 'and' and events: - return EventAnd(events=events) - else: - raise - - if not isinstance(operation, dict): - if len(operation) == 0: - return () - else: - operation = {'or': tuple(ev for ev in operation)} - - y1 = [] - for k, v in operation.items(): - if isinstance(v, dict): - y1.append(new_operation(operation=v)[0]) - y1 = [func(k, y1)] - else: - if isinstance(v, EventCondition) or isinstance(v, EventAction): - y1.append(func(k, [v])) - else: - if any((isinstance(it, dict) for it in v)): - y2 = [] - for it in v: - if isinstance(it, dict): - y2.append(new_operation(operation=it)[0]) - else: - y2.append(func(k, [it])) - - y1 = [func(k, y2)] - else: - y1.append(func(k, v)) - - return tuple(y1) - - class MissingPlayerError(ValueError): pass @@ -190,7 +151,7 @@ def add_fact(self, name: str, *entities: List["WorldEntity"]) -> None: *entities: A list of entities as arguments to the new fact. """ args = [entity.var for entity in entities] - self._facts.append(Proposition(name='is__' + name, arguments=args)) + self._facts.append(Proposition(name, args)) def remove_fact(self, name: str, *entities: List["WorldEntity"]) -> None: args = [entity.var for entity in entities] @@ -710,8 +671,8 @@ def set_quest_from_commands(self, commands: List[str]) -> Quest: unrecognized_commands = [c for c, a in zip(commands, recorder.actions) if a is None] raise QuestError("Some of the actions were unrecognized: {}".format(unrecognized_commands)) - event = self.new_event(action=actions, condition=winning_facts, command=commands, event_style=event_style) - self.quests = [self.new_quest(win_event=[event])] + event = EventCondition(actions=actions, commands=commands) + self.quests = [Quest(win_event=event, commands=commands)] # Calling build will generate the description for the quest. self.build() @@ -734,51 +695,21 @@ def new_action(self, name: str, *entities: List["WorldEntity"]) -> Union[None, A name: The name of the rule which can be used for the new rule fact as well. *entities: A list of entities as arguments to the new rule fact. """ + if name not in self._kb.rules: + raise ValueError("Can't find action: '{}'".format(name)) - def new_conditions(conditions, args): - new_ph = [] - for pred in conditions: - new_var = [var for ph in pred.parameters for var in args if ph.type == var.type] - new_ph.append(Proposition(name=pred.name, arguments=new_var)) - return new_ph - - args = [entity.var for entity in entities] - - for rule in self._kb.rules.values(): - if rule.name == name.name: - precond = new_conditions(rule.preconditions, args) - postcond = new_conditions(rule.postconditions, args) - - action = Action(rule.name, precond, postcond) - - if action.has_traceable(): - action.activate_traceable() - - return action - - return None - - def new_event(self, action: Iterable[Action] = (), condition: Iterable[Proposition] = (), - command: Iterable[str] = (), condition_verb_tense: dict = (), action_verb_tense: dict = (), - event_style: str = 'condition'): - if event_style == 'condition': - event = EventCondition(conditions=condition, verb_tense=condition_verb_tense, actions=action, - commands=command) - return event - elif event_style == 'action': - event = EventAction(actions=action, verb_tense=action_verb_tense, commands=command) - return event - else: - raise UnderspecifiedEventError + rule = self._kb.rules[name] + mapping = {Placeholder(entity.type): Placeholder(entity.id, entity.type) for entity in entities} + return rule.substitute(mapping) - def new_quest(self, win_event=(), fail_event=(), reward=None, desc=None, commands=()) -> Quest: - return Quest(win_events=new_operation(operation=win_event), - fail_events=new_operation(operation=fail_event), + def new_quest(self, win_event=None, fail_event=None, reward=None, desc=None, commands=()) -> Quest: + return Quest(win_event=win_event, + fail_event=fail_event, reward=reward, desc=desc, commands=commands) - def new_event_using_commands(self, commands: List[str], event_style: str) -> Union[EventCondition, EventAction]: + def new_event_using_commands(self, commands: List[str]) -> Union[EventCondition, EventAction]: """ Creates a new event using predefined text commands. This launches a `textworld.play` session to execute provided commands. @@ -800,10 +731,10 @@ def new_event_using_commands(self, commands: List[str], event_style: str) -> Uni # Skip "None" actions. actions, commands = zip(*[(a, c) for a, c in zip(recorder.actions, commands) if a is not None]) - event = self.new_event(action=actions, command=commands, event_style=event_style) + event = EventCondition(actions=actions, commands=commands) return event - def new_quest_using_commands(self, commands: List[str], event_style: str) -> Quest: + def new_quest_using_commands(self, commands: List[str]) -> Quest: """ Creates a new quest using predefined text commands. This launches a `textworld.play` session to execute provided commands. @@ -814,38 +745,75 @@ def new_quest_using_commands(self, commands: List[str], event_style: str) -> Que Returns: The resulting quest. """ - event = self.new_event_using_commands(commands, event_style=event_style) - return Quest(win_events=new_operation(operation=[event]), commands=event.commands) + event = self.new_event_using_commands(commands) + return Quest(win_event=event, commands=event.commands) + + def set_walkthrough(self, *walkthroughs: List[str]): + # Assuming quest.events return a list of EventAnd. + events = {event: event.copy() for quest in self.quests for event in quest.events} + + actions = [] + cmds_performed = [] + + def _callback(event): + if not isinstance(event, EventAnd): + return + + if event not in events: + assert False + + if event not in events or events[event].commands: + return + + events[event].commands = list(cmds_performed) - def set_walkthrough(self, commands: List[str]): with make_temp_directory() as tmpdir: game_file = self.compile(pjoin(tmpdir, "set_walkthrough.ulx")) env = textworld.start(game_file, infos=EnvInfos(last_action=True, intermediate_reward=True)) - state = env.reset() - events = {event: event.copy() for quest in self.quests for event in quest.win_events} - event_progressions = [ep for qp in state._game_progression.quest_progressions for ep in qp.win_events] + for walkthrough in walkthroughs: + state = env.reset() + state._game_progression.callback = _callback + + done = False + for i, cmd in enumerate(walkthrough): + if done: + msg = "Game has ended before finishing playing all commands." + raise ValueError(msg) + + cmds_performed.append(cmd) + state, score, done = env.step(cmd) + actions.append(state._last_action) + + for k, v in events.items(): + if v.commands and not k.actions: + k.commands = v.commands + k.actions = list(actions[:len(v.commands)]) + + actions.clear() + cmds_performed.clear() + + for quest in self.quests: + if quest.win_event: + quest.commands = quest.win_event.commands + + def get_action_from_commands(self, commands: List[str]): + with make_temp_directory() as tmpdir: + game_file = self.compile(pjoin(tmpdir, "get_actions.ulx")) + env = textworld.start(game_file, infos=EnvInfos(last_action=True)) + state = env.reset() - done = False actions = [] + done = False for i, cmd in enumerate(commands): if done: msg = "Game has ended before finishing playing all commands." raise ValueError(msg) - events_triggered = [ep.triggered for ep in event_progressions] - state, score, done = env.step(cmd) actions.append(state._last_action) - for was_triggered, ep in zip(events_triggered, event_progressions): - if not was_triggered and ep.triggered: - events[ep.event].actions = list(actions) - events[ep.event].commands = commands[:i + 1] - - for k, v in events.items(): - k.actions = v.actions - k.commands = v.commands + return actions def validate(self) -> bool: """ Check if the world is valid and can be compiled. diff --git a/textworld/generator/tests/test_game.py b/textworld/generator/tests/test_game.py index abc4d40f..f56a8301 100644 --- a/textworld/generator/tests/test_game.py +++ b/textworld/generator/tests/test_game.py @@ -6,30 +6,29 @@ import textwrap from typing import Iterable -import numpy.testing as npt - import textworld from textworld import g_rng from textworld import GameMaker +from textworld import testing from textworld.generator.data import KnowledgeBase from textworld.generator import World from textworld.generator import make_small_map -from textworld.generator.maker import new_operation from textworld.generator.chaining import ChainingOptions, sample_quest -from textworld.logic import Action - +from textworld.logic import Action, State, Proposition, Rule from textworld.generator.game import GameOptions from textworld.generator.game import Quest, Game, Event, EventAction, EventCondition, EventOr, EventAnd from textworld.generator.game import QuestProgression, GameProgression, EventProgression -from textworld.generator.game import UnderspecifiedEventError, UnderspecifiedQuestError from textworld.generator.game import ActionDependencyTree, ActionDependencyTreeElement from textworld.generator.inform7 import Inform7Game from textworld.logic import GameLogic +DATA = testing.build_complex_test_game() + + def _find_action(command: str, actions: Iterable[Action], inform7: Inform7Game) -> None: """ Apply a text command to a game_progression object. """ commands = inform7.gen_commands_from_actions(actions) @@ -120,47 +119,32 @@ def test_variable_infos(verbose=False): assert var_infos.desc is not None -class TestEventCondition(unittest.TestCase): - - @classmethod - def setUpClass(cls): - M = GameMaker() - - # The goal - commands = ["take carrot", "insert carrot into chest"] +class TestEvent(unittest.TestCase): - R1 = M.new_room("room") - M.set_player(R1) + def test_init(self): + event = Event(conditions=[Proposition.parse("in(carrot: f, chest: c)")]) + assert type(event) is EventCondition - carrot = M.new(type='f', name='carrot') - R1.add(carrot) - # Add a closed chest in R2. - chest = M.new(type='c', name='chest') - chest.add_property("open") - R1.add(chest) - - cls.event = M.new_event_using_commands(commands, event_style='condition') - cls.actions = cls.event.actions - cls.traceable = cls.event.traceable - cls.conditions = {M.new_fact("in", carrot, chest)} +class TestEventCondition(unittest.TestCase): - def test_init(self): - event = EventCondition(actions=self.actions) - assert event.actions == self.actions - assert event.condition == self.event.condition - assert event.traceable == self.traceable - assert event.condition.preconditions == self.actions[-1].postconditions - assert set(event.condition.preconditions).issuperset(self.conditions) - - event = EventCondition(conditions=self.conditions) - assert len(event.actions) == 0 - assert event.traceable == self.traceable - assert set(event.condition.preconditions) == set(self.conditions) - - npt.assert_raises(UnderspecifiedEventError, EventCondition, actions=[]) - npt.assert_raises(UnderspecifiedEventError, EventCondition, actions=[], conditions=[]) - npt.assert_raises(UnderspecifiedEventError, EventCondition, conditions=[]) + @classmethod + def setUpClass(cls): + cls.condition = {Proposition.parse("in(carrot: f, chest: c)")} + cls.event = EventCondition(conditions=cls.condition) + + def test_is_triggering(self): + state = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, chest: c)"), + Proposition.parse("in(lettuce: f, chest: c)"), + ]) + assert self.event.is_triggering(state=state) + + state = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, I: I)"), + Proposition.parse("in(lettuce: f, chest: c)"), + ]) + assert not self.event.is_triggering(state=state) def test_serialization(self): data = self.event.serialize() @@ -177,32 +161,22 @@ class TestEventAction(unittest.TestCase): @classmethod def setUpClass(cls): - M = GameMaker() - - # The goal - commands = ["take carrot"] - - R1 = M.new_room("room") - M.set_player(R1) - - carrot = M.new(type='f', name='carrot') - R1.add(carrot) - - # Add a closed chest in R2. - chest = M.new(type='c', name='chest') - chest.add_property("open") - R1.add(chest) - - cls.event = M.new_event_using_commands(commands, event_style='action') - cls.actions = cls.event.actions - cls.traceable = cls.event.traceable - - def test_init(self): - event = EventAction(actions=self.actions) - assert event.actions == self.actions - assert event.traceable == self.traceable - - npt.assert_raises(UnderspecifiedEventError, EventCondition, actions=[]) + cls.rule = Rule.parse("close :: $at(P, r) & $at(chest: c, r) & open(chest: c) -> closed(chest: c)") + cls.action = Action.parse("close :: $at(P, room: r) & $at(chest: c, room: r) & open(chest: c) -> closed(chest: c)") + cls.event = EventAction(action=cls.rule) + + def test_is_triggering(self): + # State should be ignored in a EventAction. + state = State(KnowledgeBase.default().logic, [ + Proposition.parse("open(chest: c)"), + ]) + assert self.event.is_triggering(state=state, action=self.action) + + state = State(KnowledgeBase.default().logic, [ + Proposition.parse("closed(chest: c)"), + ]) + action = Action.parse("close :: open(fridge: c) -> closed(fridge: c)") + assert not self.event.is_triggering(state=state, action=action) def test_serialization(self): data = self.event.serialize() @@ -219,174 +193,162 @@ class TestEventOr(unittest.TestCase): @classmethod def setUpClass(cls): - - M = GameMaker() - - # The goal - commands = ["take lime juice", "insert lime juice into chest", "take carrot"] - - R1 = M.new_room("room") - M.set_player(R1) - - lime = M.new(type='f', name='lime juice') - R1.add(lime) - - carrot = M.new(type='f', name='carrot') - R1.add(carrot) - - # Add a closed chest in R2. - chest = M.new(type='c', name='chest') - chest.add_property("open") - R1.add(chest) - - cls.first_event = M.new_event_using_commands(commands[:-1], event_style='condition') - cls.first_event_actions = cls.first_event.actions - cls.first_event_traceable = cls.first_event.traceable - cls.first_event_conditions = {M.new_fact("in", lime, chest)} - - cls.second_event = M.new_event_using_commands([commands[-1]], event_style='action') - cls.second_event_actions = cls.second_event.actions - cls.second_event_traceable = cls.second_event.traceable - - cls.event = EventOr(events=(cls.first_event, cls.second_event)) - cls.events = cls.event.events - - def test_init(self): - first_event = EventCondition(actions=self.first_event_actions) - second_event = EventAction(actions=self.second_event_actions) - event = EventOr(events=(first_event, second_event)) - - assert event.events[0].actions == self.first_event_actions - assert event.events[0].condition == self.first_event.condition - assert event.events[0].traceable == self.first_event_traceable - assert event.events[0].condition.preconditions == self.first_event_actions[-1].postconditions - assert set(event.events[0].condition.preconditions).issuperset(self.first_event_conditions) - assert event.events[1].actions == self.second_event_actions - assert event.events[1].traceable == self.second_event_traceable + cls.event_A_condition = {Proposition.parse("in(carrot: f, chest: c)")} + cls.event_A = EventCondition(conditions=cls.event_A_condition) + + cls.event_B_action = Rule.parse("close :: open(chest: c) -> closed(chest: c)") + cls.event_B = EventAction(action=cls.event_B_action) + + cls.event_A_or_B = EventOr(events=(cls.event_A, cls.event_B)) + + def test_is_triggering(self): + open_chest = Action.parse("open :: closed(chest: c) -> open(chest: c)") + close_chest = Action.parse("close :: open(chest: c) -> closed(chest: c)") + carrot_in_chest = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, chest: c)"), + ]) + carrot_in_inventory = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, I: I)"), + ]) + + # A | B + assert self.event_A.is_triggering(state=carrot_in_chest, action=close_chest) + assert self.event_B.is_triggering(state=carrot_in_chest, action=close_chest) + assert self.event_A_or_B.is_triggering(state=carrot_in_chest, action=close_chest) + + # !A | !B + assert not self.event_A.is_triggering(state=carrot_in_inventory, action=open_chest) + assert not self.event_B.is_triggering(state=carrot_in_inventory, action=open_chest) + assert not self.event_A_or_B.is_triggering(state=carrot_in_inventory, action=open_chest) + + # !A | B + assert not self.event_A.is_triggering(state=carrot_in_inventory, action=close_chest) + assert self.event_B.is_triggering(state=carrot_in_inventory, action=close_chest) + assert self.event_A_or_B.is_triggering(state=carrot_in_inventory, action=close_chest) + + # A | !B + assert self.event_A.is_triggering(state=carrot_in_chest, action=open_chest) + assert not self.event_B.is_triggering(state=carrot_in_chest, action=open_chest) + assert self.event_A_or_B.is_triggering(state=carrot_in_chest, action=open_chest) def test_serialization(self): - data = self.event.serialize() + data = self.event_A_or_B.serialize() event = EventOr.deserialize(data) - assert event == self.event + assert event == self.event_A_or_B def test_copy(self): - event = self.event.copy() - assert event == self.event - assert id(event) != id(self.event) + event = self.event_A_or_B.copy() + assert event == self.event_A_or_B + assert id(event) != id(self.event_A_or_B) class TestEventAnd(unittest.TestCase): @classmethod def setUpClass(cls): - - M = GameMaker() - - # The goal - commands = ["take lime juice", "insert lime juice into chest", "take carrot"] - - R1 = M.new_room("room") - M.set_player(R1) - - lime = M.new(type='f', name='lime juice') - R1.add(lime) - - carrot = M.new(type='f', name='carrot') - R1.add(carrot) - - # Add a closed chest in R2. - chest = M.new(type='c', name='chest') - chest.add_property("open") - R1.add(chest) - - cls.first_event = M.new_event_using_commands(commands[:-1], event_style='condition') - cls.first_event_actions = cls.first_event.actions - cls.first_event_traceable = cls.first_event.traceable - cls.first_event_conditions = {M.new_fact("in", lime, chest)} - - cls.second_event = M.new_event_using_commands([commands[-1]], event_style='action') - cls.second_event_actions = cls.second_event.actions - cls.second_event_traceable = cls.second_event.traceable - - cls.event = EventAnd(events=(cls.first_event, cls.second_event)) - cls.events = cls.event.events - - def test_init(self): - first_event = EventCondition(actions=self.first_event_actions) - second_event = EventAction(actions=self.second_event_actions) - event = EventAnd(events=(first_event, second_event)) - - assert event.events[0].actions == self.first_event_actions - assert event.events[0].condition == self.first_event.condition - assert event.events[0].traceable == self.first_event_traceable - assert event.events[0].condition.preconditions == self.first_event_actions[-1].postconditions - assert set(event.events[0].condition.preconditions).issuperset(self.first_event_conditions) - assert event.events[1].actions == self.second_event_actions - assert event.events[1].traceable == self.second_event_traceable + cls.event_A_condition = {Proposition.parse("in(carrot: f, chest: c)")} + cls.event_A = EventCondition(conditions=cls.event_A_condition) + + cls.event_B_action = Rule.parse("close :: open(chest: c) -> closed(chest: c)") + cls.event_B = EventAction(action=cls.event_B_action) + + cls.event_A_and_B = EventAnd(events=(cls.event_A, cls.event_B)) + + def test_is_triggering(self): + open_chest = Action.parse("open :: closed(chest: c) -> open(chest: c)") + close_chest = Action.parse("close :: open(chest: c) -> closed(chest: c)") + carrot_in_chest = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, chest: c)"), + ]) + carrot_in_inventory = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, I: I)"), + ]) + + # A & B + assert self.event_A.is_triggering(state=carrot_in_chest, action=close_chest) + assert self.event_B.is_triggering(state=carrot_in_chest, action=close_chest) + assert self.event_A_and_B.is_triggering(state=carrot_in_chest, action=close_chest) + + # !A & !B + assert not self.event_A.is_triggering(state=carrot_in_inventory, action=open_chest) + assert not self.event_B.is_triggering(state=carrot_in_inventory, action=open_chest) + assert not self.event_A_and_B.is_triggering(state=carrot_in_inventory, action=open_chest) + + # !A & B + assert not self.event_A.is_triggering(state=carrot_in_inventory, action=close_chest) + assert self.event_B.is_triggering(state=carrot_in_inventory, action=close_chest) + assert not self.event_A_and_B.is_triggering(state=carrot_in_inventory, action=close_chest) + + # A & !B + assert self.event_A.is_triggering(state=carrot_in_chest, action=open_chest) + assert not self.event_B.is_triggering(state=carrot_in_chest, action=open_chest) + assert not self.event_A_and_B.is_triggering(state=carrot_in_chest, action=open_chest) def test_serialization(self): - data = self.event.serialize() + data = self.event_A_and_B.serialize() event = EventAnd.deserialize(data) - assert event == self.event + assert event == self.event_A_and_B def test_copy(self): - event = self.event.copy() - assert event == self.event - assert id(event) != id(self.event) + event = self.event_A_and_B.copy() + assert event == self.event_A_and_B + assert id(event) != id(self.event_A_and_B) class TestQuest(unittest.TestCase): @classmethod def setUpClass(cls): - M = GameMaker() + cls.carrot_in_chest = {Proposition.parse("in(carrot: f, chest: c)")} + cls.event_carrot_in_chest = EventCondition(conditions=cls.carrot_in_chest) - # The goal - commands = ["open wooden door", "go east", "insert carrot into chest"] + cls.close_chest = Rule.parse("close :: open(chest: c) -> closed(chest: c)") + cls.event_close_chest = EventAction(action=cls.close_chest) - # Create a 'bedroom' room. - R1 = M.new_room("bedroom") - R2 = M.new_room("kitchen") - M.set_player(R1) - - path = M.connect(R1.east, R2.west) - door_a = M.new_door(path, name="wooden door") - M.add_fact("closed", door_a) + cls.event_closing_chest_with_carrot = EventAnd(events=(cls.event_carrot_in_chest, cls.event_close_chest)) - carrot = M.new(type='f', name='carrot') - M.inventory.add(carrot) + cls.carrot_in_inventory = {Proposition.parse("in(carrot: f, I: I)")} + cls.event_carrot_in_inventory = EventCondition(conditions=cls.carrot_in_inventory) - # Add a closed chest in R2. - chest = M.new(type='c', name='chest') - chest.add_property("open") - R2.add(chest) + cls.event_closing_chest_without_carrot = EventAnd(events=(cls.event_carrot_in_inventory, cls.event_close_chest)) - cls.eventA = M.new_event_using_commands(commands, event_style='condition') - cls.eventB = M.new_event(condition={M.new_fact("at", carrot, R1), M.new_fact("closed", path.door)}, - event_style='condition') - cls.eventC = M.new_event(condition={M.new_fact("eaten", carrot)}, event_style='condition') - cls.eventD = M.new_event(condition={M.new_fact("closed", chest), M.new_fact("closed", path.door)}, - event_style='condition') - cls.quest = M.new_quest(win_event={'or': (cls.eventA, cls.eventB)}, - fail_event={'or': (cls.eventC, cls.eventD)}, reward=2) - M.quests = [cls.quest] - cls.game = M.build() - cls.inform7 = Inform7Game(cls.game) + cls.eat_carrot = Rule.parse("eat :: in(carrot: f, I: I) -> consumed(carrot: f)") + cls.event_eat_carrot = EventAction(action=cls.eat_carrot) - def test_init(self): - npt.assert_raises(UnderspecifiedQuestError, Quest) + cls.event_closing_chest_whithout_carrot_or_eating_carrot = \ + EventOr(events=(cls.event_closing_chest_without_carrot, cls.event_eat_carrot)) - quest = Quest(win_events=new_operation(operation={'and': (self.eventA, self.eventB)})) - assert len(quest.fail_events) == 0 + cls.quest = Quest(win_event=cls.event_closing_chest_with_carrot, + fail_event=cls.event_closing_chest_whithout_carrot_or_eating_carrot) - quest = Quest(fail_events=new_operation(operation={'or': (self.eventC, self.eventD)})) - assert len(quest.win_events) == 0 + def test_backward_compatiblity(self): + # Backward compatibility tests. + quest = Quest(win_events=[self.event_closing_chest_with_carrot], + fail_events=[self.event_closing_chest_without_carrot, self.event_eat_carrot]) + assert quest == self.quest - quest = Quest(win_events=new_operation(operation={'and': (self.eventA, self.eventB)}), - fail_events=new_operation(operation={'or': (self.eventC, self.eventD)})) + quest = Quest([self.event_closing_chest_with_carrot], + [self.event_closing_chest_without_carrot, self.event_eat_carrot]) + assert quest == self.quest - assert len(quest.win_events) > 0 - assert len(quest.fail_events) > 0 + def test_is_winning_or_failing(self): + close_chest = Action.parse("close :: open(chest: c) -> closed(chest: c)") + eat_carrot = Action.parse("eat :: in(carrot: f, I: I) -> consumed(carrot: f)") + carrot_in_chest = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, chest: c)"), + ]) + carrot_in_inventory = State(KnowledgeBase.default().logic, [ + Proposition.parse("in(carrot: f, I: I)"), + ]) + + assert self.quest.is_winning(state=carrot_in_chest, action=close_chest) + assert not self.quest.is_failing(state=carrot_in_chest, action=close_chest) + assert self.quest.is_failing(state=carrot_in_inventory, action=close_chest) + assert not self.quest.is_winning(state=carrot_in_inventory, action=close_chest) + assert self.quest.is_failing(state=carrot_in_inventory, action=eat_carrot) + assert not self.quest.is_winning(state=carrot_in_inventory, action=eat_carrot) + assert self.quest.is_failing(state=carrot_in_chest, action=eat_carrot) + assert not self.quest.is_winning(state=carrot_in_chest, action=eat_carrot) def test_serialization(self): data = self.quest.serialize() @@ -435,7 +397,7 @@ def _rule_to_skip(rule): actions = chain.actions assert len(actions) == max_depth, rule.name - quest = Quest(win_events=new_operation(operation={'and': (EventCondition(actions=actions))})) + quest = Quest(win_event=EventCondition(actions=actions)) tmp_world = World.from_facts(chain.initial_state.facts) state = tmp_world.state @@ -446,7 +408,7 @@ def _rule_to_skip(rule): assert quest.is_winning(state) # Build the quest by only providing the winning conditions. - quest = Quest(win_events=new_operation(operation={'and': (EventCondition(conditions=actions[-1].postconditions))})) + quest = Quest(win_event=EventCondition(conditions=actions[-1].postconditions)) tmp_world = World.from_facts(chain.initial_state.facts) state = tmp_world.state @@ -456,86 +418,6 @@ def _rule_to_skip(rule): assert quest.is_winning(state) - def test_win_actions(self): - state = self.game.world.state.copy() - for action in self.quest.win_events_list[0].actions: - assert not self.quest.is_winning(state) - state.apply(action) - - assert self.quest.is_winning(state) - - # Test alternative way of winning, - # i.e. dropping the carrot and closing the door. - state = self.game.world.state.copy() - actions = list(state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) - - drop_carrot = _find_action("drop carrot", actions, self.inform7) - open_door = _find_action("open wooden door", actions, self.inform7) - - state = self.game.world.state.copy() - assert state.apply(drop_carrot) - assert self.quest.is_winning(state) - assert state.apply(open_door) - assert not self.quest.is_winning(state) - - # Or the other way around. - state = self.game.world.state.copy() - assert state.apply(open_door) - assert not self.quest.is_winning(state) - assert state.apply(drop_carrot) - assert not self.quest.is_winning(state) - - def test_fail_actions(self): - state = self.game.world.state.copy() - assert not self.quest.is_failing(state) - - actions = list(state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) - eat_carrot = _find_action("eat carrot", actions, self.inform7) - open_door = _find_action("open wooden door", actions, self.inform7) - state.apply(open_door) - actions = list(state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) - - for action in actions: - state = self.game.world.state.copy() - state.apply(action) - # Only the `eat carrot` should fail. - assert self.quest.is_failing(state) == (action == eat_carrot) - - state = self.game.world.state.copy() - actions = list(state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) - open_door = _find_action("open wooden door", actions, self.inform7) - state.apply(open_door) - actions = list(state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) - go_east = _find_action("go east", actions, self.inform7) - state.apply(go_east) # Move to the kitchen. - actions = list(state.all_applicable_actions(self.game.kb.rules.values(), - self.game.kb.types.constants_mapping)) - close_door = _find_action("close wooden door", actions, self.inform7) - close_chest = _find_action("close chest", actions, self.inform7) - - # Only closing the door doesn't fail the quest. - state_ = state.apply_on_copy(close_door) - assert not self.quest.is_failing(state_) - - # Only closing the chest doesn't fail the quest. - state_ = state.apply_on_copy(close_chest) - assert not self.quest.is_failing(state_) - - # Closing the chest, then the door should fail the quest. - state_ = state.apply_on_copy(close_chest) - state_.apply(close_door) - assert self.quest.is_failing(state_) - - # Closing the door, then the chest should fail the quest. - state_ = state.apply_on_copy(close_door) - state_.apply(close_chest) - assert self.quest.is_failing(state_) - class TestGame(unittest.TestCase): @@ -544,7 +426,7 @@ def setUpClass(cls): M = GameMaker() # The goal - commands = ["open wooden door", "go east", "insert carrot into chest"] + commands = ["go east", "insert carrot into chest"] # Create a 'bedroom' room. R1 = M.new_room("bedroom") @@ -552,8 +434,8 @@ def setUpClass(cls): M.set_player(R1) path = M.connect(R1.east, R2.west) - door_a = M.new_door(path, name="wooden door") - M.add_fact("closed", door_a) + path.door = M.new(type='d', name='wooden door') + path.door.add_property("open") carrot = M.new(type='f', name='carrot') M.inventory.add(carrot) @@ -563,8 +445,9 @@ def setUpClass(cls): chest.add_property("open") R2.add(chest) - M.set_quest_from_commands(commands, event_style='condition') + M.set_quest_from_commands(commands) cls.game = M.build() + cls.walkthrough = commands def test_directions_names(self): expected = set(["north", "south", "east", "west"]) @@ -588,6 +471,9 @@ def test_verbs(self): "inventory", "examine"} assert set(self.game.verbs) == expected_verbs + def test_walkthrough(self): + assert self.game.walkthrough == self.walkthrough + def test_command_templates(self): expected_templates = { 'close {c}', 'close {d}', 'drop {o}', 'eat {f}', 'examine {d}', @@ -617,36 +503,36 @@ class TestEventProgression(unittest.TestCase): @classmethod def setUpClass(cls): - M = GameMaker() - - # The goal - commands = ["take carrot", "insert carrot into chest"] - - R1 = M.new_room("room") - M.set_player(R1) - - carrot = M.new(type='f', name='carrot') - R1.add(carrot) + cls.game = DATA["game"] + cls.win_event = DATA["quest"].win_event + cls.eating_carrot = DATA["eating_carrot"] + cls.onion_eaten = DATA["onion_eaten"] - # Add a closed chest in R2. - chest = M.new(type='c', name='chest') - chest.add_property("open") - R1.add(chest) + def test_triggering_policy(self): + event = EventProgression(self.win_event, KnowledgeBase.default()) - cls.event = new_operation([M.new_event_using_commands(commands, event_style='condition')]) - cls.actions = cls.event[0].events[0].actions - cls.conditions = {M.new_fact("in", carrot, chest)} - cls.game = M.build() - commands = ["take carrot", "eat carrot"] - cls.eating_carrot = new_operation([M.new_event_using_commands(commands, event_style='condition')]) + state = self.game.world.state.copy() + for action in event.triggering_policy: + assert not event.done + assert not event.triggered + assert not event.untriggerable + state.apply(action) + event.update(action=action, state=state) - def test_triggering_policy(self): - event = EventProgression(self.event[0], KnowledgeBase.default()) + assert event.triggering_policy == () + assert event.done + assert event.triggered + assert not event.untriggerable + event = EventProgression(self.win_event, KnowledgeBase.default()) state = self.game.world.state.copy() - expected_actions = self.event[0].events[0].actions + + expected_actions = self.eating_carrot.actions for i, action in enumerate(expected_actions): - assert event.triggering_policy == expected_actions[i:] + state.apply(action) + event.update(action=action, state=state) + + for action in event.triggering_policy: assert not event.done assert not event.triggered assert not event.untriggerable @@ -659,10 +545,18 @@ def test_triggering_policy(self): assert not event.untriggerable def test_untriggerable(self): - event = EventProgression(self.event[0], KnowledgeBase.default()) + event = EventProgression(self.win_event, KnowledgeBase.default()) state = self.game.world.state.copy() - for action in self.eating_carrot[0].events[0].actions: + for action in self.eating_carrot.actions: + assert event.triggering_policy != () + assert not event.done + assert not event.triggered + assert not event.untriggerable + state.apply(action) + event.update(action=action, state=state) + + for action in self.onion_eaten.actions: assert event.triggering_policy != () assert not event.done assert not event.triggered @@ -680,84 +574,71 @@ class TestQuestProgression(unittest.TestCase): @classmethod def setUpClass(cls): - M = GameMaker() - - room = M.new_room("room") - M.set_player(room) - - carrot = M.new(type='f', name='carrot') - lettuce = M.new(type='f', name='lettuce') - room.add(carrot) - room.add(lettuce) - - chest = M.new(type='c', name='chest') - chest.add_property("open") - room.add(chest) - - # The goals - commands = ["take carrot", "insert carrot into chest"] - cls.eventA = M.new_event_using_commands(commands, event_style='condition') - - commands = ["take lettuce", "insert lettuce into chest", "close chest"] - event = M.new_event_using_commands(commands, event_style='condition') - cls.eventB = EventCondition(actions=event.actions, - conditions={M.new_fact("in", lettuce, chest), M.new_fact("closed", chest)}) - - cls.fail_eventA = EventCondition(conditions={M.new_fact("eaten", carrot)}) - cls.fail_eventB = EventCondition(conditions={M.new_fact("eaten", lettuce)}) - - cls.quest = M.new_quest(win_event={'or': (cls.eventA, cls.eventB)}, - fail_event={'or': (cls.fail_eventA, cls.fail_eventB)}, reward=2) - - commands = ["take carrot", "eat carrot"] - cls.eating_carrot = M.new_event_using_commands(commands, event_style='condition') - commands = ["take lettuce", "eat lettuce"] - cls.eating_lettuce = M.new_event_using_commands(commands, event_style='condition') - - M.quests = [cls.quest] - cls.game = M.build() - - def _apply_actions_to_quest(self, actions, quest): - state = self.game.world.state.copy() + cls.game = DATA["game"] + cls.quest = DATA["quest"] + cls.eating_carrot = DATA["eating_carrot"] + cls.onion_eaten = DATA["onion_eaten"] + cls.closing_chest_without_carrot = DATA["closing_chest_without_carrot"] + + def _apply_actions_to_quest(self, actions, quest, state=None): + state = state or self.game.world.state.copy() for action in actions: assert not quest.done state.apply(action) quest.update(action, state) - assert quest.done - return quest + return state def test_completed(self): quest = QuestProgression(self.quest, KnowledgeBase.default()) - quest = self._apply_actions_to_quest(self.eventA.actions, quest) - assert quest.completed - assert not quest.failed + self._apply_actions_to_quest(self.quest.win_event.events[0].actions, quest) + assert quest.done + assert quest.completed and not quest.failed assert quest.winning_policy is None # Alternative winning strategy. quest = QuestProgression(self.quest, KnowledgeBase.default()) - quest = self._apply_actions_to_quest(self.eventB.actions, quest) - assert quest.completed - assert not quest.failed + self._apply_actions_to_quest(self.quest.win_event.events[1].actions, quest) + assert quest.done + assert quest.completed and not quest.failed + assert quest.winning_policy is None + + # Alternative winning strategy but with carrot in inventory. + quest = QuestProgression(self.quest, KnowledgeBase.default()) + state = self._apply_actions_to_quest(self.eating_carrot.actions[:1], quest) # Take carrot. + self._apply_actions_to_quest(self.quest.win_event.events[1].actions, quest, state) + assert quest.done + assert quest.completed and not quest.failed assert quest.winning_policy is None def test_failed(self): + # onion_eaten -> eating_carrot != eating_carrot -> onion_eaten + # Eating the carrot *after* eating the onion causes the game to be lost. quest = QuestProgression(self.quest, KnowledgeBase.default()) - quest = self._apply_actions_to_quest(self.eating_carrot.actions, quest) - assert not quest.completed - assert quest.failed + state = self._apply_actions_to_quest(self.onion_eaten.actions, quest) + self._apply_actions_to_quest(self.eating_carrot.actions, quest, state) + assert quest.done + assert not quest.completed and quest.failed assert quest.winning_policy is None + # Eating the carrot *before* eating the onion does not lose the game, + # but the game becomes unfinishable. + quest = QuestProgression(self.quest, KnowledgeBase.default()) + state = self._apply_actions_to_quest(self.eating_carrot.actions, quest) + self._apply_actions_to_quest(self.onion_eaten.actions, quest, state) + assert quest.done and quest.unfinishable + assert not quest.completed and not quest.failed + quest = QuestProgression(self.quest, KnowledgeBase.default()) - quest = self._apply_actions_to_quest(self.eating_lettuce.actions, quest) - assert not quest.completed - assert quest.failed + self._apply_actions_to_quest(self.closing_chest_without_carrot.actions, quest) + assert quest.done + assert not quest.completed and quest.failed assert quest.winning_policy is None def test_winning_policy(self): kb = KnowledgeBase.default() quest = QuestProgression(self.quest, kb) - quest = self._apply_actions_to_quest(quest.winning_policy, quest) + self._apply_actions_to_quest(quest.winning_policy, quest) assert quest.completed assert not quest.failed assert quest.winning_policy is None @@ -765,13 +646,15 @@ def test_winning_policy(self): # Winning policy should be the shortest one leading to a winning event. state = self.game.world.state.copy() quest = QuestProgression(self.quest, KnowledgeBase.default()) - for i, action in enumerate(self.eventB.actions): + for i, action in enumerate(self.quest.win_event.events[1].actions): if i < 2: - assert quest.winning_policy == self.eventA.actions - # else: - # # After taking the lettuce and putting it in the chest, - # # QuestB becomes the shortest one to complete. - # assert quest.winning_policy == self.eventB.actions[i:] + assert set(quest.winning_policy).issubset(set(self.quest.win_event.events[0].actions)) + assert not set(quest.winning_policy).issubset(set(self.quest.win_event.events[1].actions)) + else: + # After opening the chest and taking the onion, + # the alternative winning event becomes the shortest one to complete. + assert quest.winning_policy == self.quest.win_event.events[1].actions[i:] + assert not quest.done state.apply(action) quest.update(action, state) @@ -786,134 +669,54 @@ class TestGameProgression(unittest.TestCase): @classmethod def setUpClass(cls): - M = GameMaker() - - # Create a 'bedroom' room. - R1 = M.new_room("bedroom") - R2 = M.new_room("kitchen") - M.set_player(R2) - - path = M.connect(R1.east, R2.west) - door_a = M.new_door(path, name="wooden door") - M.add_fact("closed", door_a) - - carrot = M.new(type='f', name='carrot') - lettuce = M.new(type='f', name='lettuce') - R1.add(carrot) - R1.add(lettuce) - - tomato = M.new(type='f', name='tomato') - pepper = M.new(type='f', name='pepper') - M.inventory.add(tomato) - M.inventory.add(pepper) - - # Add a closed chest in R2. - chest = M.new(type='c', name='chest') - chest.add_property("open") - R2.add(chest) - - # The goals - commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"] - cls.eventA = M.new_event_using_commands(commands, event_style='condition') - - commands = ["open wooden door", "go west", "take lettuce", "go east", "insert lettuce into chest"] - cls.eventB = M.new_event_using_commands(commands, event_style='condition') - - commands = ["drop pepper"] - cls.eventC = M.new_event_using_commands(commands, event_style='condition') - - cls.losing_eventA = EventCondition(conditions={M.new_fact("eaten", carrot)}) - cls.losing_eventB = EventCondition(conditions={M.new_fact("eaten", lettuce)}) - - cls.questA = M.new_quest(win_event=[cls.eventA], fail_event=[cls.losing_eventA]) - cls.questB = M.new_quest(win_event=[cls.eventB], fail_event=[cls.losing_eventB]) - cls.questC = M.new_quest(win_event=[cls.eventC], fail_event=[]) - cls.questD = M.new_quest(win_event=[], fail_event=[cls.losing_eventA, cls.losing_eventB]) - - commands = ["open wooden door", "go west", "take carrot", "eat carrot"] - cls.eating_carrot = M.new_event_using_commands(commands, event_style='condition') - commands = ["open wooden door", "go west", "take lettuce", "eat lettuce"] - cls.eating_lettuce = M.new_event_using_commands(commands, event_style='condition') - commands = ["eat tomato"] - cls.eating_tomato = M.new_event_using_commands(commands, event_style='condition') - commands = ["eat pepper"] - cls.eating_pepper = M.new_event_using_commands(commands, event_style='condition') - - M.quests = [cls.questA, cls.questB, cls.questC] - cls.game = M.build() + cls.game = DATA["game"] + cls.quest1 = DATA["quest1"] + cls.quest2 = DATA["quest2"] + cls.eating_carrot = DATA["eating_carrot"] + cls.onion_eaten = DATA["onion_eaten"] + cls.knife_on_counter = DATA["knife_on_counter"] def test_completed(self): game = GameProgression(self.game) - for action in self.eventA.actions + self.eventC.actions: + for action in self.quest1.win_event.events[0].actions + self.quest2.win_event.events[0].actions: assert not game.done game.update(action) - assert not game.done - remaining_actions = self.eventB.actions[1:] # skipping "open door". - assert game.winning_policy == remaining_actions - - for action in self.eventB.actions: - assert not game.done - game.update(action) - - assert game.done - assert game.completed - assert not game.failed - assert game.winning_policy is None - - def test_failed(self): - game = GameProgression(self.game) - action = self.eating_tomato.actions[0] - game.update(action) - assert not game.done - assert not game.completed - assert not game.failed - assert game.winning_policy is not None - - game = GameProgression(self.game) - action = self.eating_pepper.actions[0] - game.update(action) - assert not game.completed - assert game.failed assert game.done + assert game.completed and not game.failed assert game.winning_policy is None + # Alternative quest1 solution game = GameProgression(self.game) - for action in self.eating_carrot.actions: + for action in self.quest1.win_event.events[1].actions + self.quest2.win_event.events[0].actions: assert not game.done game.update(action) assert game.done - assert not game.completed - assert game.failed + assert game.completed and not game.failed assert game.winning_policy is None + def test_failed(self): game = GameProgression(self.game) - for action in self.eating_lettuce.actions: - assert not game.done - game.update(action) - assert game.done - assert not game.completed - assert game.failed - assert game.winning_policy is None - - # Completing QuestA but failing quest B. - game = GameProgression(self.game) - for action in self.eventA.actions: - assert not game.done + # Completing quest2 but failing quest 1. + for action in self.knife_on_counter.actions: game.update(action) + assert not game.quest_progressions[0].done + assert game.quest_progressions[1].done + assert game.quest_progressions[1].completed assert not game.done + assert not game.completed and not game.failed + assert game.winning_policy is not None - game = GameProgression(self.game) - for action in self.eating_lettuce.actions: - assert not game.done + for action in self.onion_eaten.actions + self.eating_carrot.actions: game.update(action) assert game.done - assert not game.completed - assert game.failed + assert game.quest_progressions[0].done + assert game.quest_progressions[0].failed + assert not game.completed and game.failed assert game.winning_policy is None def test_winning_policy(self): @@ -944,8 +747,8 @@ def test_cycle_in_winning_policy(self): R4 = M.new_room("r4") M.set_player(R1) - M.connect(R0.south, R1.north) - M.connect(R1.east, R2.west) + M.connect(R0.south, R1.north), + M.connect(R1.east, R2.west), M.connect(R3.east, R4.west) M.connect(R1.south, R3.north) M.connect(R2.south, R4.north) @@ -957,7 +760,8 @@ def test_cycle_in_winning_policy(self): R2.add(apple) commands = ["go north", "take carrot"] - M.set_quest_from_commands(commands, event_style='condition') + M.set_quest_from_commands(commands) + M.set_walkthrough(commands) # TODO: redundant! game = M.build() inform7 = Inform7Game(game) game_progression = GameProgression(game) @@ -981,7 +785,8 @@ def test_cycle_in_winning_policy(self): # Quest where player's has to pick up the carrot first. commands = ["go east", "take apple", "go west", "go north", "drop apple"] - M.set_quest_from_commands(commands, event_style='condition') + M.set_quest_from_commands(commands) + M.set_walkthrough(commands) # TODO: redundant! game = M.build() game_progression = GameProgression(game) @@ -1016,8 +821,8 @@ def test_game_with_multiple_quests(self): M.set_player(R2) path = M.connect(R1.east, R2.west) - door_a = M.new_door(path, name="wooden door") - M.add_fact("closed", door_a) + path.door = M.new(type='d', name='wooden door') + path.door.add_property("closed") carrot = M.new(type='f', name='carrot') lettuce = M.new(type='f', name='lettuce') @@ -1028,18 +833,19 @@ def test_game_with_multiple_quests(self): chest.add_property("open") R2.add(chest) - quest1 = M.new_quest_using_commands(commands[0], event_style='condition') + quest1 = M.new_quest_using_commands(commands[0]) quest1.desc = "Fetch the carrot and drop it on the kitchen's ground." - quest2 = M.new_quest_using_commands(commands[0] + commands[1], event_style='condition') + quest2 = M.new_quest_using_commands(commands[0] + commands[1]) quest2.desc = "Fetch the lettuce and drop it on the kitchen's ground." - quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2], event_style='condition') + # quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2]) winning_facts = [M.new_fact("in", lettuce, chest), M.new_fact("in", carrot, chest), M.new_fact("closed", chest)] - quest3.win_events[0].events[0].set_conditions(winning_facts) + quest3 = Quest(win_event=EventCondition(winning_facts)) quest3.desc = "Put the lettuce and the carrot into the chest before closing it." M.quests = [quest1, quest2, quest3] + M.set_walkthrough(commands[0] + commands[1] + commands[2]) assert len(M.quests) == len(commands) game = M.build() diff --git a/textworld/generator/tests/test_maker.py b/textworld/generator/tests/test_maker.py index 3aebc523..6ffb6f70 100644 --- a/textworld/generator/tests/test_maker.py +++ b/textworld/generator/tests/test_maker.py @@ -113,8 +113,8 @@ def test_making_a_small_game(play_the_game=False): path = M.connect(R1.east, R2.west) # Undirected path # Add a closed door between R1 and R2. - door = M.new_door(path, name="glass door") - M.add_fact("locked", door) + door = M.new_door(path, name='glass door') + door.add_property("locked") # Put a matching key for the door on R1's floor. key = M.new(type='k', name='rusty key') @@ -148,7 +148,7 @@ def test_record_quest_from_commands(play_the_game=False): M = GameMaker() # The goal - commands = ["open wooden door", "go east", "insert ball into chest"] + commands = ["go east", "insert ball into chest"] # Create a 'bedroom' room. R1 = M.new_room("bedroom") @@ -156,8 +156,8 @@ def test_record_quest_from_commands(play_the_game=False): M.set_player(R1) path = M.connect(R1.east, R2.west) - door_a = M.new_door(path, name="wooden door") - M.add_fact("closed", door_a) + path.door = M.new(type='d', name='wooden door') + path.door.add_property("open") ball = M.new(type='o', name='ball') M.inventory.add(ball) @@ -167,7 +167,7 @@ def test_record_quest_from_commands(play_the_game=False): chest.add_property("open") R2.add(chest) - M.set_quest_from_commands(commands, event_style='condition') + M.set_quest_from_commands(commands) game = M.build() with make_temp_directory(prefix="test_record_quest_from_commands") as tmpdir: diff --git a/textworld/generator/tests/test_text_generation.py b/textworld/generator/tests/test_text_generation.py index 290555ac..fe4c72f7 100644 --- a/textworld/generator/tests/test_text_generation.py +++ b/textworld/generator/tests/test_text_generation.py @@ -57,14 +57,15 @@ def test_blend_instructions(verbose=False): M.set_player(r1) path = M.connect(r1.north, r2.south) - door_a = M.new_door(path, name="wooden door") - M.add_fact("locked", door_a) + path.door = M.new(type="d", name="door") + M.add_fact("locked", path.door) key = M.new(type="k", name="key") M.add_fact("match", key, path.door) r1.add(key) quest = M.set_quest_from_commands(["take key", "unlock door with key", "open door", "go north", - "close door", "lock door with key", "drop key"], event_style='condition') + "close door", "lock door with key", "drop key"]) + game = M.build() grammar1 = textworld.generator.make_grammar({"blend_instructions": False}, diff --git a/textworld/generator/text_generation.py b/textworld/generator/text_generation.py index 48a91915..816fbe97 100644 --- a/textworld/generator/text_generation.py +++ b/textworld/generator/text_generation.py @@ -4,9 +4,9 @@ import re from collections import OrderedDict -from typing import Union, Iterable -from textworld.generator.game import Quest, EventCondition, EventAction, EventAnd, EventOr, Game +from textworld.generator.game import Quest, Game +from textworld.generator.game import AbstractEvent from textworld.generator.text_grammar import Grammar from textworld.generator.text_grammar import fix_determinant @@ -381,305 +381,19 @@ def generate_instruction(action, grammar, game, counts): return desc, separator -def make_str(txt): - Text = [] - for t in txt: - if len(t) > 0: - Text += [t] - return Text - - -def quest_counter(counter): - if counter == 0: - return '' - elif counter == 1: - return 'First' - elif counter == 2: - return 'Second' - elif counter == 3: - return 'Third' - else: - return str(counter) + 'th' - - -def describe_quests(game: Game, grammar: Grammar): - counter = 1 - quests_desc_arr = [] - for quest in game.quests: - if quest.desc: - quests_desc_arr.append("The " + quest_counter(counter) + " quest: \n" + quest.desc) - counter += 1 - - quests_desc_arr - if quests_desc_arr: - quests_desc_ = " \n ".join(txt for txt in quests_desc_arr if txt) - quests_desc_ = ": \n " + quests_desc_ + " \n *** " - quests_tag = grammar.get_random_expansion("#all_quests#") - quests_tag = quests_tag.replace("(quests_string)", quests_desc_.strip()) - quests_description = grammar.expand(quests_tag) - quests_description = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", - lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), - quests_description) - else: - quests_tag = grammar.get_random_expansion("#all_quests_non#") - quests_description = grammar.expand(quests_tag) - quests_description = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", - lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), - quests_description) - return quests_description - - def assign_description_to_quest(quest: Quest, game: Game, grammar: Grammar): - desc = [] - indx = '> ' - for event in quest.win_events[0].events: - if isinstance(event, EventCondition) or isinstance(event, EventAction): - st = assign_description_to_event(event, game, grammar) - else: - st = assign_description_to_combined_events(event, game, grammar, indx) + if quest.win_event is None: + return "" - if st: - desc += [st] + event_descriptions = [] + for event in quest.win_event: + event_descriptions += [describe_event(event, game, grammar)] - if quest.reward < 0: - return describe_punishing_quest(make_str(desc), grammar, indx) - else: - return describe_quest(make_str(desc), quest.win_events[0], grammar, indx) - - -def describe_punishing_quest(quest_desc: Iterable[str], grammar: Grammar, index_symbol='> '): - if len(quest_desc) == 0: - description = describe_punishing_quest_none(grammar) - else: - description = describe_punishing_quest(quest_desc, grammar, index_symbol) - - return description - - -def describe_punishing_quest_none(grammar: Grammar): - quest_tag = grammar.get_random_expansion("#punishing_quest_none#") - quest_desc = grammar.expand(quest_tag) - quest_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", - lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), - quest_desc) + quest_desc = " OR ".join(desc for desc in event_descriptions if desc) return quest_desc -def describe_punishing_quest(quest_desc: Iterable[str], grammar: Grammar, index_symbol) -> str: - only_one_task = len(quest_desc) < 2 - quest_desc = [index_symbol + desc for desc in quest_desc if desc] - quest_txt = " \n ".join(desc for desc in quest_desc if desc) - quest_txt = ": \n " + quest_txt - - if only_one_task: - quest_tag = grammar.get_random_expansion("#punishing_quest_one_task#") - quest_tag = quest_tag.replace("(combined_task)", quest_txt.strip()) - else: - quest_tag = grammar.get_random_expansion("#punishing_quest_tasks#") - quest_tag = quest_tag.replace("(list_of_combined_tasks)", quest_txt.strip()) - - description = grammar.expand(quest_tag) - description = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", - lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), - description) - return description - - -def describe_quest(quest_desc: Iterable[str], combination_rule: Iterable[Union[EventOr, EventAnd]], - grammar: Grammar, index_symbol='> '): - if len(quest_desc) == 0: - description = describe_quest_none(grammar) - else: - if isinstance(combination_rule, EventOr): - description = describe_quest_or(quest_desc, grammar, index_symbol) - elif isinstance(combination_rule, EventAnd): - description = describe_quest_and(quest_desc, grammar, index_symbol) - - return description - - -def describe_quest_none(grammar: Grammar): - quest_tag = grammar.get_random_expansion("#quest_none#") - quest_desc = grammar.expand(quest_tag) - quest_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", - lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), - quest_desc) - return quest_desc - - -def describe_quest_or(quest_desc: Iterable[str], grammar: Grammar, index_symbol) -> str: - only_one_task = len(quest_desc) < 2 - quest_desc = [index_symbol + desc for desc in quest_desc if desc] - quest_txt = " \n ".join(desc for desc in quest_desc if desc) - quest_txt = ": \n " + quest_txt - - if only_one_task: - quest_tag = grammar.get_random_expansion("#quest_one_task#") - quest_tag = quest_tag.replace("(combined_task)", quest_txt.strip()) - else: - quest_tag = grammar.get_random_expansion("#quest_or_tasks#") - quest_tag = quest_tag.replace("(list_of_combined_tasks)", quest_txt.strip()) - - description = grammar.expand(quest_tag) - description = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", - lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), - description) - return description - - -def describe_quest_and(quest_desc: Iterable[str], grammar: Grammar, index_symbol) -> str: - only_one_task = len(quest_desc) < 2 - quest_desc = [index_symbol + desc for desc in quest_desc if desc] - quest_txt = " \n ".join(desc for desc in quest_desc if desc) - quest_txt = ": \n " + quest_txt - - if only_one_task: - quest_tag = grammar.get_random_expansion("#quest_one_task#") - quest_tag = quest_tag.replace("(combined_task)", quest_txt.strip()) - else: - quest_tag = grammar.get_random_expansion("#quest_and_tasks#") - quest_tag = quest_tag.replace("(list_of_combined_tasks)", quest_txt.strip()) - - description = grammar.expand(quest_tag) - description = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", - lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), - description) - return description - - -def assign_description_to_combined_events(events: Union[EventAnd, EventOr], game: Game, grammar: Grammar, index_symbol, - _desc=[]): - if isinstance(events, EventCondition) or isinstance(events, EventAction): - _desc += [assign_description_to_event(events, game, grammar)] - return - - index_symbol = '-' + index_symbol - desc, ev_type = [], [] - for event in events.events: - st = assign_description_to_combined_events(event, game, grammar, index_symbol, desc) - ev_type.append(isinstance(event, EventCondition) or isinstance(event, EventAction)) - - if st: - desc += [st] - - if all(ev_type): - st1 = combine_events(make_str(desc), events, grammar) - else: - st1 = combine_tasks(make_str(desc), events, grammar, index_symbol) - - return st1 - - -def combine_events(events: Iterable[str], combination_rule: Iterable[Union[EventOr, EventAnd]], grammar: Grammar): - if len(events) == 0: - events_desc = "" - else: - if isinstance(combination_rule, EventOr): - events_desc = describe_event_or(events, grammar) - elif isinstance(combination_rule, EventAnd): - events_desc = describe_event_and(events, grammar) - - return events_desc - - -def describe_event_or(events_desc: Iterable[str], grammar: Grammar) -> str: - only_one_event = len(events_desc) < 2 - combined_event_txt = " , or, ".join(desc for desc in events_desc if desc) - combined_event_txt = ": " + combined_event_txt - - if only_one_event: - combined_event_tag = grammar.get_random_expansion("#combined_one_event#") - combined_event_tag = combined_event_tag.replace("(only_event)", combined_event_txt.strip()) - else: - combined_event_tag = grammar.get_random_expansion("#combined_or_events#") - combined_event_tag = combined_event_tag.replace("(list_of_events)", combined_event_txt.strip()) - - combined_event_desc = grammar.expand(combined_event_tag) - combined_event_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", - lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), - combined_event_desc) - - return combined_event_desc - - -def describe_event_and(events_desc: Iterable[str], grammar: Grammar) -> str: - only_one_event = len(events_desc) < 2 - combined_event_txt = " , and, ".join(desc for desc in events_desc if desc) - combined_event_txt = ": " + combined_event_txt - - if only_one_event: - combined_event_tag = grammar.get_random_expansion("#combined_one_event#") - combined_event_tag = combined_event_tag.replace("(only_event)", combined_event_txt.strip()) - else: - combined_event_tag = grammar.get_random_expansion("#combined_and_events#") - combined_event_tag = combined_event_tag.replace("(list_of_events)", combined_event_txt.strip()) - - combined_event_desc = grammar.expand(combined_event_tag) - combined_event_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", - lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), - combined_event_desc) - - return combined_event_desc - - -def combine_tasks(tasks: Iterable[str], combination_rule: Iterable[Union[EventOr, EventAnd]], - grammar: Grammar, index_symbol: str): - if len(tasks) == 0: - tasks_desc = "" - else: - if isinstance(combination_rule, EventOr): - tasks_desc = describe_tasks_or(tasks, grammar, index_symbol) - if isinstance(combination_rule, EventAnd): - tasks_desc = describe_tasks_and(tasks, grammar, index_symbol) - - return tasks_desc - - -def describe_tasks_and(tasks_desc: Iterable[str], grammar: Grammar, index_symbol: str) -> str: - only_one_task = len(tasks_desc) < 2 - tasks_desc = [index_symbol + desc for desc in tasks_desc if desc] - tasks_txt = " \n ".join(desc for desc in tasks_desc if desc) - tasks_txt = ": \n " + tasks_txt - - if only_one_task: - combined_task_tag = grammar.get_random_expansion("#combined_one_task#") - combined_task_tag = combined_task_tag.replace("(only_task)", tasks_txt.strip()) - else: - combined_task_tag = grammar.get_random_expansion("#combined_and_tasks#") - combined_task_tag = combined_task_tag.replace("(list_of_tasks)", tasks_txt.strip()) - - combined_task_desc = grammar.expand(combined_task_tag) - combined_task_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", - lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), - combined_task_desc) - return combined_task_desc - - -def describe_tasks_or(tasks_desc: Iterable[str], grammar: Grammar, index_symbol: str) -> str: - only_one_task = len(tasks_desc) < 2 - tasks_desc = [index_symbol + desc for desc in tasks_desc if desc] - tasks_txt = " \n ".join(desc for desc in tasks_desc if desc) - tasks_txt = ": \n " + tasks_txt - - if only_one_task: - combined_task_tag = grammar.get_random_expansion("#combined_one_task#") - combined_task_tag = combined_task_tag.replace("(only_task)", tasks_txt.strip()) - else: - combined_task_tag = grammar.get_random_expansion("#combined_and_tasks#") - combined_task_tag = combined_task_tag.replace("(list_of_tasks)", tasks_txt.strip()) - - combined_task_desc = grammar.expand(combined_task_tag) - combined_task_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", - lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), - combined_task_desc) - return combined_task_desc - - -def assign_description_to_event(events: Union[EventAction, EventCondition], game: Game, grammar: Grammar): - return describe_event(events, game, grammar) - - -def describe_event(event: Union[EventCondition, EventAction], game: Game, grammar: Grammar) -> str: +def describe_event(event: AbstractEvent, game: Game, grammar: Grammar) -> str: """ Assign a descripton to a quest. """ @@ -714,7 +428,7 @@ def describe_event(event: Union[EventCondition, EventAction], game: Game, gramma if grammar.options.blend_instructions: instructions = get_action_chains(event.actions, grammar, game) else: - instructions = [act for act in event.actions] + instructions = event.actions only_one_action = len(instructions) < 2 for c in instructions: @@ -724,13 +438,19 @@ def describe_event(event: Union[EventCondition, EventAction], game: Game, gramma actions_desc_list.append(separator) actions_desc = " ".join(actions_desc_list) - event_tag = grammar.get_random_expansion("#event#") - event_tag = event_tag.replace("(list_of_actions)", actions_desc.strip()) + if only_one_action: + quest_tag = grammar.get_random_expansion("#quest_one_action#") + quest_tag = quest_tag.replace("(action)", actions_desc.strip()) - event_desc = grammar.expand(event_tag) + else: + quest_tag = grammar.get_random_expansion("#quest#") + quest_tag = quest_tag.replace("(list_of_actions)", actions_desc.strip()) + + event_desc = grammar.expand(quest_tag) event_desc = re.sub(r"(^|(?<=[?!.]))\s*([a-z])", lambda pat: pat.group(1) + ' ' + pat.group(2).upper(), event_desc) + return event_desc diff --git a/textworld/generator/vtypes.py b/textworld/generator/vtypes.py index f9ceb9de..8ee08b1d 100644 --- a/textworld/generator/vtypes.py +++ b/textworld/generator/vtypes.py @@ -178,10 +178,7 @@ def load(cls, path: str): def __getitem__(self, vtype): """ Get VariableType object from its type string. """ vtype = vtype.rstrip("'") - if vtype in self.variables_types.keys(): - return self.variables_types[vtype] - else: - return None + return self.variables_types[vtype] def __contains__(self, vtype): vtype = vtype.rstrip("'") @@ -202,10 +199,9 @@ def descendants(self, vtype): return [] descendants = [] - if self[vtype]: - for child_type in self[vtype].children: - descendants.append(child_type) - descendants += self.descendants(child_type) + for child_type in self[vtype].children: + descendants.append(child_type) + descendants += self.descendants(child_type) return descendants diff --git a/textworld/generator/world.py b/textworld/generator/world.py index 095af9fc..f96d5eab 100644 --- a/textworld/generator/world.py +++ b/textworld/generator/world.py @@ -256,9 +256,9 @@ def _process_rooms(self) -> None: room = self._get_room(fact.arguments[0]) room.add_related_fact(fact) - if fact.definition.endswith("_of"): + if fact.name.endswith("_of"): # Handle room positioning facts. - exit = reverse_direction(fact.definition.split("_of")[0]) + exit = reverse_direction(fact.name.split("_of")[0]) dest = self._get_room(fact.arguments[1]) dest.add_related_fact(fact) assert exit not in room.exits diff --git a/textworld/logic/__init__.py b/textworld/logic/__init__.py index ba1e5019..007a684f 100644 --- a/textworld/logic/__init__.py +++ b/textworld/logic/__init__.py @@ -63,24 +63,6 @@ def _check_type_conflict(name, old_type, new_type): raise ValueError("Conflicting types for `{}`: have `{}` and `{}`.".format(name, old_type, new_type)) -class UnderspecifiedSignatureError(NameError): - def __init__(self): - msg = "The verb and definition of the signature either should both be None or both take values." - super().__init__(msg) - - -class UnderspecifiedPredicateError(NameError): - def __init__(self): - msg = "The verb and definition of the predicate either should both be None or both take values." - super().__init__(msg) - - -class UnderspecifiedPropositionError(NameError): - def __init__(self): - msg = "The verb and definition of the proposition either should both be None or both take values." - super().__init__(msg) - - class _ModelConverter(NodeWalker): """ Converts TatSu model objects to our types. @@ -555,9 +537,6 @@ def deserialize(cls, data: Mapping) -> "Variable": cls, kwargs.get("name", args[0] if len(args) >= 1 else None), tuple(kwargs.get("types", args[1] if len(args) == 2 else [])) - # tuple(kwargs.get("types", args[1] if len(args) >= 2 else [])), - # kwargs.get("verb", args[2] if len(args) >= 3 else None), - # kwargs.get("definition", args[3] if len(args) == 4 else None), ) ) @@ -568,7 +547,7 @@ class Signature(with_metaclass(SignatureTracker, object)): The type signature of a Predicate or Proposition. """ - __slots__ = ("name", "types", "_hash", "verb", "definition") + __slots__ = ("name", "types", "_hash") def __init__(self, name: str, types: Iterable[str]): """ @@ -582,19 +561,8 @@ def __init__(self, name: str, types: Iterable[str]): The types of the parameters to the proposition/predicate. """ - if name.count('__') == 0: - self.verb = "is" - self.definition = name - self.name = "is__" + name - else: - self.verb = name[:name.find('__')] - self.definition = name[name.find('__') + 2:] - self.name = name - - # self.name = name + self.name = name self.types = tuple(types) - # self.verb = verb - # self.definition = definition self._hash = hash((self.name, self.types)) def __str__(self): @@ -636,11 +604,7 @@ def parse(cls, expr: str) -> "Signature": lambda cls, args, kwargs: ( cls, kwargs.get("name", args[0] if len(args) >= 1 else None), - tuple(v.name for v in kwargs.get("arguments", args[1] if len(args) == 2 else [])), - # tuple(v.name for v in kwargs.get("arguments", args[1] if len(args) >= 2 else [])), - # kwargs.get("verb", args[2] if len(args) >= 3 else None), - # kwargs.get("definition", args[3] if len(args) >= 4 else None), - # kwargs.get("activate", args[4] if len(args) == 5 else 0) + tuple(v.name for v in kwargs.get("arguments", args[1] if len(args) == 2 else [])) ) ) @@ -651,7 +615,7 @@ class Proposition(with_metaclass(PropositionTracker, object)): An instantiated Predicate, with concrete variables for each placeholder. """ - __slots__ = ("name", "arguments", "signature", "_hash", "verb", "definition", "activate") + __slots__ = ("name", "arguments", "signature", "_hash") def __init__(self, name: str, arguments: Iterable[Variable] = []): """ @@ -665,27 +629,11 @@ def __init__(self, name: str, arguments: Iterable[Variable] = []): The variables this proposition is applied to. """ - if name.count('__') == 0: - self.verb = "is" - self.definition = name - self.name = "is__" + name - else: - self.verb = name[:name.find('__')].replace('_', ' ') - self.definition = name[name.find('__') + 2:] - self.name = name - - # self.name = name + self.name = name self.arguments = tuple(arguments) - # self.verb = verb - # self.definition = definition self.signature = Signature(name, [var.type for var in self.arguments]) self._hash = hash((self.name, self.arguments)) - # if self.verb == 'is': - # activate = 1 - # - # self.activate = activate - @property def names(self) -> Collection[str]: """ @@ -708,7 +656,7 @@ def __repr__(self): def __eq__(self, other): if isinstance(other, Proposition): - return (self.name, self.arguments) == (other.name, other.arguments) + return self.name == other.name and self.arguments == other.arguments else: return NotImplemented @@ -737,17 +685,12 @@ def serialize(self) -> Mapping: return { "name": self.name, "arguments": [var.serialize() for var in self.arguments], - # "verb": self.verb, - # "definition": self.definition, } @classmethod def deserialize(cls, data: Mapping) -> "Proposition": name = data["name"] args = [Variable.deserialize(arg) for arg in data["arguments"]] - # verb = data["verb"] - # definition = data["definition"] - # activate = data["activate"] return cls(name, args) @@ -844,19 +787,8 @@ def __init__(self, name: str, parameters: Iterable[Placeholder]): The symbolic arguments to this predicate. """ - if name.count('__') == 0: - self.verb = "is" - self.definition = name - self.name = "is__" + name - else: - self.verb = name[:name.find('__')] - self.definition = name[name.find('__') + 2:] - self.name = name - - # self.name = name + self.name = name self.parameters = tuple(parameters) - # self.verb = verb - # self.definition = definition self.signature = Signature(name, [ph.type for ph in self.parameters]) @property @@ -886,7 +818,7 @@ def __eq__(self, other): return NotImplemented def __hash__(self): - return hash((self.name, self.types)) + return hash((self.name, self.parameters)) def __lt__(self, other): if isinstance(other, Predicate): @@ -910,17 +842,12 @@ def serialize(self) -> Mapping: return { "name": self.name, "parameters": [ph.serialize() for ph in self.parameters], - # "verb": self.verb, - # "definition": self.definition } @classmethod def deserialize(cls, data: Mapping) -> "Predicate": name = data["name"] params = [Placeholder.deserialize(ph) for ph in data["parameters"]] - # verb = data["verb"] - # definition = data["definition"] - # return cls(name, params, verb, definition) return cls(name, params) def substitute(self, mapping: Mapping[Placeholder, Placeholder]) -> "Predicate": @@ -934,7 +861,6 @@ def substitute(self, mapping: Mapping[Placeholder, Placeholder]) -> "Predicate": """ params = [mapping.get(param, param) for param in self.parameters] - # return Predicate(self.name, params, self.verb, self.definition) return Predicate(self.name, params) def instantiate(self, mapping: Mapping[Placeholder, Variable]) -> Proposition: @@ -952,8 +878,7 @@ def instantiate(self, mapping: Mapping[Placeholder, Variable]) -> Proposition: """ args = [mapping[param] for param in self.parameters] - # return Proposition(self.name, arguments=args, verb=self.verb, definition=self.definition) - return Proposition(self.name, arguments=args) + return Proposition(self.name, args) def match(self, proposition: Proposition) -> Optional[Mapping[Placeholder, Variable]]: """ @@ -1140,21 +1065,6 @@ def format_command(self, mapping: Dict[str, str] = {}): mapping = mapping or {v.name: v.name for v in self.variables} return self.command_template.format(**mapping) - def has_traceable(self): - for prop in self.all_propositions: - if not prop.name.startswith('is__'): - return True - return False - - def activate_traceable(self): - for prop in self.all_propositions: - if not prop.name.startswith('is__'): - prop.activate = 1 - - # def is_valid(self): - # aa = self.all_propositions - # return all([prop.activate == 1 for prop in self.all_propositions]) - class Rule: """ @@ -1281,6 +1191,7 @@ def instantiate(self, mapping: Mapping[Placeholder, Variable]) -> Action: ------- The instantiated Action with each Placeholder mapped to the corresponding Variable. """ + key = tuple(mapping[ph] for ph in self.placeholders) if key in self._cache: return self._cache[key] @@ -1288,8 +1199,7 @@ def instantiate(self, mapping: Mapping[Placeholder, Variable]) -> Action: pre_inst = [pred.instantiate(mapping) for pred in self.preconditions] post_inst = [pred.instantiate(mapping) for pred in self.postconditions] action = Action(self.name, pre_inst, post_inst) - if action.has_traceable(): - action.activate_traceable() + action.command_template = self._make_command_template(mapping) if self.reverse_rule: action.reverse_name = self.reverse_rule.name @@ -1555,23 +1465,6 @@ def _normalize_predicates(self, predicates): result.append(pred) return result - def _predicate_diversity(self): - new_preds = [] - for pred in self.predicates: - for v in ['was', 'has been', 'had been']: - new_preds.append(Signature(name=v.replace(' ', '_') + pred.name[pred.name.find('__'):], types=pred.types)) - self.predicates.update(set(new_preds)) - - def _inform7_predicates_diversity(self): - new_preds = {} - for k, v in self.inform7.predicates.items(): - for vt in ['was', 'has been', 'had been']: - new_preds[Signature(name=vt.replace(' ', '_') + k.name[k.name.find('__'):], types=k.types)] = \ - Inform7Predicate(predicate=Predicate(name=vt.replace(' ', '_') + v.predicate.name[v.predicate.name.find('__'):], - parameters=v.predicate.parameters), - source=v.source.replace('is', vt)) - self.inform7.predicates.update(new_preds) - @classmethod @lru_cache(maxsize=128, typed=False) def parse(cls, document: str) -> "GameLogic": @@ -1586,8 +1479,6 @@ def load(cls, paths: Iterable[str]): for path in paths: with open(path, "r") as f: result._parse(f.read(), path=path) - result._predicate_diversity() - result._inform7_predicates_diversity() result._initialize() return result @@ -1692,6 +1583,7 @@ def are_facts(self, props: Iterable[Proposition]) -> bool: for prop in props: if not self.is_fact(prop): return False + return True @property @@ -1893,6 +1785,7 @@ def all_assignments(self, seen_phs.add(ph) new_phs_by_depth.append(new_phs) + # Placeholders uniquely found in postcondition are considered as free variables. free_vars = [ph for ph in rule.placeholders if ph not in seen_phs] new_phs_by_depth.append(free_vars) @@ -2036,24 +1929,3 @@ def __str__(self): lines.append("})") return "\n".join(lines) - - def get_facts(self): - all_facts = [] - for sig in sorted(self._facts.keys()): - facts = self._facts[sig] - if len(facts) == 0: - continue - for fact in sorted(facts): - all_facts.append(fact) - return all_facts - - def has_traceable(self): - for prop in self.get_facts(): - if not prop.name.startswith('is__'): - return True - return False - - @property - def logic(self): - return self._logic - diff --git a/textworld/logic/parser.py b/textworld/logic/parser.py index d18db13a..e86ae355 100644 --- a/textworld/logic/parser.py +++ b/textworld/logic/parser.py @@ -117,33 +117,20 @@ def _variable_(self): # noqa [] ) - def _predVT_(self, name): - self._constant(name[:name.find('__')].replace('_', ' ')) - - def _predDef_(self, name): - self._constant(name[name.find('__') + 2:]) - @tatsumasu('SignatureNode') def _signature_(self): # noqa - self._predName_() - if self.cst.count('__') == 0: - self.last_node = 'is__' + self.cst self.name_last_node('name') - # self._predVT_(self.ast['name']) - # self.name_last_node('verb') - # self._predDef_(self.ast['name']) - # self.name_last_node('definition') self._token('(') def sep2(): self._token(',') + def block2(): self._name_() self._gather(block2, sep2) self.name_last_node('types') self._token(')') - self.ast._define( ['name', 'types'], [] @@ -152,23 +139,17 @@ def block2(): @tatsumasu('PropositionNode') def _proposition_(self): # noqa self._predName_() - if self.cst.count('__') == 0: - self.last_node = 'is__' + self.cst self.name_last_node('name') - # self._predVT_(self.ast['name']) - # self.name_last_node('verb') - # self._predDef_(self.ast['name']) - # self.name_last_node('definition') self._token('(') def sep2(): self._token(',') + def block2(): self._variable_() self._gather(block2, sep2) self.name_last_node('arguments') self._token(')') - self.ast._define( ['arguments', 'name'], [] @@ -229,23 +210,17 @@ def _placeholder_(self): # noqa @tatsumasu('PredicateNode') def _predicate_(self): # noqa self._predName_() - if self.cst.count('__') == 0: - self.last_node = 'is__' + self.cst self.name_last_node('name') - # self._predVT_(self.ast['name']) - # self.name_last_node('verb') - # self._predDef_(self.ast['name']) - # self.name_last_node('definition') self._token('(') def sep2(): self._token(',') + def block2(): self._placeholder_() self._gather(block2, sep2) self.name_last_node('parameters') self._token(')') - self.ast._define( ['name', 'parameters'], [] diff --git a/textworld/testing.py b/textworld/testing.py index ad6a4e0f..46782a69 100644 --- a/textworld/testing.py +++ b/textworld/testing.py @@ -6,12 +6,13 @@ import sys import contextlib -from typing import Tuple +from typing import Tuple, Optional import numpy as np import textworld from textworld.generator.game import Event, Quest, Game +from textworld.generator.game import EventAction, EventCondition, EventOr, EventAnd from textworld.generator.game import GameOptions @@ -38,7 +39,7 @@ def _compile_test_game(game, options: GameOptions) -> str: "instruction_extension": [] } rng_grammar = np.random.RandomState(1234) - grammar = textworld.generator.make_grammar(grammar_flags, rng=rng_grammar) + grammar = textworld.generator.make_grammar(grammar_flags, rng=rng_grammar, kb=options.kb) game.change_grammar(grammar) game_file = textworld.generator.compile_game(game, options) @@ -107,3 +108,102 @@ def build_and_compile_game(options: GameOptions) -> Tuple[Game, str]: game = build_game(options) game_file = _compile_test_game(game, options) return game, game_file + + +def build_complex_test_game(options: Optional[GameOptions] = None): + M = textworld.GameMaker(options) + + # The goal + quest1_cmds1 = ["open chest", "take carrot", "insert carrot into chest", "close chest"] + quest1_cmds2 = ["open chest", "take onion", "insert onion into chest", "close chest"] + quest2_cmds = ["take knife", "put knife on counter"] + + kitchen = M.new_room("kitchen") + M.set_player(kitchen) + + counter = M.new(type='s', name='counter') + chest = M.new(type='c', name='chest') + chest.add_property("closed") + carrot = M.new(type='f', name='carrot') + onion = M.new(type='f', name='onion') + knife = M.new(type='o', name='knife') + kitchen.add(chest, counter, carrot, onion, knife) + + carrot_in_chest = EventCondition(conditions={M.new_fact("in", carrot, chest)}) + onion_in_chest = EventCondition(conditions={M.new_fact("in", onion, chest)}) + closing_chest = EventAction(action=M.new_action("close/c", chest)) + + either_carrot_or_onion_in_chest = EventOr(events=(carrot_in_chest, onion_in_chest)) + closing_chest_with_either_carrot_or_onion = EventAnd(events=(either_carrot_or_onion_in_chest, closing_chest)) + + carrot_in_inventory = EventCondition(conditions={M.new_fact("in", carrot, M.inventory)}) + closing_chest_without_carrot = EventAnd(events=(carrot_in_inventory, closing_chest)) + + eating_carrot = EventAction(action=M.new_action("eat", carrot)) + onion_eaten = EventCondition(conditions={M.new_fact("eaten", onion)}) + + quest1 = Quest( + win_event=closing_chest_with_either_carrot_or_onion, + fail_event=EventOr([ + closing_chest_without_carrot, + EventAnd([ + eating_carrot, + onion_eaten + ]) + ]), + reward=3, + ) + + knife_on_counter = EventCondition(conditions={M.new_fact("on", knife, counter)}) + + quest2 = Quest( + win_event=knife_on_counter, + reward=5, + ) + + carrot_in_chest.name = "carrot_in_chest" + onion_in_chest.name = "onion_in_chest" + closing_chest.name = "closing_chest" + either_carrot_or_onion_in_chest.name = "either_carrot_or_onion_in_chest" + closing_chest_with_either_carrot_or_onion.name = "closing_chest_with_either_carrot_or_onion" + carrot_in_inventory.name = "carrot_in_inventory" + closing_chest_without_carrot.name = "closing_chest_without_carrot" + eating_carrot.name = "eating_carrot" + onion_eaten.name = "onion_eaten" + knife_on_counter.name = "knife_on_counter" + + M.quests = [quest1, quest2] + M.set_walkthrough( + quest1_cmds1, + quest1_cmds2, + quest2_cmds + ) + game = M.build() + + eating_carrot.commands = ["take carrot", "eat carrot"] + eating_carrot.actions = M.get_action_from_commands(eating_carrot.commands) + onion_eaten.commands = ["take onion", "eat onion"] + onion_eaten.actions = M.get_action_from_commands(onion_eaten.commands) + closing_chest_without_carrot.commands = ["take carrot", "open chest", "close chest"] + closing_chest_without_carrot.actions = M.get_action_from_commands(closing_chest_without_carrot.commands) + knife_on_counter.commands = ["take knife", "put knife on counter"] + knife_on_counter.actions = M.get_action_from_commands(knife_on_counter.commands) + + data = { + "game": game, + "quest": quest1, + "quest1": quest1, + "quest2": quest2, + "carrot_in_chest": carrot_in_chest, + "onion_in_chest": onion_in_chest, + "closing_chest": closing_chest, + "either_carrot_or_onion_in_chest": either_carrot_or_onion_in_chest, + "closing_chest_with_either_carrot_or_onion": closing_chest_with_either_carrot_or_onion, + "carrot_in_inventory": carrot_in_inventory, + "closing_chest_without_carrot": closing_chest_without_carrot, + "eating_carrot": eating_carrot, + "onion_eaten": onion_eaten, + "knife_on_counter": knife_on_counter, + } + + return data