fix bug where oidc audience that contains "/" won't be extract correctly

This commit is contained in:
Niko Izsak 2025-04-16 15:04:35 +02:00
parent 15ac0bd440
commit cdc14fa7fb

View file

@ -105,6 +105,7 @@ def get_secret( # noqa: PLR0915
if secret_name.startswith("oidc/"): if secret_name.startswith("oidc/"):
secret_name_split = secret_name.replace("oidc/", "") secret_name_split = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name_split.split("/", 1) oidc_provider, oidc_aud = secret_name_split.split("/", 1)
oidc_aud = "/".join(secret_name_split.split("/")[1:])
# TODO: Add caching for HTTP requests # TODO: Add caching for HTTP requests
if oidc_provider == "google": if oidc_provider == "google":
oidc_token = oidc_cache.get_cache(key=secret_name) oidc_token = oidc_cache.get_cache(key=secret_name)
@ -140,10 +141,7 @@ def get_secret( # noqa: PLR0915
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions # https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL") actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if ( if actions_id_token_request_url is None or actions_id_token_request_token is None:
actions_id_token_request_url is None
or actions_id_token_request_token is None
):
raise ValueError( raise ValueError(
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment" "ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
) )
@ -171,7 +169,9 @@ def get_secret( # noqa: PLR0915
# https://azure.github.io/azure-workload-identity/docs/quick-start.html # https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE") azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None: if azure_federated_token_file is None:
verbose_logger.warning("AZURE_FEDERATED_TOKEN_FILE not found in environment will use Azure AD token provider") verbose_logger.warning(
"AZURE_FEDERATED_TOKEN_FILE not found in environment will use Azure AD token provider"
)
azure_token_provider = get_azure_ad_token_provider(azure_scope=oidc_aud) azure_token_provider = get_azure_ad_token_provider(azure_scope=oidc_aud)
oidc_token = azure_token_provider() oidc_token = azure_token_provider()
if oidc_token is None: if oidc_token is None:
@ -203,10 +203,7 @@ def get_secret( # noqa: PLR0915
raise ValueError("Unsupported OIDC provider") raise ValueError("Unsupported OIDC provider")
try: try:
if ( if _should_read_secret_from_secret_manager() and litellm.secret_manager_client is not None:
_should_read_secret_from_secret_manager()
and litellm.secret_manager_client is not None
):
try: try:
client = litellm.secret_manager_client client = litellm.secret_manager_client
key_manager = "local" key_manager = "local"
@ -232,9 +229,7 @@ def get_secret( # noqa: PLR0915
): ):
encrypted_secret: Any = os.getenv(secret_name) encrypted_secret: Any = os.getenv(secret_name)
if encrypted_secret is None: if encrypted_secret is None:
raise ValueError( raise ValueError("Google KMS requires the encrypted secret to be in the environment!")
"Google KMS requires the encrypted secret to be in the environment!"
)
b64_flag = _is_base64(encrypted_secret) b64_flag = _is_base64(encrypted_secret)
if b64_flag is True: # if passed in as encoded b64 string if b64_flag is True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret) encrypted_secret = base64.b64decode(encrypted_secret)
@ -249,20 +244,14 @@ def get_secret( # noqa: PLR0915
"ciphertext": ciphertext, "ciphertext": ciphertext,
} }
) )
secret = response.plaintext.decode( secret = response.plaintext.decode("utf-8") # assumes the original value was encoded with utf-8
"utf-8"
) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value: elif key_manager == KeyManagementSystem.AWS_KMS.value:
""" """
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys. Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
""" """
encrypted_value = os.getenv(secret_name, None) encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None: if encrypted_value is None:
raise Exception( raise Exception("AWS KMS - Encrypted Value of Key={} is None".format(secret_name))
"AWS KMS - Encrypted Value of Key={} is None".format(
secret_name
)
)
# Decode the base64 encoded ciphertext # Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value) ciphertext_blob = base64.b64decode(encrypted_value)
@ -289,14 +278,10 @@ def get_secret( # noqa: PLR0915
print_verbose(f"get_secret_value_response: {secret}") print_verbose(f"get_secret_value_response: {secret}")
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value: elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
try: try:
secret = client.get_secret_from_google_secret_manager( secret = client.get_secret_from_google_secret_manager(secret_name)
secret_name
)
print_verbose(f"secret from google secret manager: {secret}") print_verbose(f"secret from google secret manager: {secret}")
if secret is None: if secret is None:
raise ValueError( raise ValueError(f"No secret found in Google Secret Manager for {secret_name}")
f"No secret found in Google Secret Manager for {secret_name}"
)
except Exception as e: except Exception as e:
print_verbose(f"An error occurred - {str(e)}") print_verbose(f"An error occurred - {str(e)}")
raise e raise e
@ -304,9 +289,7 @@ def get_secret( # noqa: PLR0915
try: try:
secret = client.sync_read_secret(secret_name=secret_name) secret = client.sync_read_secret(secret_name=secret_name)
if secret is None: if secret is None:
raise ValueError( raise ValueError(f"No secret found in Hashicorp Secret Manager for {secret_name}")
f"No secret found in Hashicorp Secret Manager for {secret_name}"
)
except Exception as e: except Exception as e:
print_verbose(f"An error occurred - {str(e)}") print_verbose(f"An error occurred - {str(e)}")
raise e raise e
@ -331,9 +314,7 @@ def get_secret( # noqa: PLR0915
else: else:
secret = os.environ.get(secret_name) secret = os.environ.get(secret_name)
secret_value_as_bool = str_to_bool(secret) if secret is not None else None secret_value_as_bool = str_to_bool(secret) if secret is not None else None
if secret_value_as_bool is not None and isinstance( if secret_value_as_bool is not None and isinstance(secret_value_as_bool, bool):
secret_value_as_bool, bool
):
return secret_value_as_bool return secret_value_as_bool
else: else:
return secret return secret