feat(proxy_server.py): abstract config update/writing and support persisting config in db

allows user to opt into writing to db (SAVE_CONFIG_TO_DB) and removes any api keys before sending to db

 https://github.com/BerriAI/litellm/issues/1322
This commit is contained in:
Krrish Dholakia 2024-01-04 14:44:45 +05:30
parent 6dea0d3115
commit 99d9a825de
4 changed files with 430 additions and 309 deletions

View file

@ -502,232 +502,331 @@ async def _run_background_health_check():
await asyncio.sleep(health_check_interval)
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
config = {}
try:
if os.path.exists(config_file_path):
class ProxyConfig:
"""
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
"""
def __init__(self) -> None:
pass
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
global prisma_client, user_config_file_path
file_path = config_file_path or user_config_file_path
if config_file_path is not None:
user_config_file_path = config_file_path
with open(config_file_path, "r") as file:
config = yaml.safe_load(file)
# Load existing config
## Yaml
if os.path.exists(f"{file_path}"):
with open(f"{file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
raise Exception(
f"Path to config does not exist, Current working directory: {os.getcwd()}, 'os.path.exists({config_file_path})' returned False"
config = {
"model_list": [],
"general_settings": {},
"router_settings": {},
"litellm_settings": {},
}
## DB
if (
prisma_client is not None
and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True
):
_tasks = []
keys = [
"model_list",
"general_settings",
"router_settings",
"litellm_settings",
]
for k in keys:
response = prisma_client.get_generic_data(
key="param_name", value=k, table_name="config"
)
_tasks.append(response)
responses = await asyncio.gather(*_tasks)
return config
async def save_config(self, new_config: dict):
global prisma_client, llm_router, user_config_file_path
# Load existing config
backup_config = await self.get_config()
# Save the updated config
## YAML
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(new_config, config_file, default_flow_style=False)
# update Router - verifies if this is a valid config
try:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=user_config_file_path
)
except Exception as e:
raise Exception(f"Exception while reading Config: {e}")
except Exception as e:
traceback.print_exc()
# Revert to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(status_code=400, detail="Invalid config passed in")
## PRINT YAML FOR CONFIRMING IT WORKS
printed_yaml = copy.deepcopy(config)
printed_yaml.pop("environment_variables", None)
## DB - writes valid config to db
"""
- Do not write restricted params like 'api_key' to the database
- if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`)
"""
if (
prisma_client is not None
and litellm.get_secret("SAVE_CONFIG_TO_DB", default_value=False) == True
):
### KEY REMOVAL ###
models = new_config.get("model_list", [])
for m in models:
if m.get("litellm_params", {}).get("api_key", None) is not None:
# pop the key
api_key = m["litellm_params"].pop("api_key")
# store in local env
key_name = f"LITELLM_MODEL_KEY_{uuid.uuid4()}"
os.environ[key_name] = api_key
# save the key name (not the value)
m["litellm_params"]["api_key"] = f"os.environ/{key_name}"
await prisma_client.insert_data(data=new_config, table_name="config")
print_verbose(
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
)
async def load_config(
self, router: Optional[litellm.Router], config_file_path: str
):
"""
Load config values into proxy global state
"""
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
## ENVIRONMENT VARIABLES
environment_variables = config.get("environment_variables", None)
if environment_variables:
for key, value in environment_variables.items():
os.environ[key] = value
# Load existing config
config = await self.get_config(config_file_path=config_file_path)
## PRINT YAML FOR CONFIRMING IT WORKS
printed_yaml = copy.deepcopy(config)
printed_yaml.pop("environment_variables", None)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get("litellm_settings", None)
if litellm_settings is None:
litellm_settings = {}
if litellm_settings:
# ANSI escape code for blue text
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
for key, value in litellm_settings.items():
if key == "cache":
print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa
from litellm.caching import Cache
print_verbose(
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
)
cache_params = {}
if "cache_params" in litellm_settings:
cache_params_in_config = litellm_settings["cache_params"]
# overwrie cache_params with cache_params_in_config
cache_params.update(cache_params_in_config)
## ENVIRONMENT VARIABLES
environment_variables = config.get("environment_variables", None)
if environment_variables:
for key, value in environment_variables.items():
os.environ[key] = value
cache_type = cache_params.get("type", "redis")
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get("litellm_settings", None)
if litellm_settings is None:
litellm_settings = {}
if litellm_settings:
# ANSI escape code for blue text
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
for key, value in litellm_settings.items():
if key == "cache":
print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa
from litellm.caching import Cache
print_verbose(f"passed cache type={cache_type}")
cache_params = {}
if "cache_params" in litellm_settings:
cache_params_in_config = litellm_settings["cache_params"]
# overwrie cache_params with cache_params_in_config
cache_params.update(cache_params_in_config)
if cache_type == "redis":
cache_host = litellm.get_secret("REDIS_HOST", None)
cache_port = litellm.get_secret("REDIS_PORT", None)
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
cache_type = cache_params.get("type", "redis")
cache_params = {
"type": cache_type,
"host": cache_host,
"port": cache_port,
"password": cache_password,
}
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
print_verbose(f"passed cache type={cache_type}")
if cache_type == "redis":
cache_host = litellm.get_secret("REDIS_HOST", None)
cache_port = litellm.get_secret("REDIS_PORT", None)
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
cache_params = {
"type": cache_type,
"host": cache_host,
"port": cache_port,
"password": cache_password,
}
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
print( # noqa
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
)
print() # noqa
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
litellm.cache = Cache(**cache_params)
print( # noqa
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
)
print() # noqa
elif key == "callbacks":
litellm.callbacks = [
get_instance_fn(value=value, config_file_path=config_file_path)
]
print_verbose(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
)
elif key == "post_call_rules":
litellm.post_call_rules = [
get_instance_fn(value=value, config_file_path=config_file_path)
]
print_verbose(f"litellm.post_call_rules: {litellm.post_call_rules}")
elif key == "success_callback":
litellm.success_callback = []
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
litellm.cache = Cache(**cache_params)
print( # noqa
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
)
elif key == "callbacks":
litellm.callbacks = [
get_instance_fn(value=value, config_file_path=config_file_path)
]
print_verbose(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
)
elif key == "post_call_rules":
litellm.post_call_rules = [
get_instance_fn(value=value, config_file_path=config_file_path)
]
print_verbose(f"litellm.post_call_rules: {litellm.post_call_rules}")
elif key == "success_callback":
litellm.success_callback = []
# intialize success callbacks
for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback:
litellm.success_callback.append(
get_instance_fn(value=callback)
)
# these are litellm callbacks - "langfuse", "sentry", "wandb"
else:
litellm.success_callback.append(callback)
print_verbose(
f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}"
)
elif key == "failure_callback":
litellm.failure_callback = []
# intialize success callbacks
for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback:
litellm.success_callback.append(get_instance_fn(value=callback))
# these are litellm callbacks - "langfuse", "sentry", "wandb"
else:
litellm.success_callback.append(callback)
print_verbose(
f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}"
)
elif key == "failure_callback":
litellm.failure_callback = []
# intialize success callbacks
for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback:
litellm.failure_callback.append(
get_instance_fn(value=callback)
)
# these are litellm callbacks - "langfuse", "sentry", "wandb"
else:
litellm.failure_callback.append(callback)
print_verbose(
f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}"
)
elif key == "cache_params":
# this is set in the cache branch
# see usage here: https://docs.litellm.ai/docs/proxy/caching
pass
else:
setattr(litellm, key, value)
# intialize success callbacks
for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback:
litellm.failure_callback.append(get_instance_fn(value=callback))
# these are litellm callbacks - "langfuse", "sentry", "wandb"
else:
litellm.failure_callback.append(callback)
print_verbose(
f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}"
)
elif key == "cache_params":
# this is set in the cache branch
# see usage here: https://docs.litellm.ai/docs/proxy/caching
pass
else:
setattr(litellm, key, value)
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
general_settings = config.get("general_settings", {})
if general_settings is None:
general_settings = {}
if general_settings:
### LOAD SECRET MANAGER ###
key_management_system = general_settings.get("key_management_system", None)
if key_management_system is not None:
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
### LOAD FROM AZURE KEY VAULT ###
load_from_azure_key_vault(use_azure_key_vault=True)
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
### LOAD FROM GOOGLE KMS ###
load_google_kms(use_google_kms=True)
else:
raise ValueError("Invalid Key Management System selected")
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
use_google_kms = general_settings.get("use_google_kms", False)
load_google_kms(use_google_kms=use_google_kms)
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### ALERTING ###
proxy_logging_obj.update_values(
alerting=general_settings.get("alerting", None),
alerting_threshold=general_settings.get("alerting_threshold", 600),
)
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"):
print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url)
print_verbose(f"RETRIEVED DB URL: {database_url}")
prisma_setup(database_url=database_url)
## COST TRACKING ##
cost_tracking()
### MASTER KEY ###
master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
### CUSTOM API KEY AUTH ###
custom_auth = general_settings.get("custom_auth", None)
if custom_auth:
user_custom_auth = get_instance_fn(
value=custom_auth, config_file_path=config_file_path
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
general_settings = config.get("general_settings", {})
if general_settings is None:
general_settings = {}
if general_settings:
### LOAD SECRET MANAGER ###
key_management_system = general_settings.get("key_management_system", None)
if key_management_system is not None:
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
### LOAD FROM AZURE KEY VAULT ###
load_from_azure_key_vault(use_azure_key_vault=True)
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
### LOAD FROM GOOGLE KMS ###
load_google_kms(use_google_kms=True)
else:
raise ValueError("Invalid Key Management System selected")
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
use_google_kms = general_settings.get("use_google_kms", False)
load_google_kms(use_google_kms=use_google_kms)
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### ALERTING ###
proxy_logging_obj.update_values(
alerting=general_settings.get("alerting", None),
alerting_threshold=general_settings.get("alerting_threshold", 600),
)
### BACKGROUND HEALTH CHECKS ###
# Enable background health checks
use_background_health_checks = general_settings.get(
"background_health_checks", False
)
health_check_interval = general_settings.get("health_check_interval", 300)
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"):
print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url)
print_verbose(f"RETRIEVED DB URL: {database_url}")
prisma_setup(database_url=database_url)
## COST TRACKING ##
cost_tracking()
### MASTER KEY ###
master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
### CUSTOM API KEY AUTH ###
custom_auth = general_settings.get("custom_auth", None)
if custom_auth:
user_custom_auth = get_instance_fn(
value=custom_auth, config_file_path=config_file_path
)
### BACKGROUND HEALTH CHECKS ###
# Enable background health checks
use_background_health_checks = general_settings.get(
"background_health_checks", False
)
health_check_interval = general_settings.get("health_check_interval", 300)
router_params: dict = {
"num_retries": 3,
"cache_responses": litellm.cache
!= None, # cache if user passed in cache values
}
## MODEL LIST
model_list = config.get("model_list", None)
if model_list:
router_params["model_list"] = model_list
print( # noqa
f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m"
) # noqa
for model in model_list:
### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v)
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
litellm_model_name = model["litellm_params"]["model"]
litellm_model_api_base = model["litellm_params"].get("api_base", None)
if "ollama" in litellm_model_name and litellm_model_api_base is None:
run_ollama_serve()
## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
arg_spec = inspect.getfullargspec(litellm.Router)
# model list already set
exclude_args = {
"self",
"model_list",
router_params: dict = {
"num_retries": 3,
"cache_responses": litellm.cache
!= None, # cache if user passed in cache values
}
## MODEL LIST
model_list = config.get("model_list", None)
if model_list:
router_params["model_list"] = model_list
print( # noqa
f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m"
) # noqa
for model in model_list:
### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v)
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
litellm_model_name = model["litellm_params"]["model"]
litellm_model_api_base = model["litellm_params"].get("api_base", None)
if "ollama" in litellm_model_name and litellm_model_api_base is None:
run_ollama_serve()
available_args = [x for x in arg_spec.args if x not in exclude_args]
## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
arg_spec = inspect.getfullargspec(litellm.Router)
# model list already set
exclude_args = {
"self",
"model_list",
}
for k, v in router_settings.items():
if k in available_args:
router_params[k] = v
available_args = [x for x in arg_spec.args if x not in exclude_args]
router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings
for k, v in router_settings.items():
if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings
proxy_config = ProxyConfig()
async def generate_key_helper_fn(
@ -856,10 +955,6 @@ def initialize(
if debug == True: # this needs to be first, so users can see Router init debugg
litellm.set_verbose = True
dynamic_config = {"general": {}, user_model: {}}
if config:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=config
)
if headers: # model-specific param
user_headers = headers
dynamic_config[user_model]["headers"] = headers
@ -988,7 +1083,7 @@ def parse_cache_control(cache_control):
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key, use_background_health_checks
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings
import json
### LOAD MASTER KEY ###
@ -1000,10 +1095,26 @@ async def startup_event():
print_verbose(f"worker_config: {worker_config}")
# check if it's a valid file path
if os.path.isfile(worker_config):
initialize(config=worker_config)
if worker_config.get("config", None) is not None:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config.pop("config")
)
initialize(**worker_config)
else:
# if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
if worker_config.get("config", None) is not None:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config.pop("config")
)
initialize(**worker_config)
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
@ -1825,7 +1936,7 @@ async def user_auth(request: Request):
### Check if user email in user table
response = await prisma_client.get_generic_data(
key="user_email", value=user_email, db="users"
key="user_email", value=user_email, table_name="users"
)
### if so - generate a 24 hr key with that user id
if response is not None:
@ -1883,16 +1994,13 @@ async def user_update(request: Request):
dependencies=[Depends(user_api_key_auth)],
)
async def add_new_model(model_params: ModelParams):
global llm_router, llm_model_list, general_settings, user_config_file_path
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
try:
print_verbose(f"User config path: {user_config_file_path}")
# Load existing config
if os.path.exists(f"{user_config_file_path}"):
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
config = {"model_list": []}
backup_config = copy.deepcopy(config)
config = await proxy_config.get_config()
print_verbose(f"User config path: {user_config_file_path}")
print_verbose(f"Loaded config: {config}")
# Add the new model to the config
model_info = model_params.model_info.json()
@ -1907,22 +2015,8 @@ async def add_new_model(model_params: ModelParams):
print_verbose(f"updated model list: {config['model_list']}")
# Save the updated config
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(config, config_file, default_flow_style=False)
# update Router
try:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
except Exception as e:
# Rever to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(status_code=400, detail="Invalid Model passed in")
print_verbose(f"llm_model_list: {llm_model_list}")
# Save new config
await proxy_config.save_config(new_config=config)
return {"message": "Model added successfully"}
except Exception as e:
@ -1949,13 +2043,10 @@ async def add_new_model(model_params: ModelParams):
dependencies=[Depends(user_api_key_auth)],
)
async def model_info_v1(request: Request):
global llm_model_list, general_settings, user_config_file_path
global llm_model_list, general_settings, user_config_file_path, proxy_config
# Load existing config
if os.path.exists(f"{user_config_file_path}"):
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
config = {"model_list": []} # handle base case
config = await proxy_config.get_config()
all_models = config["model_list"]
for model in all_models:
@ -1984,18 +2075,18 @@ async def model_info_v1(request: Request):
dependencies=[Depends(user_api_key_auth)],
)
async def delete_model(model_info: ModelInfoDelete):
global llm_router, llm_model_list, general_settings, user_config_file_path
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
try:
if not os.path.exists(user_config_file_path):
raise HTTPException(status_code=404, detail="Config file does not exist.")
with open(user_config_file_path, "r") as config_file:
config = yaml.safe_load(config_file)
# Load existing config
config = await proxy_config.get_config()
# If model_list is not in the config, nothing can be deleted
if "model_list" not in config:
if len(config.get("model_list", [])) == 0:
raise HTTPException(
status_code=404, detail="No model list available in the config."
status_code=400, detail="No model list available in the config."
)
# Check if the model with the specified model_id exists
@ -2008,19 +2099,14 @@ async def delete_model(model_info: ModelInfoDelete):
# If the model was not found, return an error
if model_to_delete is None:
raise HTTPException(
status_code=404, detail="Model with given model_id not found."
status_code=400, detail="Model with given model_id not found."
)
# Remove model from the list and save the updated config
config["model_list"].remove(model_to_delete)
with open(user_config_file_path, "w") as config_file:
yaml.dump(config, config_file, default_flow_style=False)
# Update Router
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
# Save updated config
config = await proxy_config.save_config(new_config=config)
return {"message": "Model deleted successfully"}
except HTTPException as e:
@ -2200,14 +2286,11 @@ async def update_config(config_info: ConfigYAML):
Currently supports modifying General Settings + LiteLLM settings
"""
global llm_router, llm_model_list, general_settings
global llm_router, llm_model_list, general_settings, proxy_config
try:
# Load existing config
if os.path.exists(f"{user_config_file_path}"):
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
config = {}
config = await proxy_config.get_config()
backup_config = copy.deepcopy(config)
print_verbose(f"Loaded config: {config}")
@ -2240,21 +2323,7 @@ async def update_config(config_info: ConfigYAML):
}
# Save the updated config
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(config, config_file, default_flow_style=False)
# update Router
try:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
except Exception as e:
# Rever to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(
status_code=400, detail=f"Invalid config passed in. Errror - {str(e)}"
)
config = await proxy_config.save_config(new_config=config)
return {"message": "Config updated successfully"}
except HTTPException as e:
raise e

View file

@ -25,4 +25,9 @@ model LiteLLM_VerificationToken {
user_id String?
max_parallel_requests Int?
metadata Json @default("{}")
}
model LiteLLM_Config {
param_name String @id
param_value Json?
}

View file

@ -301,20 +301,24 @@ class PrismaClient:
self,
key: str,
value: Any,
db: Literal["users", "keys"],
table_name: Literal["users", "keys", "config"],
):
"""
Generic implementation of get data
"""
try:
if db == "users":
if table_name == "users":
response = await self.db.litellm_usertable.find_first(
where={key: value} # type: ignore
)
elif db == "keys":
elif table_name == "keys":
response = await self.db.litellm_verificationtoken.find_first( # type: ignore
where={key: value} # type: ignore
)
elif table_name == "config":
response = await self.db.litellm_config.find_first( # type: ignore
where={key: value} # type: ignore
)
return response
except Exception as e:
asyncio.create_task(
@ -385,39 +389,66 @@ class PrismaClient:
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def insert_data(self, data: dict):
async def insert_data(
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
):
"""
Add a key to the database. If it already exists, do nothing.
"""
try:
token = data["token"]
hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token
max_budget = db_data.pop("max_budget", None)
user_email = db_data.pop("user_email", None)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
"token": hashed_token,
},
data={
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
new_user_row = await self.db.litellm_usertable.upsert(
where={"user_id": data["user_id"]},
data={
"create": {
"user_id": data["user_id"],
"max_budget": max_budget,
"user_email": user_email,
if table_name == "user+key":
token = data["token"]
hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token
max_budget = db_data.pop("max_budget", None)
user_email = db_data.pop("user_email", None)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
"token": hashed_token,
},
"update": {}, # don't do anything if it already exists
},
)
return new_verification_token
data={
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
new_user_row = await self.db.litellm_usertable.upsert(
where={"user_id": data["user_id"]},
data={
"create": {
"user_id": data["user_id"],
"max_budget": max_budget,
"user_email": user_email,
},
"update": {}, # don't do anything if it already exists
},
)
return new_verification_token
elif table_name == "config":
"""
For each param,
get the existing table values
Add the new values
Update DB
"""
tasks = []
for k, v in data.items():
updated_data = v
updated_data = json.dumps(updated_data)
updated_table_row = self.db.litellm_config.upsert(
where={"param_name": k},
data={
"create": {"param_name": k, "param_value": updated_data},
"update": {"param_value": updated_data},
},
)
tasks.append(updated_table_row)
await asyncio.gather(*tasks)
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
@ -527,6 +558,7 @@ class PrismaClient:
async def disconnect(self):
try:
await self.db.disconnect()
self.connected = False
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)

View file

@ -9,7 +9,7 @@
import sys, re, binascii, struct
import litellm
import dotenv, json, traceback, threading, base64
import dotenv, json, traceback, threading, base64, ast
import subprocess, os
import litellm, openai
import itertools
@ -6621,7 +6621,7 @@ def _is_base64(s):
def get_secret(
secret_name: str,
default_value: Optional[str] = None,
default_value: Optional[Union[str, bool]] = None,
):
key_management_system = litellm._key_management_system
if secret_name.startswith("os.environ/"):
@ -6672,9 +6672,24 @@ def get_secret(
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ
secret = os.getenv(secret_name)
return secret
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
else:
return os.environ.get(secret_name)
secret = os.environ.get(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
except Exception as e:
if default_value is not None:
return default_value