diff --git a/libraries/botbuilder-dialogs/botbuilder/dialogs/prompts/oauth_prompt.py b/libraries/botbuilder-dialogs/botbuilder/dialogs/prompts/oauth_prompt.py index 588b22c38..c9d8bb5a9 100644 --- a/libraries/botbuilder-dialogs/botbuilder/dialogs/prompts/oauth_prompt.py +++ b/libraries/botbuilder-dialogs/botbuilder/dialogs/prompts/oauth_prompt.py @@ -203,13 +203,17 @@ async def continue_dialog(self, dialog_context: DialogContext) -> DialogTurnResu The prompt generally continues to receive the user's replies until it accepts the user's reply as valid input for the prompt. """ - # Recognize token - recognized = await self._recognize_token(dialog_context) - # Check for timeout state = dialog_context.active_dialog.state is_message = dialog_context.context.activity.type == ActivityTypes.message - has_timed_out = is_message and ( + is_timeout_activity_type = ( + is_message + or OAuthPrompt._is_token_response_event(dialog_context.context) + or OAuthPrompt._is_teams_verification_invoke(dialog_context.context) + or OAuthPrompt._is_token_exchange_request_invoke(dialog_context.context) + ) + + has_timed_out = is_timeout_activity_type and ( datetime.now() > state[OAuthPrompt.PERSISTED_EXPIRES] ) @@ -221,6 +225,9 @@ async def continue_dialog(self, dialog_context: DialogContext) -> DialogTurnResu else: state["state"]["attemptCount"] += 1 + # Recognize token + recognized = await self._recognize_token(dialog_context) + # Validate the return value is_valid = False if self._validator is not None: @@ -238,6 +245,9 @@ async def continue_dialog(self, dialog_context: DialogContext) -> DialogTurnResu # Return recognized value or re-prompt if is_valid: return await dialog_context.end_dialog(recognized.value) + if is_message and self._settings.end_on_invalid_message: + # If EndOnInvalidMessage is set, complete the prompt with no result. + return await dialog_context.end_dialog(None) # Send retry prompt if ( diff --git a/libraries/botbuilder-dialogs/botbuilder/dialogs/prompts/oauth_prompt_settings.py b/libraries/botbuilder-dialogs/botbuilder/dialogs/prompts/oauth_prompt_settings.py index 1d8f04eca..c071c590e 100644 --- a/libraries/botbuilder-dialogs/botbuilder/dialogs/prompts/oauth_prompt_settings.py +++ b/libraries/botbuilder-dialogs/botbuilder/dialogs/prompts/oauth_prompt_settings.py @@ -11,6 +11,7 @@ def __init__( text: str = None, timeout: int = None, oauth_app_credentials: AppCredentials = None, + end_on_invalid_message: bool = False, ): """ Settings used to configure an `OAuthPrompt` instance. @@ -22,9 +23,15 @@ def __init__( `OAuthPrompt` defaults value to `900,000` ms (15 minutes). oauth_app_credentials (AppCredentials): (Optional) AppCredentials to use for OAuth. If None, the Bots credentials are used. + end_on_invalid_message (bool): (Optional) value indicating whether the OAuthPrompt should end upon + receiving an invalid message. Generally the OAuthPrompt will ignore incoming messages from the + user during the auth flow, if they are not related to the auth flow. This flag enables ending the + OAuthPrompt rather than ignoring the user's message. Typically, this flag will be set to 'true', + but is 'false' by default for backwards compatibility. """ self.connection_name = connection_name self.title = title self.text = text self.timeout = timeout self.oath_app_credentials = oauth_app_credentials + self.end_on_invalid_message = end_on_invalid_message diff --git a/libraries/botbuilder-dialogs/tests/test_oauth_prompt.py b/libraries/botbuilder-dialogs/tests/test_oauth_prompt.py index a5802103a..a6b22553b 100644 --- a/libraries/botbuilder-dialogs/tests/test_oauth_prompt.py +++ b/libraries/botbuilder-dialogs/tests/test_oauth_prompt.py @@ -9,6 +9,7 @@ ChannelAccount, ConversationAccount, InputHints, + SignInConstants, TokenResponse, ) @@ -260,3 +261,156 @@ async def callback_handler(turn_context: TurnContext): await adapter.send("Hello") self.assertTrue(called) + + async def test_should_end_oauth_prompt_on_invalid_message_when_end_on_invalid_message( + self, + ): + connection_name = "myConnection" + token = "abc123" + magic_code = "888999" + + async def exec_test(turn_context: TurnContext): + dialog_context = await dialogs.create_context(turn_context) + + results = await dialog_context.continue_dialog() + + if results.status == DialogTurnStatus.Empty: + await dialog_context.prompt("prompt", PromptOptions()) + elif results.status == DialogTurnStatus.Complete: + if results.result and results.result.token: + await turn_context.send_activity("Failed") + + else: + await turn_context.send_activity("Ended") + + await convo_state.save_changes(turn_context) + + # Initialize TestAdapter. + adapter = TestAdapter(exec_test) + + # Create ConversationState with MemoryStorage and register the state as middleware. + convo_state = ConversationState(MemoryStorage()) + + # Create a DialogState property, DialogSet and AttachmentPrompt. + dialog_state = convo_state.create_property("dialog_state") + dialogs = DialogSet(dialog_state) + dialogs.add( + OAuthPrompt( + "prompt", + OAuthPromptSettings(connection_name, "Login", None, 300000, None, True), + ) + ) + + def inspector( + activity: Activity, description: str = None + ): # pylint: disable=unused-argument + assert len(activity.attachments) == 1 + assert ( + activity.attachments[0].content_type + == CardFactory.content_types.oauth_card + ) + + # send a mock EventActivity back to the bot with the token + adapter.add_user_token( + connection_name, + activity.channel_id, + activity.recipient.id, + token, + magic_code, + ) + + step1 = await adapter.send("Hello") + step2 = await step1.assert_reply(inspector) + step3 = await step2.send("test invalid message") + await step3.assert_reply("Ended") + + async def test_should_timeout_oauth_prompt_with_message_activity(self,): + activity = Activity(type=ActivityTypes.message, text="any") + await self.run_timeout_test(activity) + + async def test_should_timeout_oauth_prompt_with_token_response_event_activity( + self, + ): + activity = Activity( + type=ActivityTypes.event, name=SignInConstants.token_response_event_name + ) + await self.run_timeout_test(activity) + + async def test_should_timeout_oauth_prompt_with_verify_state_operation_activity( + self, + ): + activity = Activity( + type=ActivityTypes.invoke, name=SignInConstants.verify_state_operation_name + ) + await self.run_timeout_test(activity) + + async def test_should_not_timeout_oauth_prompt_with_custom_event_activity(self,): + activity = Activity(type=ActivityTypes.event, name="custom event name") + await self.run_timeout_test(activity, False, "Ended", "Failed") + + async def run_timeout_test( + self, + activity: Activity, + should_succeed: bool = True, + token_response: str = "Failed", + no_token_resonse="Ended", + ): + connection_name = "myConnection" + token = "abc123" + magic_code = "888999" + + async def exec_test(turn_context: TurnContext): + dialog_context = await dialogs.create_context(turn_context) + + results = await dialog_context.continue_dialog() + + if results.status == DialogTurnStatus.Empty: + await dialog_context.prompt("prompt", PromptOptions()) + elif results.status == DialogTurnStatus.Complete or ( + results.status == DialogTurnStatus.Waiting and not should_succeed + ): + if results.result and results.result.token: + await turn_context.send_activity(token_response) + + else: + await turn_context.send_activity(no_token_resonse) + + await convo_state.save_changes(turn_context) + + # Initialize TestAdapter. + adapter = TestAdapter(exec_test) + + # Create ConversationState with MemoryStorage and register the state as middleware. + convo_state = ConversationState(MemoryStorage()) + + # Create a DialogState property, DialogSet and AttachmentPrompt. + dialog_state = convo_state.create_property("dialog_state") + dialogs = DialogSet(dialog_state) + dialogs.add( + OAuthPrompt( + "prompt", OAuthPromptSettings(connection_name, "Login", None, 1), + ) + ) + + def inspector( + activity: Activity, description: str = None + ): # pylint: disable=unused-argument + assert len(activity.attachments) == 1 + assert ( + activity.attachments[0].content_type + == CardFactory.content_types.oauth_card + ) + + # send a mock EventActivity back to the bot with the token + adapter.add_user_token( + connection_name, + activity.channel_id, + activity.recipient.id, + token, + magic_code, + ) + + step1 = await adapter.send("Hello") + step2 = await step1.assert_reply(inspector) + step3 = await step2.send(activity) + await step3.assert_reply(no_token_resonse)