Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Features Added

- Added on_refresh_success callback to load method. This callback is called when the refresh method successfully refreshes the configuration.
- Added minimum up time. This is the minimum amount of time the provider will try to be up before throwing an error. This is to prevent quick restart loops.

### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import random
import time
import datetime
from threading import Lock
import logging
from typing import (
Expand Down Expand Up @@ -51,6 +52,8 @@

logger = logging.getLogger(__name__)

min_uptime = 5


@overload
def load(
Expand Down Expand Up @@ -151,6 +154,7 @@ def load(*args, **kwargs) -> "AzureAppConfigurationProvider":
credential: Optional["TokenCredential"] = kwargs.pop("credential", None)
connection_string: Optional[str] = kwargs.pop("connection_string", None)
key_vault_options: Optional[AzureAppConfigurationKeyVaultOptions] = kwargs.pop("key_vault_options", None)
start_time = datetime.datetime.now()

# Update endpoint and credential if specified positionally.
if len(args) > 2:
Expand Down Expand Up @@ -186,7 +190,11 @@ def load(*args, **kwargs) -> "AzureAppConfigurationProvider":
provider = _buildprovider(
connection_string, endpoint, credential, uses_key_vault="UsesKeyVault" in headers, **kwargs
)
provider._load_all(headers=headers)
try:
provider._load_all(headers=headers)
except Exception as e:
_prekill(start_time)
raise e

# Refresh-All sentinels are not updated on load_all, as they are not necessarily included in the provider.
for (key, label), etag in provider._refresh_on.items():
Expand All @@ -203,10 +211,23 @@ def load(*args, **kwargs) -> "AzureAppConfigurationProvider":
label,
)
else:
_prekill(start_time)
raise e
except Exception as e:
_prekill(start_time)
raise e
return provider


def _prekill(start_time: datetime.datetime) -> None:
# We want to make sure we are up a minimum amount of time before we kill the process. Otherwise, we could get stuck
# in a quick restart loop.
min_time = datetime.timedelta(seconds=min_uptime)
current_time = datetime.datetime.now()
if current_time - start_time < min_time:
time.sleep(min_time - (current_time - start_time))


def _get_headers(request_type, **kwargs) -> str:
headers = kwargs.pop("headers", {})
if os.environ.get(REQUEST_TRACING_DISABLED_ENVIRONMENT_VARIABLE, default="").lower() != "true":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license information.
# -------------------------------------------------------------------------
import json
import datetime
from asyncio.locks import Lock
import logging
from typing import (
Expand Down Expand Up @@ -42,6 +43,7 @@
_get_headers,
_RefreshTimer,
_build_sentinel,
_prekill,
)
from .._user_agent import USER_AGENT

Expand Down Expand Up @@ -153,6 +155,7 @@ async def load(*args, **kwargs) -> "AzureAppConfigurationProvider":
credential: Optional["AsyncTokenCredential"] = kwargs.pop("credential", None)
connection_string: Optional[str] = kwargs.pop("connection_string", None)
key_vault_options: Optional[AzureAppConfigurationKeyVaultOptions] = kwargs.pop("key_vault_options", None)
start_time = datetime.datetime.now()

# Update endpoint and credential if specified positionally.
if len(args) > 2:
Expand Down Expand Up @@ -186,7 +189,12 @@ async def load(*args, **kwargs) -> "AzureAppConfigurationProvider":

headers = _get_headers("Startup", **kwargs)
provider = _buildprovider(connection_string, endpoint, credential, **kwargs)
await provider._load_all(headers=headers)

try:
await provider._load_all(headers=headers)
except Exception as e:
_prekill(start_time)
raise e

# Refresh-All sentinels are not updated on load_all, as they are not necessarily included in the provider.
for (key, label), etag in provider._refresh_on.items():
Expand All @@ -204,7 +212,11 @@ async def load(*args, **kwargs) -> "AzureAppConfigurationProvider":
)
provider._refresh_on[(key, label)] = None
else:
_prekill(start_time)
raise e
except Exception as e:
_prekill(start_time)
raise e
return provider


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from devtools_testutils import recorded_by_proxy
from preparers import app_config_decorator
from testcase import AppConfigTestCase
import datetime
from unittest.mock import patch

from azure.appconfiguration.provider._azureappconfigurationprovider import _prekill


class TestAppConfigurationProvider(AppConfigTestCase):
Expand Down Expand Up @@ -104,6 +108,18 @@ def test_provider_secret_resolver_options(self, appconfiguration_connection_stri
)
assert client["secret"] == "Reslover Value"

# method: _prekill
@patch("time.sleep")
def test_prekill(self, mock_sleep, **kwargs):
start_time = datetime.datetime.now()
_prekill(start_time)
assert mock_sleep.call_count == 1

mock_sleep.reset_mock()
start_time = datetime.datetime.now() - datetime.timedelta(seconds=10)
_prekill(start_time)
mock_sleep.assert_not_called()


def secret_resolver(secret_id):
return "Reslover Value"