diff --git a/libraries/botbuilder-core/botbuilder/core/bot_state.py b/libraries/botbuilder-core/botbuilder/core/bot_state.py index 855c25ff3..53a45ff64 100644 --- a/libraries/botbuilder-core/botbuilder/core/bot_state.py +++ b/libraries/botbuilder-core/botbuilder/core/bot_state.py @@ -94,7 +94,7 @@ async def save_changes(self, turn_context: TurnContext, force: bool = False) -> cached_state = turn_context.turn_state.get(self._context_service_key) - if force or (cached_state != None and cached_state.is_changed == True): + if force or (cached_state is not None and cached_state.is_changed): storage_key = self.get_storage_key(turn_context) changes : Dict[str, object] = { storage_key: cached_state.state } await self._storage.write(changes) @@ -132,7 +132,7 @@ async def delete(self, turn_context: TurnContext) -> None: await self._storage.delete({ storage_key }) @abstractmethod - async def get_storage_key(self, turn_context: TurnContext) -> str: + def get_storage_key(self, turn_context: TurnContext) -> str: raise NotImplementedError() async def get_property_value(self, turn_context: TurnContext, property_name: str): diff --git a/libraries/botbuilder-core/botbuilder/core/conversation_state.py b/libraries/botbuilder-core/botbuilder/core/conversation_state.py index 967061f0f..69e9a50b1 100644 --- a/libraries/botbuilder-core/botbuilder/core/conversation_state.py +++ b/libraries/botbuilder-core/botbuilder/core/conversation_state.py @@ -21,22 +21,18 @@ def __init__(self, storage: Storage): Where to store namespace: str """ - def call_get_storage_key(context): - key = self.get_storage_key(context) - if key is None: - raise AttributeError(self.no_key_error_message) - else: - return key super(ConversationState, self).__init__(storage, 'ConversationState') - def get_storage_key(self, context: TurnContext): - activity = context.activity - channel_id = getattr(activity, 'channel_id', None) - conversation_id = getattr(activity.conversation, 'id', None) if hasattr(activity, 'conversation') else None + channel_id = context.activity.channel_id or self.__raise_type_error("invalid activity-missing channel_id") + conversation_id = context.activity.conversation.id or self.__raise_type_error( + "invalid activity-missing conversation.id") storage_key = None if channel_id and conversation_id: storage_key = "%s/conversations/%s" % (channel_id,conversation_id) return storage_key + + def __raise_type_error(self, err: str = 'NoneType found while expecting value'): + raise TypeError(err) \ No newline at end of file diff --git a/libraries/botbuilder-core/botbuilder/core/memory_storage.py b/libraries/botbuilder-core/botbuilder/core/memory_storage.py index aeb7a099c..25bb48f5a 100644 --- a/libraries/botbuilder-core/botbuilder/core/memory_storage.py +++ b/libraries/botbuilder-core/botbuilder/core/memory_storage.py @@ -3,6 +3,7 @@ from typing import Dict, List from .storage import Storage, StoreItem +from copy import deepcopy class MemoryStorage(Storage): @@ -35,7 +36,6 @@ async def write(self, changes: Dict[str, StoreItem]): # iterate over the changes for (key, change) in changes.items(): new_value = change - old_state = None old_state_etag = None # Check if the a matching key already exists in self.memory @@ -51,13 +51,13 @@ async def write(self, changes: Dict[str, StoreItem]): new_state = new_value # Set ETag if applicable - if isinstance(new_value, StoreItem): + if hasattr(new_value, 'e_tag'): if old_state_etag is not None and new_value.e_tag != "*" and new_value.e_tag < old_state_etag: raise KeyError("Etag conflict.\nOriginal: %s\r\nCurrent: %s" % \ (new_value.e_tag, old_state_etag) ) new_state.e_tag = str(self._e_tag) self._e_tag += 1 - self.memory[key] = new_state + self.memory[key] = deepcopy(new_state) except Exception as e: raise e diff --git a/libraries/botbuilder-core/botbuilder/core/user_state.py b/libraries/botbuilder-core/botbuilder/core/user_state.py index 28aa8c660..870fcc877 100644 --- a/libraries/botbuilder-core/botbuilder/core/user_state.py +++ b/libraries/botbuilder-core/botbuilder/core/user_state.py @@ -21,13 +21,6 @@ def __init__(self, storage: Storage, namespace=''): """ self.namespace = namespace - def call_get_storage_key(context): - key = self.get_storage_key(context) - if key is None: - raise AttributeError(self.no_key_error_message) - else: - return key - super(UserState, self).__init__(storage, "UserState") def get_storage_key(self, context: TurnContext) -> str: @@ -36,11 +29,14 @@ def get_storage_key(self, context: TurnContext) -> str: :param context: :return: """ - activity = context.activity - channel_id = getattr(activity, 'channel_id', None) - user_id = getattr(activity.from_property, 'id', None) if hasattr(activity, 'from_property') else None + channel_id = context.activity.channel_id or self.__raise_type_error("invalid activity-missing channelId") + user_id = context.activity.from_property.id or self.__raise_type_error( + "invalid activity-missing from_property.id") storage_key = None if channel_id and user_id: storage_key = "%s/users/%s" % (channel_id, user_id) return storage_key + + def __raise_type_error(self, err: str = 'NoneType found while expecting value'): + raise TypeError(err) diff --git a/libraries/botbuilder-core/tests/test_bot_state.py b/libraries/botbuilder-core/tests/test_bot_state.py index f44da08a7..57ee2442a 100644 --- a/libraries/botbuilder-core/tests/test_bot_state.py +++ b/libraries/botbuilder-core/tests/test_bot_state.py @@ -3,9 +3,9 @@ import aiounittest from unittest.mock import MagicMock -from botbuilder.core import TurnContext, BotState, MemoryStorage, UserState +from botbuilder.core import BotState, ConversationState, MemoryStorage, Storage, StoreItem, TurnContext, UserState from botbuilder.core.adapters import TestAdapter -from botbuilder.schema import Activity +from botbuilder.schema import Activity, ConversationAccount from test_utilities import TestUtilities @@ -23,6 +23,23 @@ def key_factory(context): assert context is not None return STORAGE_KEY +class BotStateForTest(BotState): + def __init__(self, storage: Storage): + super().__init__(storage, f"BotState:BotState") + + def get_storage_key(self, turn_context: TurnContext) -> str: + return f"botstate/{turn_context.activity.channel_id}/{turn_context.activity.conversation.id}/BotState" + + +class CustomState(StoreItem): + def __init__(self, custom_string: str = None, e_tag: str = '*'): + super().__init__(custom_string=custom_string, e_tag=e_tag) + + +class TestPocoState: + def __init__(self, value=None): + self.value = value + class TestBotState(aiounittest.AsyncTestCase): storage = MemoryStorage() @@ -334,4 +351,134 @@ async def test_LoadSaveDelete(self): obj2 = dictionary["EmptyContext/users/empty@empty.context.org"] self.assertEqual("hello-2", obj2["property-a"]) with self.assertRaises(KeyError) as _: - obj2["property-b"] \ No newline at end of file + obj2["property-b"] + + async def test_state_use_bot_state_directly(self): + async def exec_test(context: TurnContext): + bot_state_manager = BotStateForTest(MemoryStorage()) + test_property = bot_state_manager.create_property("test") + + # read initial state object + await bot_state_manager.load(context) + + custom_state = await test_property.get(context, lambda: CustomState()) + + # this should be a 'CustomState' as nothing is currently stored in storage + assert isinstance(custom_state, CustomState) + + # amend property and write to storage + custom_state.custom_string = "test" + await bot_state_manager.save_changes(context) + + custom_state.custom_string = "asdfsadf" + + # read into context again + await bot_state_manager.load(context, True) + + custom_state = await test_property.get(context) + + # check object read from value has the correct value for custom_string + assert custom_state.custom_string == "test" + + adapter = TestAdapter(exec_test) + await adapter.send('start') + + async def test_user_state_bad_from_throws(self): + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + context.activity.from_property = None + test_property = user_state.create_property("test") + with self.assertRaises(AttributeError): + await test_property.get(context) + + async def test_conversation_state_bad_converation_throws(self): + dictionary = {} + user_state = ConversationState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + context.activity.conversation = None + test_property = user_state.create_property("test") + with self.assertRaises(AttributeError): + await test_property.get(context) + + async def test_clear_and_save(self): + turn_context = TestUtilities.create_empty_context() + turn_context.activity.conversation = ConversationAccount(id="1234") + + storage = MemoryStorage({}) + + # Turn 0 + bot_state1 = ConversationState(storage) + (await bot_state1 + .create_property("test-name") + .get(turn_context, lambda: TestPocoState())).value = "test-value" + await bot_state1.save_changes(turn_context) + + # Turn 1 + bot_state2 = ConversationState(storage) + value1 = (await bot_state2 + .create_property("test-name") + .get(turn_context, lambda: TestPocoState(value="default-value"))).value + + assert "test-value" == value1 + + # Turn 2 + bot_state3 = ConversationState(storage) + await bot_state3.clear_state(turn_context) + await bot_state3.save_changes(turn_context) + + # Turn 3 + bot_state4 = ConversationState(storage) + value2 = (await bot_state4 + .create_property("test-name") + .get(turn_context, lambda: TestPocoState(value="default-value"))).value + + assert "default-value", value2 + + async def test_bot_state_delete(self): + turn_context = TestUtilities.create_empty_context() + turn_context.activity.conversation = ConversationAccount(id="1234") + + storage = MemoryStorage({}) + + # Turn 0 + bot_state1 = ConversationState(storage) + (await bot_state1 + .create_property("test-name") + .get(turn_context, lambda: TestPocoState())).value = "test-value" + await bot_state1.save_changes(turn_context) + + # Turn 1 + bot_state2 = ConversationState(storage) + value1 = (await bot_state2 + .create_property("test-name") + .get(turn_context, lambda: TestPocoState(value="default-value"))).value + + assert "test-value" == value1 + + # Turn 2 + bot_state3 = ConversationState(storage) + await bot_state3.delete(turn_context) + + # Turn 3 + bot_state4 = ConversationState(storage) + value2 = (await bot_state4 + .create_property("test-name") + .get(turn_context, lambda: TestPocoState(value="default-value"))).value + + assert "default-value" == value2 + + async def test_bot_state_get(self): + turn_context = TestUtilities.create_empty_context() + turn_context.activity.conversation = ConversationAccount(id="1234") + + storage = MemoryStorage({}) + + conversation_state = ConversationState(storage) + (await conversation_state + .create_property("test-name") + .get(turn_context, lambda: TestPocoState())).value = "test-value" + + result = conversation_state.get(turn_context) + + assert "test-value" == result["test-name"].value \ No newline at end of file