Merge pull request #4576 from BerriAI/litellm_encrypt_decrypt_using_salt

[Refactor] Use helper function to encrypt/decrypt model credentials
This commit is contained in:
Ishaan Jaff 2024-07-06 15:11:09 -07:00 committed by GitHub
commit d61cc598b0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 300 additions and 240 deletions

View file

@ -0,0 +1,167 @@
import os
def show_missing_vars_in_env():
from fastapi.responses import HTMLResponse
from litellm.proxy.proxy_server import master_key, prisma_client
if prisma_client is None and master_key is None:
return HTMLResponse(
content=missing_keys_form(
missing_key_names="DATABASE_URL, LITELLM_MASTER_KEY"
),
status_code=200,
)
if prisma_client is None:
return HTMLResponse(
content=missing_keys_form(missing_key_names="DATABASE_URL"), status_code=200
)
if master_key is None:
return HTMLResponse(
content=missing_keys_form(missing_key_names="LITELLM_MASTER_KEY"),
status_code=200,
)
return None
# LiteLLM Admin UI - Non SSO Login
url_to_redirect_to = os.getenv("PROXY_BASE_URL", "")
url_to_redirect_to += "/login"
html_form = f"""
<!DOCTYPE html>
<html>
<head>
<title>LiteLLM Login</title>
<style>
body {{
font-family: Arial, sans-serif;
background-color: #f4f4f4;
margin: 0;
padding: 0;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
}}
form {{
background-color: #fff;
padding: 20px;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
label {{
display: block;
margin-bottom: 8px;
}}
input {{
width: 100%;
padding: 8px;
margin-bottom: 16px;
box-sizing: border-box;
border: 1px solid #ccc;
border-radius: 4px;
}}
input[type="submit"] {{
background-color: #4caf50;
color: #fff;
cursor: pointer;
}}
input[type="submit"]:hover {{
background-color: #45a049;
}}
</style>
</head>
<body>
<form action="{url_to_redirect_to}" method="post">
<h2>LiteLLM Login</h2>
<p>By default Username is "admin" and Password is your set LiteLLM Proxy `MASTER_KEY`</p>
<p>If you need to set UI credentials / SSO docs here: <a href="https://docs.litellm.ai/docs/proxy/ui" target="_blank">https://docs.litellm.ai/docs/proxy/ui</a></p>
<br>
<label for="username">Username:</label>
<input type="text" id="username" name="username" required>
<label for="password">Password:</label>
<input type="password" id="password" name="password" required>
<input type="submit" value="Submit">
</form>
"""
def missing_keys_form(missing_key_names: str):
missing_keys_html_form = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: Arial, sans-serif;
background-color: #f4f4f9;
color: #333;
margin: 20px;
line-height: 1.6;
}}
.container {{
max-width: 800px;
margin: auto;
padding: 20px;
background: #fff;
border: 1px solid #ddd;
border-radius: 5px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
h1 {{
font-size: 24px;
margin-bottom: 20px;
}}
pre {{
background: #f8f8f8;
padding: 1px;
border: 1px solid #ccc;
border-radius: 4px;
overflow-x: auto;
font-size: 14px;
}}
.env-var {{
font-weight: normal;
}}
.comment {{
font-weight: normal;
color: #777;
}}
</style>
<title>Environment Setup Instructions</title>
</head>
<body>
<div class="container">
<h1>Environment Setup Instructions</h1>
<p>Please add the following variables to your environment variables:</p>
<pre>
<span class="env-var">LITELLM_MASTER_KEY="sk-1234"</span> <span class="comment"># Your master key for the proxy server. Can use this to send /chat/completion requests etc</span>
<span class="env-var">LITELLM_SALT_KEY="sk-XXXXXXXX"</span> <span class="comment"># Can NOT CHANGE THIS ONCE SET - It is used to encrypt/decrypt credentials stored in DB. If value of 'LITELLM_SALT_KEY' changes your models cannot be retrieved from DB</span>
<span class="env-var">DATABASE_URL="postgres://..."</span> <span class="comment"># Need a postgres database? (Check out Supabase, Neon, etc)</span>
<span class="comment">## OPTIONAL ##</span>
<span class="env-var">PORT=4000</span> <span class="comment"># DO THIS FOR RENDER/RAILWAY</span>
<span class="env-var">STORE_MODEL_IN_DB="True"</span> <span class="comment"># Allow storing models in db</span>
</pre>
<h1>Missing Environment Variables</h1>
<p>{missing_keys}</p>
</div>
<div class="container">
<h1>Need Help? Support</h1>
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
</div>
</body>
</html>
"""
return missing_keys_html_form.format(missing_keys=missing_key_names)

View file

@ -0,0 +1,89 @@
import base64
import os
from litellm._logging import verbose_proxy_logger
LITELLM_SALT_KEY = os.getenv("LITELLM_SALT_KEY", None)
verbose_proxy_logger.debug(
"LITELLM_SALT_KEY is None using master_key to encrypt/decrypt secrets stored in DB"
)
def encrypt_value_helper(value: str):
from litellm.proxy.proxy_server import master_key
signing_key = LITELLM_SALT_KEY
if LITELLM_SALT_KEY is None:
signing_key = master_key
try:
if isinstance(value, str):
encrypted_value = encrypt_value(value=value, signing_key=signing_key) # type: ignore
encrypted_value = base64.b64encode(encrypted_value).decode("utf-8")
return encrypted_value
raise ValueError(
f"Invalid value type passed to encrypt_value: {type(value)} for Value: {value}\n Value must be a string"
)
except Exception as e:
raise e
def decrypt_value_helper(value: str):
from litellm.proxy.proxy_server import master_key
signing_key = LITELLM_SALT_KEY
if LITELLM_SALT_KEY is None:
signing_key = master_key
try:
if isinstance(value, str):
decoded_b64 = base64.b64decode(value)
value = decrypt_value(value=decoded_b64, signing_key=signing_key) # type: ignore
return value
except Exception as e:
verbose_proxy_logger.error(f"Error decrypting value: {value}\nError: {str(e)}")
# [Non-Blocking Exception. - this should not block decrypting other values]
pass
def encrypt_value(value: str, signing_key: str):
import hashlib
import nacl.secret
import nacl.utils
# get 32 byte master key #
hash_object = hashlib.sha256(signing_key.encode())
hash_bytes = hash_object.digest()
# initialize secret box #
box = nacl.secret.SecretBox(hash_bytes)
# encode message #
value_bytes = value.encode("utf-8")
encrypted = box.encrypt(value_bytes)
return encrypted
def decrypt_value(value: bytes, signing_key: str) -> str:
import hashlib
import nacl.secret
import nacl.utils
# get 32 byte master key #
hash_object = hashlib.sha256(signing_key.encode())
hash_bytes = hash_object.digest()
# initialize secret box #
box = nacl.secret.SecretBox(hash_bytes)
# Convert the bytes object to a string
plaintext = box.decrypt(value)
plaintext = plaintext.decode("utf-8") # type: ignore
return plaintext # type: ignore

View file

@ -140,7 +140,15 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
## Import All Misc routes here ## ## Import All Misc routes here ##
from litellm.proxy.caching_routes import router as caching_router from litellm.proxy.caching_routes import router as caching_router
from litellm.proxy.common_utils.admin_ui_utils import (
html_form,
show_missing_vars_in_env,
)
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
encrypt_value_helper,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.proxy.common_utils.openai_endpoint_utils import ( from litellm.proxy.common_utils.openai_endpoint_utils import (
@ -186,13 +194,9 @@ from litellm.proxy.utils import (
_get_projected_spend_over_limit, _get_projected_spend_over_limit,
_is_projected_spend_over_limit, _is_projected_spend_over_limit,
_is_valid_team_configs, _is_valid_team_configs,
decrypt_value,
encrypt_value,
get_error_message_str, get_error_message_str,
get_instance_fn, get_instance_fn,
hash_token, hash_token,
html_form,
missing_keys_html_form,
reset_budget, reset_budget,
send_email, send_email,
update_spend, update_spend,
@ -1243,6 +1247,7 @@ class ProxyConfig:
## DB ## DB
if prisma_client is not None and ( if prisma_client is not None and (
general_settings.get("store_model_in_db", False) == True general_settings.get("store_model_in_db", False) == True
or store_model_in_db is True
): ):
_tasks = [] _tasks = []
keys = [ keys = [
@ -1885,16 +1890,8 @@ class ProxyConfig:
# decrypt values # decrypt values
for k, v in _litellm_params.items(): for k, v in _litellm_params.items():
if isinstance(v, str): if isinstance(v, str):
# decode base64
try:
decoded_b64 = base64.b64decode(v)
except Exception as e:
verbose_proxy_logger.error(
"Error decoding value - {}".format(v)
)
continue
# decrypt value # decrypt value
_value = decrypt_value(value=decoded_b64, master_key=master_key) _value = decrypt_value_helper(value=v)
# sanity check if string > size 0 # sanity check if string > size 0
if len(_value) > 0: if len(_value) > 0:
_litellm_params[k] = _value _litellm_params[k] = _value
@ -1938,13 +1935,8 @@ class ProxyConfig:
if isinstance(_litellm_params, dict): if isinstance(_litellm_params, dict):
# decrypt values # decrypt values
for k, v in _litellm_params.items(): for k, v in _litellm_params.items():
if isinstance(v, str): decrypted_value = decrypt_value_helper(value=v)
# decode base64 _litellm_params[k] = decrypted_value
decoded_b64 = base64.b64decode(v)
# decrypt value
_litellm_params[k] = decrypt_value(
value=decoded_b64, master_key=master_key # type: ignore
)
_litellm_params = LiteLLM_Params(**_litellm_params) _litellm_params = LiteLLM_Params(**_litellm_params)
else: else:
verbose_proxy_logger.error( verbose_proxy_logger.error(
@ -2005,10 +1997,8 @@ class ProxyConfig:
environment_variables = config_data.get("environment_variables", {}) environment_variables = config_data.get("environment_variables", {})
for k, v in environment_variables.items(): for k, v in environment_variables.items():
try: try:
if v is not None: decrypted_value = decrypt_value_helper(value=v)
decoded_b64 = base64.b64decode(v) os.environ[k] = decrypted_value
value = decrypt_value(value=decoded_b64, master_key=master_key) # type: ignore
os.environ[k] = value
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
"Error setting env variable: %s - %s", k, str(e) "Error setting env variable: %s - %s", k, str(e)
@ -5941,11 +5931,8 @@ async def add_new_model(
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True) _litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
_orignal_litellm_model_name = model_params.litellm_params.model _orignal_litellm_model_name = model_params.litellm_params.model
for k, v in _litellm_params_dict.items(): for k, v in _litellm_params_dict.items():
if isinstance(v, str): encrypted_value = encrypt_value_helper(value=v)
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore model_params.litellm_params[k] = encrypted_value
model_params.litellm_params[k] = base64.b64encode(
encrypted_value
).decode("utf-8")
_data: dict = { _data: dict = {
"model_id": model_params.model_info.id, "model_id": model_params.model_info.id,
"model_name": model_params.model_name, "model_name": model_params.model_name,
@ -6076,11 +6063,8 @@ async def update_model(
### ENCRYPT PARAMS ### ### ENCRYPT PARAMS ###
for k, v in _new_litellm_params_dict.items(): for k, v in _new_litellm_params_dict.items():
if isinstance(v, str): encrypted_value = encrypt_value_helper(value=v)
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore model_params.litellm_params[k] = encrypted_value
model_params.litellm_params[k] = base64.b64encode(
encrypted_value
).decode("utf-8")
### MERGE WITH EXISTING DATA ### ### MERGE WITH EXISTING DATA ###
merged_dictionary = {} merged_dictionary = {}
@ -7198,10 +7182,9 @@ async def google_login(request: Request):
) )
####### Detect DB + MASTER KEY in .env ####### ####### Detect DB + MASTER KEY in .env #######
if prisma_client is None or master_key is None: missing_env_vars = show_missing_vars_in_env()
from fastapi.responses import HTMLResponse if missing_env_vars is not None:
return missing_env_vars
return HTMLResponse(content=missing_keys_html_form, status_code=200)
# get url from request # get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
@ -8404,11 +8387,8 @@ async def update_config(config_info: ConfigYAML):
# encrypt updated_environment_variables # # encrypt updated_environment_variables #
for k, v in _updated_environment_variables.items(): for k, v in _updated_environment_variables.items():
if isinstance(v, str): encrypted_value = encrypt_value_helper(value=v)
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore _updated_environment_variables[k] = encrypted_value
_updated_environment_variables[k] = base64.b64encode(
encrypted_value
).decode("utf-8")
_existing_env_variables = config["environment_variables"] _existing_env_variables = config["environment_variables"]
@ -8825,11 +8805,8 @@ async def get_config():
env_vars_dict[_var] = None env_vars_dict[_var] = None
else: else:
# decode + decrypt the value # decode + decrypt the value
decoded_b64 = base64.b64decode(env_variable) decrypted_value = decrypt_value_helper(value=env_variable)
_decrypted_value = decrypt_value( env_vars_dict[_var] = decrypted_value
value=decoded_b64, master_key=master_key
)
env_vars_dict[_var] = _decrypted_value
_data_to_return.append({"name": _callback, "variables": env_vars_dict}) _data_to_return.append({"name": _callback, "variables": env_vars_dict})
elif _callback == "langfuse": elif _callback == "langfuse":
@ -8845,11 +8822,8 @@ async def get_config():
_langfuse_env_vars[_var] = None _langfuse_env_vars[_var] = None
else: else:
# decode + decrypt the value # decode + decrypt the value
decoded_b64 = base64.b64decode(env_variable) decrypted_value = decrypt_value_helper(value=env_variable)
_decrypted_value = decrypt_value( _langfuse_env_vars[_var] = decrypted_value
value=decoded_b64, master_key=master_key
)
_langfuse_env_vars[_var] = _decrypted_value
_data_to_return.append( _data_to_return.append(
{"name": _callback, "variables": _langfuse_env_vars} {"name": _callback, "variables": _langfuse_env_vars}
@ -8870,10 +8844,7 @@ async def get_config():
_slack_env_vars[_var] = _value _slack_env_vars[_var] = _value
else: else:
# decode + decrypt the value # decode + decrypt the value
decoded_b64 = base64.b64decode(env_variable) _decrypted_value = decrypt_value_helper(value=env_variable)
_decrypted_value = decrypt_value(
value=decoded_b64, master_key=master_key
)
_slack_env_vars[_var] = _decrypted_value _slack_env_vars[_var] = _decrypted_value
_alerting_types = proxy_logging_obj.slack_alerting_instance.alert_types _alerting_types = proxy_logging_obj.slack_alerting_instance.alert_types
@ -8909,10 +8880,7 @@ async def get_config():
_email_env_vars[_var] = None _email_env_vars[_var] = None
else: else:
# decode + decrypt the value # decode + decrypt the value
decoded_b64 = base64.b64decode(env_variable) _decrypted_value = decrypt_value_helper(value=env_variable)
_decrypted_value = decrypt_value(
value=decoded_b64, master_key=master_key
)
_email_env_vars[_var] = _decrypted_value _email_env_vars[_var] = _decrypted_value
alerting_data.append( alerting_data.append(

View file

@ -2705,178 +2705,6 @@ def _is_valid_team_configs(team_id=None, team_config=None, request_data=None):
return return
def encrypt_value(value: str, master_key: str):
import hashlib
import nacl.secret
import nacl.utils
# get 32 byte master key #
hash_object = hashlib.sha256(master_key.encode())
hash_bytes = hash_object.digest()
# initialize secret box #
box = nacl.secret.SecretBox(hash_bytes)
# encode message #
value_bytes = value.encode("utf-8")
encrypted = box.encrypt(value_bytes)
return encrypted
def decrypt_value(value: bytes, master_key: str) -> str:
import hashlib
import nacl.secret
import nacl.utils
# get 32 byte master key #
hash_object = hashlib.sha256(master_key.encode())
hash_bytes = hash_object.digest()
# initialize secret box #
box = nacl.secret.SecretBox(hash_bytes)
# Convert the bytes object to a string
plaintext = box.decrypt(value)
plaintext = plaintext.decode("utf-8") # type: ignore
return plaintext # type: ignore
# LiteLLM Admin UI - Non SSO Login
url_to_redirect_to = os.getenv("PROXY_BASE_URL", "")
url_to_redirect_to += "/login"
html_form = f"""
<!DOCTYPE html>
<html>
<head>
<title>LiteLLM Login</title>
<style>
body {{
font-family: Arial, sans-serif;
background-color: #f4f4f4;
margin: 0;
padding: 0;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
}}
form {{
background-color: #fff;
padding: 20px;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
label {{
display: block;
margin-bottom: 8px;
}}
input {{
width: 100%;
padding: 8px;
margin-bottom: 16px;
box-sizing: border-box;
border: 1px solid #ccc;
border-radius: 4px;
}}
input[type="submit"] {{
background-color: #4caf50;
color: #fff;
cursor: pointer;
}}
input[type="submit"]:hover {{
background-color: #45a049;
}}
</style>
</head>
<body>
<form action="{url_to_redirect_to}" method="post">
<h2>LiteLLM Login</h2>
<p>By default Username is "admin" and Password is your set LiteLLM Proxy `MASTER_KEY`</p>
<p>If you need to set UI credentials / SSO docs here: <a href="https://docs.litellm.ai/docs/proxy/ui" target="_blank">https://docs.litellm.ai/docs/proxy/ui</a></p>
<br>
<label for="username">Username:</label>
<input type="text" id="username" name="username" required>
<label for="password">Password:</label>
<input type="password" id="password" name="password" required>
<input type="submit" value="Submit">
</form>
"""
missing_keys_html_form = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {
font-family: Arial, sans-serif;
background-color: #f4f4f9;
color: #333;
margin: 20px;
line-height: 1.6;
}
.container {
max-width: 600px;
margin: auto;
padding: 20px;
background: #fff;
border: 1px solid #ddd;
border-radius: 5px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}
h1 {
font-size: 24px;
margin-bottom: 20px;
}
pre {
background: #f8f8f8;
padding: 10px;
border: 1px solid #ccc;
border-radius: 4px;
overflow-x: auto;
font-size: 14px;
}
.env-var {
font-weight: normal;
}
.comment {
font-weight: normal;
color: #777;
}
</style>
<title>Environment Setup Instructions</title>
</head>
<body>
<div class="container">
<h1>Environment Setup Instructions</h1>
<p>Please add the following configurations to your environment variables:</p>
<pre>
<span class="env-var">LITELLM_MASTER_KEY="sk-1234"</span> <span class="comment"># make this unique. must start with `sk-`.</span>
<span class="env-var">DATABASE_URL="postgres://..."</span> <span class="comment"># Need a postgres database? (Check out Supabase, Neon, etc)</span>
<span class="comment">## OPTIONAL ##</span>
<span class="env-var">PORT=4000</span> <span class="comment"># DO THIS FOR RENDER/RAILWAY</span>
<span class="env-var">STORE_MODEL_IN_DB="True"</span> <span class="comment"># Allow storing models in db</span>
</pre>
</div>
</body>
</html>
"""
def _to_ns(dt): def _to_ns(dt):
return int(dt.timestamp() * 1e9) return int(dt.timestamp() * 1e9)

View file

@ -2,23 +2,30 @@
## Unit tests for ProxyConfig class ## Unit tests for ProxyConfig class
import sys, os import os
import sys
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import os, io import io
import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest, litellm
from pydantic import BaseModel, ConfigDict
from litellm.proxy.proxy_server import ProxyConfig
from litellm.proxy.utils import encrypt_value, ProxyLogging, DualCache
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
from typing import Literal from typing import Literal
import pytest
from pydantic import BaseModel, ConfigDict
import litellm
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value
from litellm.proxy.proxy_server import ProxyConfig
from litellm.proxy.utils import DualCache, ProxyLogging
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
class DBModel(BaseModel): class DBModel(BaseModel):
model_id: str model_id: str
@ -28,6 +35,7 @@ class DBModel(BaseModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_deployment(): async def test_delete_deployment():
""" """