feat(aws_secret_manager.py): allows user to keep a hash of the proxy master key in their env

This commit is contained in:
Krrish Dholakia 2024-06-06 15:32:51 -07:00
parent 707a41fd6c
commit a2da2a8f16
5 changed files with 62 additions and 8 deletions

View file

@ -7351,10 +7351,10 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]:
if custom_llm_provider == "databricks":
return litellm.DatabricksConfig().get_required_params()
elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_required_params()
else:
return []
@ -10052,6 +10052,8 @@ def get_secret(
):
key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings
args = locals()
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
@ -10139,13 +10141,13 @@ def get_secret(
key_manager = "local"
if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = client.get_secret(secret_name).value
elif (
key_manager == KeyManagementSystem.GOOGLE_KMS
key_manager == KeyManagementSystem.GOOGLE_KMS.value
or client.__class__.__name__ == "KeyManagementServiceClient"
):
encrypted_secret: Any = os.getenv(secret_name)
@ -10173,6 +10175,25 @@ def get_secret(
secret = response.plaintext.decode(
"utf-8"
) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value:
"""
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
"""
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception("encrypted value for AWS KMS cannot be None.")
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try:
get_secret_value_response = client.get_secret_value(
@ -10193,10 +10214,14 @@ def get_secret(
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
elif key_manager == "local":
secret = os.getenv(secret_name)
else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ
print_verbose(f"An exception occurred - {str(e)}")
verbose_logger.error(
f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}"
)
secret = os.getenv(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)