Merge pull request #4576 from BerriAI/litellm_encrypt_decrypt_using_salt

[Refactor] Use helper function to encrypt/decrypt model credentials
This commit is contained in:
Ishaan Jaff 2024-07-06 15:11:09 -07:00 committed by GitHub
commit d61cc598b0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 300 additions and 240 deletions

View file

@ -140,7 +140,15 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
## Import All Misc routes here ##
from litellm.proxy.caching_routes import router as caching_router
from litellm.proxy.common_utils.admin_ui_utils import (
html_form,
show_missing_vars_in_env,
)
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
encrypt_value_helper,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.proxy.common_utils.openai_endpoint_utils import (
@ -186,13 +194,9 @@ from litellm.proxy.utils import (
_get_projected_spend_over_limit,
_is_projected_spend_over_limit,
_is_valid_team_configs,
decrypt_value,
encrypt_value,
get_error_message_str,
get_instance_fn,
hash_token,
html_form,
missing_keys_html_form,
reset_budget,
send_email,
update_spend,
@ -1243,6 +1247,7 @@ class ProxyConfig:
## DB
if prisma_client is not None and (
general_settings.get("store_model_in_db", False) == True
or store_model_in_db is True
):
_tasks = []
keys = [
@ -1885,16 +1890,8 @@ class ProxyConfig:
# decrypt values
for k, v in _litellm_params.items():
if isinstance(v, str):
# decode base64
try:
decoded_b64 = base64.b64decode(v)
except Exception as e:
verbose_proxy_logger.error(
"Error decoding value - {}".format(v)
)
continue
# decrypt value
_value = decrypt_value(value=decoded_b64, master_key=master_key)
_value = decrypt_value_helper(value=v)
# sanity check if string > size 0
if len(_value) > 0:
_litellm_params[k] = _value
@ -1938,13 +1935,8 @@ class ProxyConfig:
if isinstance(_litellm_params, dict):
# decrypt values
for k, v in _litellm_params.items():
if isinstance(v, str):
# decode base64
decoded_b64 = base64.b64decode(v)
# decrypt value
_litellm_params[k] = decrypt_value(
value=decoded_b64, master_key=master_key # type: ignore
)
decrypted_value = decrypt_value_helper(value=v)
_litellm_params[k] = decrypted_value
_litellm_params = LiteLLM_Params(**_litellm_params)
else:
verbose_proxy_logger.error(
@ -2005,10 +1997,8 @@ class ProxyConfig:
environment_variables = config_data.get("environment_variables", {})
for k, v in environment_variables.items():
try:
if v is not None:
decoded_b64 = base64.b64decode(v)
value = decrypt_value(value=decoded_b64, master_key=master_key) # type: ignore
os.environ[k] = value
decrypted_value = decrypt_value_helper(value=v)
os.environ[k] = decrypted_value
except Exception as e:
verbose_proxy_logger.error(
"Error setting env variable: %s - %s", k, str(e)
@ -5941,11 +5931,8 @@ async def add_new_model(
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
_orignal_litellm_model_name = model_params.litellm_params.model
for k, v in _litellm_params_dict.items():
if isinstance(v, str):
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore
model_params.litellm_params[k] = base64.b64encode(
encrypted_value
).decode("utf-8")
encrypted_value = encrypt_value_helper(value=v)
model_params.litellm_params[k] = encrypted_value
_data: dict = {
"model_id": model_params.model_info.id,
"model_name": model_params.model_name,
@ -6076,11 +6063,8 @@ async def update_model(
### ENCRYPT PARAMS ###
for k, v in _new_litellm_params_dict.items():
if isinstance(v, str):
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore
model_params.litellm_params[k] = base64.b64encode(
encrypted_value
).decode("utf-8")
encrypted_value = encrypt_value_helper(value=v)
model_params.litellm_params[k] = encrypted_value
### MERGE WITH EXISTING DATA ###
merged_dictionary = {}
@ -7198,10 +7182,9 @@ async def google_login(request: Request):
)
####### Detect DB + MASTER KEY in .env #######
if prisma_client is None or master_key is None:
from fastapi.responses import HTMLResponse
return HTMLResponse(content=missing_keys_html_form, status_code=200)
missing_env_vars = show_missing_vars_in_env()
if missing_env_vars is not None:
return missing_env_vars
# get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
@ -8404,11 +8387,8 @@ async def update_config(config_info: ConfigYAML):
# encrypt updated_environment_variables #
for k, v in _updated_environment_variables.items():
if isinstance(v, str):
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore
_updated_environment_variables[k] = base64.b64encode(
encrypted_value
).decode("utf-8")
encrypted_value = encrypt_value_helper(value=v)
_updated_environment_variables[k] = encrypted_value
_existing_env_variables = config["environment_variables"]
@ -8825,11 +8805,8 @@ async def get_config():
env_vars_dict[_var] = None
else:
# decode + decrypt the value
decoded_b64 = base64.b64decode(env_variable)
_decrypted_value = decrypt_value(
value=decoded_b64, master_key=master_key
)
env_vars_dict[_var] = _decrypted_value
decrypted_value = decrypt_value_helper(value=env_variable)
env_vars_dict[_var] = decrypted_value
_data_to_return.append({"name": _callback, "variables": env_vars_dict})
elif _callback == "langfuse":
@ -8845,11 +8822,8 @@ async def get_config():
_langfuse_env_vars[_var] = None
else:
# decode + decrypt the value
decoded_b64 = base64.b64decode(env_variable)
_decrypted_value = decrypt_value(
value=decoded_b64, master_key=master_key
)
_langfuse_env_vars[_var] = _decrypted_value
decrypted_value = decrypt_value_helper(value=env_variable)
_langfuse_env_vars[_var] = decrypted_value
_data_to_return.append(
{"name": _callback, "variables": _langfuse_env_vars}
@ -8870,10 +8844,7 @@ async def get_config():
_slack_env_vars[_var] = _value
else:
# decode + decrypt the value
decoded_b64 = base64.b64decode(env_variable)
_decrypted_value = decrypt_value(
value=decoded_b64, master_key=master_key
)
_decrypted_value = decrypt_value_helper(value=env_variable)
_slack_env_vars[_var] = _decrypted_value
_alerting_types = proxy_logging_obj.slack_alerting_instance.alert_types
@ -8909,10 +8880,7 @@ async def get_config():
_email_env_vars[_var] = None
else:
# decode + decrypt the value
decoded_b64 = base64.b64decode(env_variable)
_decrypted_value = decrypt_value(
value=decoded_b64, master_key=master_key
)
_decrypted_value = decrypt_value_helper(value=env_variable)
_email_env_vars[_var] = _decrypted_value
alerting_data.append(