diff --git a/libraries/botbuilder-azure/botbuilder/azure/__init__.py b/libraries/botbuilder-azure/botbuilder/azure/__init__.py index 54dea209d..9980f8aa4 100644 --- a/libraries/botbuilder-azure/botbuilder/azure/__init__.py +++ b/libraries/botbuilder-azure/botbuilder/azure/__init__.py @@ -7,6 +7,10 @@ from .about import __version__ from .cosmosdb_storage import CosmosDbStorage, CosmosDbConfig, CosmosDbKeyEscape +from .cosmosdb_partitioned_storage import ( + CosmosDbPartitionedStorage, + CosmosDbPartitionedConfig, +) from .blob_storage import BlobStorage, BlobStorageSettings __all__ = [ @@ -15,5 +19,7 @@ "CosmosDbStorage", "CosmosDbConfig", "CosmosDbKeyEscape", + "CosmosDbPartitionedStorage", + "CosmosDbPartitionedConfig", "__version__", ] diff --git a/libraries/botbuilder-azure/botbuilder/azure/blob_storage.py b/libraries/botbuilder-azure/botbuilder/azure/blob_storage.py index ae3ad1766..fada3fe53 100644 --- a/libraries/botbuilder-azure/botbuilder/azure/blob_storage.py +++ b/libraries/botbuilder-azure/botbuilder/azure/blob_storage.py @@ -3,7 +3,6 @@ from jsonpickle import encode from jsonpickle.unpickler import Unpickler - from azure.storage.blob import BlockBlobService, Blob, PublicAccess from botbuilder.core import Storage @@ -42,7 +41,7 @@ def __init__(self, settings: BlobStorageSettings): async def read(self, keys: List[str]) -> Dict[str, object]: if not keys: - raise Exception("Please provide at least one key to read from storage.") + raise Exception("Keys are required when reading") self.client.create_container(self.settings.container_name) self.client.set_container_acl( @@ -63,24 +62,31 @@ async def read(self, keys: List[str]) -> Dict[str, object]: return items async def write(self, changes: Dict[str, object]): + if changes is None: + raise Exception("Changes are required when writing") + if not changes: + return + self.client.create_container(self.settings.container_name) self.client.set_container_acl( self.settings.container_name, public_access=PublicAccess.Container ) - for name, item in changes.items(): - e_tag = ( - None if not hasattr(item, "e_tag") or item.e_tag == "*" else item.e_tag - ) - if e_tag: - item.e_tag = e_tag.replace('"', '\\"') + for (name, item) in changes.items(): + e_tag = item.e_tag if hasattr(item, "e_tag") else item.get("e_tag", None) + e_tag = None if e_tag == "*" else e_tag + if e_tag == "": + raise Exception("blob_storage.write(): etag missing") item_str = self._store_item_to_str(item) - self.client.create_blob_from_text( - container_name=self.settings.container_name, - blob_name=name, - text=item_str, - if_match=e_tag, - ) + try: + self.client.create_blob_from_text( + container_name=self.settings.container_name, + blob_name=name, + text=item_str, + if_match=e_tag, + ) + except Exception as error: + raise error async def delete(self, keys: List[str]): if keys is None: @@ -102,7 +108,6 @@ async def delete(self, keys: List[str]): def _blob_to_store_item(self, blob: Blob) -> object: item = json.loads(blob.content) item["e_tag"] = blob.properties.etag - item["id"] = blob.name result = Unpickler().restore(item) return result diff --git a/libraries/botbuilder-azure/botbuilder/azure/cosmosdb_partitioned_storage.py b/libraries/botbuilder-azure/botbuilder/azure/cosmosdb_partitioned_storage.py new file mode 100644 index 000000000..00c3bb137 --- /dev/null +++ b/libraries/botbuilder-azure/botbuilder/azure/cosmosdb_partitioned_storage.py @@ -0,0 +1,285 @@ +"""CosmosDB Middleware for Python Bot Framework. + +This is middleware to store items in CosmosDB. +Part of the Azure Bot Framework in Python. +""" + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +from typing import Dict, List +from threading import Semaphore +import json + +from azure.cosmos import documents, http_constants +from jsonpickle.pickler import Pickler +from jsonpickle.unpickler import Unpickler +import azure.cosmos.cosmos_client as cosmos_client # pylint: disable=no-name-in-module,import-error +import azure.cosmos.errors as cosmos_errors # pylint: disable=no-name-in-module,import-error +from botbuilder.core.storage import Storage +from botbuilder.azure import CosmosDbKeyEscape + + +class CosmosDbPartitionedConfig: + """The class for partitioned CosmosDB configuration for the Azure Bot Framework.""" + + def __init__( + self, + cosmos_db_endpoint: str = None, + auth_key: str = None, + database_id: str = None, + container_id: str = None, + cosmos_client_options: dict = None, + container_throughput: int = None, + **kwargs, + ): + """Create the Config object. + + :param cosmos_db_endpoint: The CosmosDB endpoint. + :param auth_key: The authentication key for Cosmos DB. + :param database_id: The database identifier for Cosmos DB instance. + :param container_id: The container identifier. + :param cosmos_client_options: The options for the CosmosClient. Currently only supports connection_policy and + consistency_level + :param container_throughput: The throughput set when creating the Container. Defaults to 400. + :return CosmosDbPartitionedConfig: + """ + self.__config_file = kwargs.get("filename") + if self.__config_file: + kwargs = json.load(open(self.__config_file)) + self.cosmos_db_endpoint = cosmos_db_endpoint or kwargs.get("cosmos_db_endpoint") + self.auth_key = auth_key or kwargs.get("auth_key") + self.database_id = database_id or kwargs.get("database_id") + self.container_id = container_id or kwargs.get("container_id") + self.cosmos_client_options = cosmos_client_options or kwargs.get( + "cosmos_client_options", {} + ) + self.container_throughput = container_throughput or kwargs.get( + "container_throughput" + ) + + +class CosmosDbPartitionedStorage(Storage): + """The class for partitioned CosmosDB middleware for the Azure Bot Framework.""" + + def __init__(self, config: CosmosDbPartitionedConfig): + """Create the storage object. + + :param config: + """ + super(CosmosDbPartitionedStorage, self).__init__() + self.config = config + self.client = None + self.database = None + self.container = None + self.__semaphore = Semaphore() + + async def read(self, keys: List[str]) -> Dict[str, object]: + """Read storeitems from storage. + + :param keys: + :return dict: + """ + if not keys: + raise Exception("Keys are required when reading") + + await self.initialize() + + store_items = {} + + for key in keys: + try: + escaped_key = CosmosDbKeyEscape.sanitize_key(key) + + read_item_response = self.client.ReadItem( + self.__item_link(escaped_key), {"partitionKey": escaped_key} + ) + document_store_item = read_item_response + if document_store_item: + store_items[document_store_item["realId"]] = self.__create_si( + document_store_item + ) + # When an item is not found a CosmosException is thrown, but we want to + # return an empty collection so in this instance we catch and do not rethrow. + # Throw for any other exception. + except cosmos_errors.HTTPFailure as err: + if ( + err.status_code + == cosmos_errors.http_constants.StatusCodes.NOT_FOUND + ): + continue + raise err + except Exception as err: + raise err + return store_items + + async def write(self, changes: Dict[str, object]): + """Save storeitems to storage. + + :param changes: + :return: + """ + if changes is None: + raise Exception("Changes are required when writing") + if not changes: + return + + await self.initialize() + + for (key, change) in changes.items(): + e_tag = change.get("e_tag", None) + doc = { + "id": CosmosDbKeyEscape.sanitize_key(key), + "realId": key, + "document": self.__create_dict(change), + } + if e_tag == "": + raise Exception("cosmosdb_storage.write(): etag missing") + + access_condition = { + "accessCondition": {"type": "IfMatch", "condition": e_tag} + } + options = ( + access_condition if e_tag != "*" and e_tag and e_tag != "" else None + ) + try: + self.client.UpsertItem( + database_or_Container_link=self.__container_link, + document=doc, + options=options, + ) + except cosmos_errors.HTTPFailure as err: + raise err + except Exception as err: + raise err + + async def delete(self, keys: List[str]): + """Remove storeitems from storage. + + :param keys: + :return: + """ + await self.initialize() + + for key in keys: + escaped_key = CosmosDbKeyEscape.sanitize_key(key) + try: + self.client.DeleteItem( + document_link=self.__item_link(escaped_key), + options={"partitionKey": escaped_key}, + ) + except cosmos_errors.HTTPFailure as err: + if ( + err.status_code + == cosmos_errors.http_constants.StatusCodes.NOT_FOUND + ): + continue + raise err + except Exception as err: + raise err + + async def initialize(self): + if not self.container: + if not self.client: + self.client = cosmos_client.CosmosClient( + self.config.cosmos_db_endpoint, + {"masterKey": self.config.auth_key}, + self.config.cosmos_client_options.get("connection_policy", None), + self.config.cosmos_client_options.get("consistency_level", None), + ) + + if not self.database: + with self.__semaphore: + try: + self.database = self.client.CreateDatabase( + {"id": self.config.database_id} + ) + except cosmos_errors.HTTPFailure: + self.database = self.client.ReadDatabase( + "dbs/" + self.config.database_id + ) + + if not self.container: + with self.__semaphore: + container_def = { + "id": self.config.container_id, + "partitionKey": { + "paths": ["/id"], + "kind": documents.PartitionKind.Hash, + }, + } + try: + self.container = self.client.CreateContainer( + "dbs/" + self.database["id"], + container_def, + {"offerThroughput": 400}, + ) + except cosmos_errors.HTTPFailure as err: + if err.status_code == http_constants.StatusCodes.CONFLICT: + self.container = self.client.ReadContainer( + "dbs/" + + self.database["id"] + + "/colls/" + + container_def["id"] + ) + else: + raise err + + @staticmethod + def __create_si(result) -> object: + """Create an object from a result out of CosmosDB. + + :param result: + :return object: + """ + # get the document item from the result and turn into a dict + doc = result.get("document") + # read the e_tag from Cosmos + if result.get("_etag"): + doc["e_tag"] = result["_etag"] + + result_obj = Unpickler().restore(doc) + + # create and return the object + return result_obj + + @staticmethod + def __create_dict(store_item: object) -> Dict: + """Return the dict of an object. + + This eliminates non_magic attributes and the e_tag. + + :param store_item: + :return dict: + """ + # read the content + json_dict = Pickler().flatten(store_item) + if "e_tag" in json_dict: + del json_dict["e_tag"] + + # loop through attributes and write and return a dict + return json_dict + + def __item_link(self, identifier) -> str: + """Return the item link of a item in the container. + + :param identifier: + :return str: + """ + return self.__container_link + "/docs/" + identifier + + @property + def __container_link(self) -> str: + """Return the container link in the database. + + :param: + :return str: + """ + return self.__database_link + "/colls/" + self.config.container_id + + @property + def __database_link(self) -> str: + """Return the database link. + + :return str: + """ + return "dbs/" + self.config.database_id diff --git a/libraries/botbuilder-azure/botbuilder/azure/cosmosdb_storage.py b/libraries/botbuilder-azure/botbuilder/azure/cosmosdb_storage.py index c8a25a017..3d588a864 100644 --- a/libraries/botbuilder-azure/botbuilder/azure/cosmosdb_storage.py +++ b/libraries/botbuilder-azure/botbuilder/azure/cosmosdb_storage.py @@ -160,6 +160,10 @@ async def write(self, changes: Dict[str, object]): :param changes: :return: """ + if changes is None: + raise Exception("Changes are required when writing") + if not changes: + return try: # check if the database and container exists and if not create if not self.__container_exists: @@ -167,13 +171,19 @@ async def write(self, changes: Dict[str, object]): # iterate over the changes for (key, change) in changes.items(): # store the e_tag - e_tag = change.e_tag + e_tag = ( + change.e_tag + if hasattr(change, "e_tag") + else change.get("e_tag", None) + ) # create the new document doc = { "id": CosmosDbKeyEscape.sanitize_key(key), "realId": key, "document": self.__create_dict(change), } + if e_tag == "": + raise Exception("cosmosdb_storage.write(): etag missing") # the e_tag will be * for new docs so do an insert if e_tag == "*" or not e_tag: self.client.UpsertItem( @@ -191,9 +201,6 @@ async def write(self, changes: Dict[str, object]): new_document=doc, options={"accessCondition": access_condition}, ) - # error when there is no e_tag - else: - raise Exception("cosmosdb_storage.write(): etag missing") except Exception as error: raise error diff --git a/libraries/botbuilder-azure/tests/test_blob_storage.py b/libraries/botbuilder-azure/tests/test_blob_storage.py index 40db0f61e..31f54a231 100644 --- a/libraries/botbuilder-azure/tests/test_blob_storage.py +++ b/libraries/botbuilder-azure/tests/test_blob_storage.py @@ -4,16 +4,29 @@ import pytest from botbuilder.core import StoreItem from botbuilder.azure import BlobStorage, BlobStorageSettings +from botbuilder.testing import StorageBaseTests # local blob emulator instance blob + BLOB_STORAGE_SETTINGS = BlobStorageSettings( - account_name="", account_key="", container_name="test" + account_name="", + account_key="", + container_name="test", + # Default Azure Storage Emulator Connection String + connection_string="AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq" + + "2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;DefaultEndpointsProtocol=http;BlobEndpoint=" + + "http://127.0.0.1:10000/devstoreaccount1;QueueEndpoint=http://127.0.0.1:10001/devstoreaccount1;" + + "TableEndpoint=http://127.0.0.1:10002/devstoreaccount1;", ) EMULATOR_RUNNING = False +def get_storage(): + return BlobStorage(BLOB_STORAGE_SETTINGS) + + async def reset(): - storage = BlobStorage(BLOB_STORAGE_SETTINGS) + storage = get_storage() try: await storage.client.delete_container( container_name=BLOB_STORAGE_SETTINGS.container_name @@ -29,7 +42,7 @@ def __init__(self, counter=1, e_tag="*"): self.e_tag = e_tag -class TestBlobStorage: +class TestBlobStorageConstructor: @pytest.mark.asyncio async def test_blob_storage_init_should_error_without_cosmos_db_config(self): try: @@ -37,17 +50,104 @@ async def test_blob_storage_init_should_error_without_cosmos_db_config(self): except Exception as error: assert error + +class TestBlobStorageBaseTests: @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") @pytest.mark.asyncio - async def test_blob_storage_read_should_return_data_with_valid_key(self): - storage = BlobStorage(BLOB_STORAGE_SETTINGS) - await storage.write({"user": SimpleStoreItem()}) + async def test_return_empty_object_when_reading_unknown_key(self): + await reset() - data = await storage.read(["user"]) - assert "user" in data - assert data["user"].counter == 1 - assert len(data.keys()) == 1 + test_ran = await StorageBaseTests.return_empty_object_when_reading_unknown_key( + get_storage() + ) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_handle_null_keys_when_reading(self): + await reset() + + test_ran = await StorageBaseTests.handle_null_keys_when_reading(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_handle_null_keys_when_writing(self): + await reset() + + test_ran = await StorageBaseTests.handle_null_keys_when_writing(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_does_not_raise_when_writing_no_items(self): + await reset() + + test_ran = await StorageBaseTests.does_not_raise_when_writing_no_items( + get_storage() + ) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_create_object(self): + await reset() + + test_ran = await StorageBaseTests.create_object(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_handle_crazy_keys(self): + await reset() + + test_ran = await StorageBaseTests.handle_crazy_keys(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_update_object(self): + await reset() + test_ran = await StorageBaseTests.update_object(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_delete_object(self): + await reset() + + test_ran = await StorageBaseTests.delete_object(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_perform_batch_operations(self): + await reset() + + test_ran = await StorageBaseTests.perform_batch_operations(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_proceeds_through_waterfall(self): + await reset() + + test_ran = await StorageBaseTests.proceeds_through_waterfall(get_storage()) + + assert test_ran + + +class TestBlobStorage: @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") @pytest.mark.asyncio async def test_blob_storage_read_update_should_return_new_etag(self): @@ -60,25 +160,6 @@ async def test_blob_storage_read_update_should_return_new_etag(self): assert data_updated["test"].counter == 2 assert data_updated["test"].e_tag != data_result["test"].e_tag - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") - @pytest.mark.asyncio - async def test_blob_storage_read_no_key_should_throw(self): - try: - storage = BlobStorage(BLOB_STORAGE_SETTINGS) - await storage.read([]) - except Exception as error: - assert error - - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") - @pytest.mark.asyncio - async def test_blob_storage_write_should_add_new_value(self): - storage = BlobStorage(BLOB_STORAGE_SETTINGS) - await storage.write({"user": SimpleStoreItem(counter=1)}) - - data = await storage.read(["user"]) - assert "user" in data - assert data["user"].counter == 1 - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") @pytest.mark.asyncio async def test_blob_storage_write_should_overwrite_when_new_e_tag_is_an_asterisk( @@ -91,32 +172,6 @@ async def test_blob_storage_write_should_overwrite_when_new_e_tag_is_an_asterisk data = await storage.read(["user"]) assert data["user"].counter == 10 - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") - @pytest.mark.asyncio - async def test_blob_storage_write_batch_operation(self): - storage = BlobStorage(BLOB_STORAGE_SETTINGS) - await storage.write( - { - "batch1": SimpleStoreItem(counter=1), - "batch2": SimpleStoreItem(counter=1), - "batch3": SimpleStoreItem(counter=1), - } - ) - data = await storage.read(["batch1", "batch2", "batch3"]) - assert len(data.keys()) == 3 - assert data["batch1"] - assert data["batch2"] - assert data["batch3"] - assert data["batch1"].counter == 1 - assert data["batch2"].counter == 1 - assert data["batch3"].counter == 1 - assert data["batch1"].e_tag - assert data["batch2"].e_tag - assert data["batch3"].e_tag - await storage.delete(["batch1", "batch2", "batch3"]) - data = await storage.read(["batch1", "batch2", "batch3"]) - assert not data.keys() - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") @pytest.mark.asyncio async def test_blob_storage_delete_should_delete_according_cached_data(self): diff --git a/libraries/botbuilder-azure/tests/test_cosmos_partitioned_storage.py b/libraries/botbuilder-azure/tests/test_cosmos_partitioned_storage.py new file mode 100644 index 000000000..cb6dd0822 --- /dev/null +++ b/libraries/botbuilder-azure/tests/test_cosmos_partitioned_storage.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import azure.cosmos.errors as cosmos_errors +from azure.cosmos import documents +import pytest +from botbuilder.azure import CosmosDbPartitionedStorage, CosmosDbPartitionedConfig +from botbuilder.testing import StorageBaseTests + +EMULATOR_RUNNING = False + + +def get_settings() -> CosmosDbPartitionedConfig: + return CosmosDbPartitionedConfig( + cosmos_db_endpoint="https://localhost:8081", + auth_key="C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==", + database_id="test-db", + container_id="bot-storage", + ) + + +def get_storage(): + return CosmosDbPartitionedStorage(get_settings()) + + +async def reset(): + storage = CosmosDbPartitionedStorage(get_settings()) + await storage.initialize() + try: + storage.client.DeleteDatabase(database_link="dbs/" + get_settings().database_id) + except cosmos_errors.HTTPFailure: + pass + + +class TestCosmosDbPartitionedStorageConstructor: + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_raises_error_when_instantiated_with_no_arguments(self): + try: + # noinspection PyArgumentList + # pylint: disable=no-value-for-parameter + CosmosDbPartitionedStorage() + except Exception as error: + assert error + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_raises_error_when_no_endpoint_provided(self): + no_endpoint = get_settings() + no_endpoint.cosmos_db_endpoint = None + try: + CosmosDbPartitionedStorage(no_endpoint) + except Exception as error: + assert error + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_raises_error_when_no_auth_key_provided(self): + no_auth_key = get_settings() + no_auth_key.auth_key = None + try: + CosmosDbPartitionedStorage(no_auth_key) + except Exception as error: + assert error + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_raises_error_when_no_database_id_provided(self): + no_database_id = get_settings() + no_database_id.database_id = None + try: + CosmosDbPartitionedStorage(no_database_id) + except Exception as error: + assert error + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_raises_error_when_no_container_id_provided(self): + no_container_id = get_settings() + no_container_id.container_id = None + try: + CosmosDbPartitionedStorage(no_container_id) + except Exception as error: + assert error + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_passes_cosmos_client_options(self): + settings_with_options = get_settings() + + connection_policy = documents.ConnectionPolicy() + connection_policy.DisableSSLVerification = True + + settings_with_options.cosmos_client_options = { + "connection_policy": connection_policy, + "consistency_level": documents.ConsistencyLevel.Eventual, + } + + client = CosmosDbPartitionedStorage(settings_with_options) + await client.initialize() + + assert client.client.connection_policy.DisableSSLVerification is True + assert ( + client.client.default_headers["x-ms-consistency-level"] + == documents.ConsistencyLevel.Eventual + ) + + +class TestCosmosDbPartitionedStorageBaseStorageTests: + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_return_empty_object_when_reading_unknown_key(self): + await reset() + + test_ran = await StorageBaseTests.return_empty_object_when_reading_unknown_key( + get_storage() + ) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_handle_null_keys_when_reading(self): + await reset() + + test_ran = await StorageBaseTests.handle_null_keys_when_reading(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_handle_null_keys_when_writing(self): + await reset() + + test_ran = await StorageBaseTests.handle_null_keys_when_writing(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_does_not_raise_when_writing_no_items(self): + await reset() + + test_ran = await StorageBaseTests.does_not_raise_when_writing_no_items( + get_storage() + ) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_create_object(self): + await reset() + + test_ran = await StorageBaseTests.create_object(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_handle_crazy_keys(self): + await reset() + + test_ran = await StorageBaseTests.handle_crazy_keys(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_update_object(self): + await reset() + + test_ran = await StorageBaseTests.update_object(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_delete_object(self): + await reset() + + test_ran = await StorageBaseTests.delete_object(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_perform_batch_operations(self): + await reset() + + test_ran = await StorageBaseTests.perform_batch_operations(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_proceeds_through_waterfall(self): + await reset() + + test_ran = await StorageBaseTests.proceeds_through_waterfall(get_storage()) + + assert test_ran diff --git a/libraries/botbuilder-azure/tests/test_cosmos_storage.py b/libraries/botbuilder-azure/tests/test_cosmos_storage.py index a9bfe5191..c66660857 100644 --- a/libraries/botbuilder-azure/tests/test_cosmos_storage.py +++ b/libraries/botbuilder-azure/tests/test_cosmos_storage.py @@ -7,6 +7,7 @@ import pytest from botbuilder.core import StoreItem from botbuilder.azure import CosmosDbStorage, CosmosDbConfig +from botbuilder.testing import StorageBaseTests # local cosmosdb emulator instance cosmos_db_config COSMOS_DB_CONFIG = CosmosDbConfig( @@ -18,6 +19,10 @@ EMULATOR_RUNNING = False +def get_storage(): + return CosmosDbStorage(COSMOS_DB_CONFIG) + + async def reset(): storage = CosmosDbStorage(COSMOS_DB_CONFIG) try: @@ -50,7 +55,7 @@ def __init__(self, counter=1, e_tag="*"): self.e_tag = e_tag -class TestCosmosDbStorage: +class TestCosmosDbStorageConstructor: @pytest.mark.asyncio async def test_cosmos_storage_init_should_error_without_cosmos_db_config(self): try: @@ -59,7 +64,7 @@ async def test_cosmos_storage_init_should_error_without_cosmos_db_config(self): assert error @pytest.mark.asyncio - async def test_creation_request_options_era_being_called(self): + async def test_creation_request_options_are_being_called(self): # pylint: disable=protected-access test_config = CosmosDbConfig( endpoint="https://localhost:8081", @@ -86,6 +91,104 @@ async def test_creation_request_options_era_being_called(self): "dbs/" + test_id, {"id": test_id}, test_config.container_creation_options ) + +class TestCosmosDbStorageBaseStorageTests: + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_return_empty_object_when_reading_unknown_key(self): + await reset() + + test_ran = await StorageBaseTests.return_empty_object_when_reading_unknown_key( + get_storage() + ) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_handle_null_keys_when_reading(self): + await reset() + + test_ran = await StorageBaseTests.handle_null_keys_when_reading(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_handle_null_keys_when_writing(self): + await reset() + + test_ran = await StorageBaseTests.handle_null_keys_when_writing(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_does_not_raise_when_writing_no_items(self): + await reset() + + test_ran = await StorageBaseTests.does_not_raise_when_writing_no_items( + get_storage() + ) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_create_object(self): + await reset() + + test_ran = await StorageBaseTests.create_object(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_handle_crazy_keys(self): + await reset() + + test_ran = await StorageBaseTests.handle_crazy_keys(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_update_object(self): + await reset() + + test_ran = await StorageBaseTests.update_object(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_delete_object(self): + await reset() + + test_ran = await StorageBaseTests.delete_object(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_perform_batch_operations(self): + await reset() + + test_ran = await StorageBaseTests.perform_batch_operations(get_storage()) + + assert test_ran + + @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") + @pytest.mark.asyncio + async def test_proceeds_through_waterfall(self): + await reset() + + test_ran = await StorageBaseTests.proceeds_through_waterfall(get_storage()) + + assert test_ran + + +class TestCosmosDbStorage: @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") @pytest.mark.asyncio async def test_cosmos_storage_init_should_work_with_just_endpoint_and_key(self): @@ -100,18 +203,6 @@ async def test_cosmos_storage_init_should_work_with_just_endpoint_and_key(self): assert data["user"].counter == 1 assert len(data.keys()) == 1 - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") - @pytest.mark.asyncio - async def test_cosmos_storage_read_should_return_data_with_valid_key(self): - await reset() - storage = CosmosDbStorage(COSMOS_DB_CONFIG) - await storage.write({"user": SimpleStoreItem()}) - - data = await storage.read(["user"]) - assert "user" in data - assert data["user"].counter == 1 - assert len(data.keys()) == 1 - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") @pytest.mark.asyncio async def test_cosmos_storage_read_update_should_return_new_etag(self): @@ -135,27 +226,6 @@ async def test_cosmos_storage_read_with_invalid_key_should_return_empty_dict(sel assert isinstance(data, dict) assert not data.keys() - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") - @pytest.mark.asyncio - async def test_cosmos_storage_read_no_key_should_throw(self): - try: - await reset() - storage = CosmosDbStorage(COSMOS_DB_CONFIG) - await storage.read([]) - except Exception as error: - assert error - - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") - @pytest.mark.asyncio - async def test_cosmos_storage_write_should_add_new_value(self): - await reset() - storage = CosmosDbStorage(COSMOS_DB_CONFIG) - await storage.write({"user": SimpleStoreItem(counter=1)}) - - data = await storage.read(["user"]) - assert "user" in data - assert data["user"].counter == 1 - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") @pytest.mark.asyncio async def test_cosmos_storage_write_should_overwrite_when_new_e_tag_is_an_asterisk( @@ -169,62 +239,6 @@ async def test_cosmos_storage_write_should_overwrite_when_new_e_tag_is_an_asteri data = await storage.read(["user"]) assert data["user"].counter == 10 - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") - @pytest.mark.asyncio - async def test_cosmos_storage_write_batch_operation(self): - await reset() - storage = CosmosDbStorage(COSMOS_DB_CONFIG) - await storage.write( - { - "batch1": SimpleStoreItem(counter=1), - "batch2": SimpleStoreItem(counter=1), - "batch3": SimpleStoreItem(counter=1), - } - ) - data = await storage.read(["batch1", "batch2", "batch3"]) - assert len(data.keys()) == 3 - assert data["batch1"] - assert data["batch2"] - assert data["batch3"] - assert data["batch1"].counter == 1 - assert data["batch2"].counter == 1 - assert data["batch3"].counter == 1 - assert data["batch1"].e_tag - assert data["batch2"].e_tag - assert data["batch3"].e_tag - await storage.delete(["batch1", "batch2", "batch3"]) - data = await storage.read(["batch1", "batch2", "batch3"]) - assert not data.keys() - - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") - @pytest.mark.asyncio - async def test_cosmos_storage_write_crazy_keys_work(self): - await reset() - storage = CosmosDbStorage(COSMOS_DB_CONFIG) - crazy_key = '!@#$%^&*()_+??><":QASD~`' - await storage.write({crazy_key: SimpleStoreItem(counter=1)}) - data = await storage.read([crazy_key]) - assert len(data.keys()) == 1 - assert data[crazy_key] - assert data[crazy_key].counter == 1 - assert data[crazy_key].e_tag - - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") - @pytest.mark.asyncio - async def test_cosmos_storage_delete_should_delete_according_cached_data(self): - await reset() - storage = CosmosDbStorage(COSMOS_DB_CONFIG) - await storage.write({"test": SimpleStoreItem()}) - try: - await storage.delete(["test"]) - except Exception as error: - raise error - else: - data = await storage.read(["test"]) - - assert isinstance(data, dict) - assert not data.keys() - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") @pytest.mark.asyncio async def test_cosmos_storage_delete_should_delete_multiple_values_when_given_multiple_valid_keys( diff --git a/libraries/botbuilder-core/botbuilder/core/memory_storage.py b/libraries/botbuilder-core/botbuilder/core/memory_storage.py index 73ff77bc4..b85b3d368 100644 --- a/libraries/botbuilder-core/botbuilder/core/memory_storage.py +++ b/libraries/botbuilder-core/botbuilder/core/memory_storage.py @@ -22,6 +22,8 @@ async def delete(self, keys: List[str]): async def read(self, keys: List[str]): data = {} + if not keys: + return data try: for key in keys: if key in self.memory: @@ -32,10 +34,14 @@ async def read(self, keys: List[str]): return data async def write(self, changes: Dict[str, StoreItem]): + if changes is None: + raise Exception("Changes are required when writing") + if not changes: + return try: # iterate over the changes for (key, change) in changes.items(): - new_value = change + new_value = deepcopy(change) old_state_etag = None # Check if the a matching key already exists in self.memory @@ -43,26 +49,35 @@ async def write(self, changes: Dict[str, StoreItem]): if key in self.memory: old_state = self.memory[key] if not isinstance(old_state, StoreItem): - if "eTag" in old_state: - old_state_etag = old_state["eTag"] + old_state_etag = old_state.get("e_tag", None) elif old_state.e_tag: old_state_etag = old_state.e_tag new_state = new_value # Set ETag if applicable - if hasattr(new_value, "e_tag"): - if ( - old_state_etag is not None - and new_value.e_tag != "*" - and new_value.e_tag < old_state_etag - ): - raise KeyError( - "Etag conflict.\nOriginal: %s\r\nCurrent: %s" - % (new_value.e_tag, old_state_etag) - ) + new_value_etag = ( + new_value.e_tag + if hasattr(new_value, "e_tag") + else new_value.get("e_tag", None) + ) + if new_value_etag == "": + raise Exception("blob_storage.write(): etag missing") + if ( + old_state_etag is not None + and new_value_etag is not None + and new_value_etag != "*" + and new_value_etag < old_state_etag + ): + raise KeyError( + "Etag conflict.\nOriginal: %s\r\nCurrent: %s" + % (new_value_etag, old_state_etag) + ) + if hasattr(new_state, "e_tag"): new_state.e_tag = str(self._e_tag) - self._e_tag += 1 + else: + new_state["e_tag"] = str(self._e_tag) + self._e_tag += 1 self.memory[key] = deepcopy(new_state) except Exception as error: diff --git a/libraries/botbuilder-core/tests/test_memory_storage.py b/libraries/botbuilder-core/tests/test_memory_storage.py index 63946ad60..a34e2a94e 100644 --- a/libraries/botbuilder-core/tests/test_memory_storage.py +++ b/libraries/botbuilder-core/tests/test_memory_storage.py @@ -1,9 +1,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import aiounittest +import pytest from botbuilder.core import MemoryStorage, StoreItem +from botbuilder.testing import StorageBaseTests + + +def get_storage(): + return MemoryStorage() class SimpleStoreItem(StoreItem): @@ -13,7 +18,7 @@ def __init__(self, counter=1, e_tag="*"): self.e_tag = e_tag -class TestMemoryStorage(aiounittest.AsyncTestCase): +class TestMemoryStorageConstructor: def test_initializing_memory_storage_without_data_should_still_have_memory(self): storage = MemoryStorage() assert storage.memory is not None @@ -23,6 +28,7 @@ def test_memory_storage__e_tag_should_start_at_0(self): storage = MemoryStorage() assert storage._e_tag == 0 # pylint: disable=protected-access + @pytest.mark.asyncio async def test_memory_storage_initialized_with_memory_should_have_accessible_data( self, ): @@ -32,26 +38,75 @@ async def test_memory_storage_initialized_with_memory_should_have_accessible_dat assert data["test"].counter == 1 assert len(data.keys()) == 1 - async def test_memory_storage_read_should_return_data_with_valid_key(self): - storage = MemoryStorage() - await storage.write({"user": SimpleStoreItem()}) - data = await storage.read(["user"]) - assert "user" in data - assert data["user"].counter == 1 - assert len(data.keys()) == 1 - assert storage._e_tag == 1 # pylint: disable=protected-access - assert int(data["user"].e_tag) == 0 +class TestMemoryStorageBaseTests: + @pytest.mark.asyncio + async def test_return_empty_object_when_reading_unknown_key(self): + test_ran = await StorageBaseTests.return_empty_object_when_reading_unknown_key( + get_storage() + ) - async def test_memory_storage_write_should_add_new_value(self): - storage = MemoryStorage() - aux = {"user": SimpleStoreItem(counter=1)} - await storage.write(aux) + assert test_ran + + @pytest.mark.asyncio + async def test_handle_null_keys_when_reading(self): + test_ran = await StorageBaseTests.handle_null_keys_when_reading(get_storage()) + + assert test_ran + + @pytest.mark.asyncio + async def test_handle_null_keys_when_writing(self): + test_ran = await StorageBaseTests.handle_null_keys_when_writing(get_storage()) + + assert test_ran + + @pytest.mark.asyncio + async def test_does_not_raise_when_writing_no_items(self): + test_ran = await StorageBaseTests.does_not_raise_when_writing_no_items( + get_storage() + ) + + assert test_ran + + @pytest.mark.asyncio + async def test_create_object(self): + test_ran = await StorageBaseTests.create_object(get_storage()) + + assert test_ran + + @pytest.mark.asyncio + async def test_handle_crazy_keys(self): + test_ran = await StorageBaseTests.handle_crazy_keys(get_storage()) + + assert test_ran + + @pytest.mark.asyncio + async def test_update_object(self): + test_ran = await StorageBaseTests.update_object(get_storage()) + + assert test_ran + + @pytest.mark.asyncio + async def test_delete_object(self): + test_ran = await StorageBaseTests.delete_object(get_storage()) + + assert test_ran + + @pytest.mark.asyncio + async def test_perform_batch_operations(self): + test_ran = await StorageBaseTests.perform_batch_operations(get_storage()) + + assert test_ran + + @pytest.mark.asyncio + async def test_proceeds_through_waterfall(self): + test_ran = await StorageBaseTests.proceeds_through_waterfall(get_storage()) + + assert test_ran - data = await storage.read(["user"]) - assert "user" in data - assert data["user"].counter == 1 +class TestMemoryStorage: + @pytest.mark.asyncio async def test_memory_storage_write_should_overwrite_when_new_e_tag_is_an_asterisk_1( self, ): @@ -62,6 +117,7 @@ async def test_memory_storage_write_should_overwrite_when_new_e_tag_is_an_asteri data = await storage.read(["user"]) assert data["user"].counter == 10 + @pytest.mark.asyncio async def test_memory_storage_write_should_overwrite_when_new_e_tag_is_an_asterisk_2( self, ): @@ -72,6 +128,7 @@ async def test_memory_storage_write_should_overwrite_when_new_e_tag_is_an_asteri data = await storage.read(["user"]) assert data["user"].counter == 5 + @pytest.mark.asyncio async def test_memory_storage_read_with_invalid_key_should_return_empty_dict(self): storage = MemoryStorage() data = await storage.read(["test"]) @@ -79,6 +136,7 @@ async def test_memory_storage_read_with_invalid_key_should_return_empty_dict(sel assert isinstance(data, dict) assert not data.keys() + @pytest.mark.asyncio async def test_memory_storage_delete_should_delete_according_cached_data(self): storage = MemoryStorage({"test": "test"}) try: @@ -91,6 +149,7 @@ async def test_memory_storage_delete_should_delete_according_cached_data(self): assert isinstance(data, dict) assert not data.keys() + @pytest.mark.asyncio async def test_memory_storage_delete_should_delete_multiple_values_when_given_multiple_valid_keys( self, ): @@ -102,6 +161,7 @@ async def test_memory_storage_delete_should_delete_multiple_values_when_given_mu data = await storage.read(["test", "test2"]) assert not data.keys() + @pytest.mark.asyncio async def test_memory_storage_delete_should_delete_values_when_given_multiple_valid_keys_and_ignore_other_data( self, ): @@ -117,6 +177,7 @@ async def test_memory_storage_delete_should_delete_values_when_given_multiple_va data = await storage.read(["test", "test2", "test3"]) assert len(data.keys()) == 1 + @pytest.mark.asyncio async def test_memory_storage_delete_invalid_key_should_do_nothing_and_not_affect_cached_data( self, ): @@ -128,6 +189,7 @@ async def test_memory_storage_delete_invalid_key_should_do_nothing_and_not_affec data = await storage.read(["foo"]) assert not data.keys() + @pytest.mark.asyncio async def test_memory_storage_delete_invalid_keys_should_do_nothing_and_not_affect_cached_data( self, ): diff --git a/libraries/botbuilder-testing/botbuilder/testing/__init__.py b/libraries/botbuilder-testing/botbuilder/testing/__init__.py index 681a168e4..af82e1a65 100644 --- a/libraries/botbuilder-testing/botbuilder/testing/__init__.py +++ b/libraries/botbuilder-testing/botbuilder/testing/__init__.py @@ -3,6 +3,7 @@ from .dialog_test_client import DialogTestClient from .dialog_test_logger import DialogTestLogger +from .storage_base_tests import StorageBaseTests -__all__ = ["DialogTestClient", "DialogTestLogger"] +__all__ = ["DialogTestClient", "DialogTestLogger", "StorageBaseTests"] diff --git a/libraries/botbuilder-testing/botbuilder/testing/storage_base_tests.py b/libraries/botbuilder-testing/botbuilder/testing/storage_base_tests.py new file mode 100644 index 000000000..defa5040f --- /dev/null +++ b/libraries/botbuilder-testing/botbuilder/testing/storage_base_tests.py @@ -0,0 +1,337 @@ +""" +Base tests that all storage providers should implement in their own tests. +They handle the storage-based assertions, internally. + +All tests return true if assertions pass to indicate that the code ran to completion, passing internal assertions. +Therefore, all tests using theses static tests should strictly check that the method returns true. + +:Example: + async def test_handle_null_keys_when_reading(self): + await reset() + + test_ran = await StorageBaseTests.handle_null_keys_when_reading(get_storage()) + + assert test_ran +""" +import pytest +from botbuilder.azure import CosmosDbStorage +from botbuilder.core import ( + ConversationState, + TurnContext, + MessageFactory, + MemoryStorage, +) +from botbuilder.core.adapters import TestAdapter +from botbuilder.dialogs import ( + DialogSet, + DialogTurnStatus, + TextPrompt, + PromptValidatorContext, + WaterfallStepContext, + Dialog, + WaterfallDialog, + PromptOptions, +) + + +class StorageBaseTests: + @staticmethod + async def return_empty_object_when_reading_unknown_key(storage) -> bool: + result = await storage.read(["unknown"]) + + assert result is not None + assert len(result) == 0 + + return True + + @staticmethod + async def handle_null_keys_when_reading(storage) -> bool: + if isinstance(storage, (CosmosDbStorage, MemoryStorage)): + result = await storage.read(None) + assert len(result.keys()) == 0 + # Catch-all + else: + with pytest.raises(Exception) as err: + await storage.read(None) + assert err.value.args[0] == "Keys are required when reading" + + return True + + @staticmethod + async def handle_null_keys_when_writing(storage) -> bool: + with pytest.raises(Exception) as err: + await storage.write(None) + assert err.value.args[0] == "Changes are required when writing" + + return True + + @staticmethod + async def does_not_raise_when_writing_no_items(storage) -> bool: + # noinspection PyBroadException + try: + await storage.write([]) + except: + pytest.fail("Should not raise") + + return True + + @staticmethod + async def create_object(storage) -> bool: + store_items = { + "createPoco": {"id": 1}, + "createPocoStoreItem": {"id": 2}, + } + + await storage.write(store_items) + + read_store_items = await storage.read(store_items.keys()) + + assert store_items["createPoco"]["id"] == read_store_items["createPoco"]["id"] + assert ( + store_items["createPocoStoreItem"]["id"] + == read_store_items["createPocoStoreItem"]["id"] + ) + assert read_store_items["createPoco"]["e_tag"] is not None + assert read_store_items["createPocoStoreItem"]["e_tag"] is not None + + return True + + @staticmethod + async def handle_crazy_keys(storage) -> bool: + key = '!@#$%^&*()_+??><":QASD~`' + store_item = {"id": 1} + store_items = {key: store_item} + + await storage.write(store_items) + + read_store_items = await storage.read(store_items.keys()) + + assert read_store_items[key] is not None + assert read_store_items[key]["id"] == 1 + + return True + + @staticmethod + async def update_object(storage) -> bool: + original_store_items = { + "pocoItem": {"id": 1, "count": 1}, + "pocoStoreItem": {"id": 1, "count": 1}, + } + + # 1st write should work + await storage.write(original_store_items) + + loaded_store_items = await storage.read(["pocoItem", "pocoStoreItem"]) + + update_poco_item = loaded_store_items["pocoItem"] + update_poco_item["e_tag"] = None + update_poco_store_item = loaded_store_items["pocoStoreItem"] + assert update_poco_store_item["e_tag"] is not None + + # 2nd write should work + update_poco_item["count"] += 1 + update_poco_store_item["count"] += 1 + + await storage.write(loaded_store_items) + + reloaded_store_items = await storage.read(loaded_store_items.keys()) + + reloaded_update_poco_item = reloaded_store_items["pocoItem"] + reloaded_update_poco_store_item = reloaded_store_items["pocoStoreItem"] + + assert reloaded_update_poco_item["e_tag"] is not None + assert ( + update_poco_store_item["e_tag"] != reloaded_update_poco_store_item["e_tag"] + ) + assert reloaded_update_poco_item["count"] == 2 + assert reloaded_update_poco_store_item["count"] == 2 + + # Write with old e_tag should succeed for non-storeItem + update_poco_item["count"] = 123 + await storage.write({"pocoItem": update_poco_item}) + + # Write with old eTag should FAIL for storeItem + update_poco_store_item["count"] = 123 + + with pytest.raises(Exception) as err: + await storage.write({"pocoStoreItem": update_poco_store_item}) + assert err.value is not None + + reloaded_store_items2 = await storage.read(["pocoItem", "pocoStoreItem"]) + + reloaded_poco_item2 = reloaded_store_items2["pocoItem"] + reloaded_poco_item2["e_tag"] = None + reloaded_poco_store_item2 = reloaded_store_items2["pocoStoreItem"] + + assert reloaded_poco_item2["count"] == 123 + assert reloaded_poco_store_item2["count"] == 2 + + # write with wildcard etag should work + reloaded_poco_item2["count"] = 100 + reloaded_poco_store_item2["count"] = 100 + reloaded_poco_store_item2["e_tag"] = "*" + + wildcard_etag_dict = { + "pocoItem": reloaded_poco_item2, + "pocoStoreItem": reloaded_poco_store_item2, + } + + await storage.write(wildcard_etag_dict) + + reloaded_store_items3 = await storage.read(["pocoItem", "pocoStoreItem"]) + + assert reloaded_store_items3["pocoItem"]["count"] == 100 + assert reloaded_store_items3["pocoStoreItem"]["count"] == 100 + + # Write with empty etag should not work + reloaded_store_items4 = await storage.read(["pocoStoreItem"]) + reloaded_store_item4 = reloaded_store_items4["pocoStoreItem"] + + assert reloaded_store_item4 is not None + + reloaded_store_item4["e_tag"] = "" + dict2 = {"pocoStoreItem": reloaded_store_item4} + + with pytest.raises(Exception) as err: + await storage.write(dict2) + assert err.value is not None + + final_store_items = await storage.read(["pocoItem", "pocoStoreItem"]) + assert final_store_items["pocoItem"]["count"] == 100 + assert final_store_items["pocoStoreItem"]["count"] == 100 + + return True + + @staticmethod + async def delete_object(storage) -> bool: + store_items = {"delete1": {"id": 1, "count": 1}} + + await storage.write(store_items) + + read_store_items = await storage.read(["delete1"]) + + assert read_store_items["delete1"]["e_tag"] + assert read_store_items["delete1"]["count"] == 1 + + await storage.delete(["delete1"]) + + reloaded_store_items = await storage.read(["delete1"]) + + assert reloaded_store_items.get("delete1", None) is None + + return True + + @staticmethod + async def delete_unknown_object(storage) -> bool: + # noinspection PyBroadException + try: + await storage.delete(["unknown_key"]) + except: + pytest.fail("Should not raise") + + return True + + @staticmethod + async def perform_batch_operations(storage) -> bool: + await storage.write( + {"batch1": {"count": 10}, "batch2": {"count": 20}, "batch3": {"count": 30},} + ) + + result = await storage.read(["batch1", "batch2", "batch3"]) + + assert result.get("batch1", None) is not None + assert result.get("batch2", None) is not None + assert result.get("batch3", None) is not None + assert result["batch1"]["count"] == 10 + assert result["batch2"]["count"] == 20 + assert result["batch3"]["count"] == 30 + assert result["batch1"].get("e_tag", None) is not None + assert result["batch2"].get("e_tag", None) is not None + assert result["batch3"].get("e_tag", None) is not None + + await storage.delete(["batch1", "batch2", "batch3"]) + + result = await storage.read(["batch1", "batch2", "batch3"]) + + assert result.get("batch1", None) is None + assert result.get("batch2", None) is None + assert result.get("batch3", None) is None + + return True + + @staticmethod + async def proceeds_through_waterfall(storage) -> bool: + convo_state = ConversationState(storage) + + dialog_state = convo_state.create_property("dialogState") + dialogs = DialogSet(dialog_state) + + async def exec_test(turn_context: TurnContext) -> None: + dialog_context = await dialogs.create_context(turn_context) + + await dialog_context.continue_dialog() + if not turn_context.responded: + await dialog_context.begin_dialog(WaterfallDialog.__name__) + await convo_state.save_changes(turn_context) + + adapter = TestAdapter(exec_test) + + async def prompt_validator(prompt_context: PromptValidatorContext): + result = prompt_context.recognized.value + if len(result) > 3: + succeeded_message = MessageFactory.text( + f"You got it at the {prompt_context.options.number_of_attempts}rd try!" + ) + await prompt_context.context.send_activity(succeeded_message) + return True + + reply = MessageFactory.text( + f"Please send a name that is longer than 3 characters. {prompt_context.options.number_of_attempts}" + ) + await prompt_context.context.send_activity(reply) + return False + + async def step_1(step_context: WaterfallStepContext) -> DialogTurnStatus: + assert isinstance(step_context.active_dialog.state["stepIndex"], int) + await step_context.context.send_activity("step1") + return Dialog.end_of_turn + + async def step_2(step_context: WaterfallStepContext) -> None: + assert isinstance(step_context.active_dialog.state["stepIndex"], int) + await step_context.prompt( + TextPrompt.__name__, + PromptOptions(prompt=MessageFactory.text("Please type your name")), + ) + + async def step_3(step_context: WaterfallStepContext) -> DialogTurnStatus: + assert isinstance(step_context.active_dialog.state["stepIndex"], int) + await step_context.context.send_activity("step3") + return Dialog.end_of_turn + + steps = [step_1, step_2, step_3] + + dialogs.add(WaterfallDialog(WaterfallDialog.__name__, steps)) + + dialogs.add(TextPrompt(TextPrompt.__name__, prompt_validator)) + + step1 = await adapter.send("hello") + step2 = await step1.assert_reply("step1") + step3 = await step2.send("hello") + step4 = await step3.assert_reply("Please type your name") # None + step5 = await step4.send("hi") + step6 = await step5.assert_reply( + "Please send a name that is longer than 3 characters. 0" + ) + step7 = await step6.send("hi") + step8 = await step7.assert_reply( + "Please send a name that is longer than 3 characters. 1" + ) + step9 = await step8.send("hi") + step10 = await step9.assert_reply( + "Please send a name that is longer than 3 characters. 2" + ) + step11 = await step10.send("Kyle") + step12 = await step11.assert_reply("You got it at the 3rd try!") + await step12.assert_reply("step3") + + return True