forked from phoenix/litellm-mirror
feat(proxy_server.py): abstract config update/writing and support persisting config in db
allows user to opt into writing to db (SAVE_CONFIG_TO_DB) and removes any api keys before sending to db https://github.com/BerriAI/litellm/issues/1322
This commit is contained in:
parent
6dea0d3115
commit
99d9a825de
4 changed files with 430 additions and 309 deletions
|
@ -502,232 +502,331 @@ async def _run_background_health_check():
|
|||
await asyncio.sleep(health_check_interval)
|
||||
|
||||
|
||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
|
||||
config = {}
|
||||
try:
|
||||
if os.path.exists(config_file_path):
|
||||
class ProxyConfig:
|
||||
"""
|
||||
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
|
||||
global prisma_client, user_config_file_path
|
||||
|
||||
file_path = config_file_path or user_config_file_path
|
||||
if config_file_path is not None:
|
||||
user_config_file_path = config_file_path
|
||||
with open(config_file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
# Load existing config
|
||||
## Yaml
|
||||
if os.path.exists(f"{file_path}"):
|
||||
with open(f"{file_path}", "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Path to config does not exist, Current working directory: {os.getcwd()}, 'os.path.exists({config_file_path})' returned False"
|
||||
config = {
|
||||
"model_list": [],
|
||||
"general_settings": {},
|
||||
"router_settings": {},
|
||||
"litellm_settings": {},
|
||||
}
|
||||
|
||||
## DB
|
||||
if (
|
||||
prisma_client is not None
|
||||
and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True
|
||||
):
|
||||
_tasks = []
|
||||
keys = [
|
||||
"model_list",
|
||||
"general_settings",
|
||||
"router_settings",
|
||||
"litellm_settings",
|
||||
]
|
||||
for k in keys:
|
||||
response = prisma_client.get_generic_data(
|
||||
key="param_name", value=k, table_name="config"
|
||||
)
|
||||
_tasks.append(response)
|
||||
|
||||
responses = await asyncio.gather(*_tasks)
|
||||
|
||||
return config
|
||||
|
||||
async def save_config(self, new_config: dict):
|
||||
global prisma_client, llm_router, user_config_file_path
|
||||
# Load existing config
|
||||
backup_config = await self.get_config()
|
||||
|
||||
# Save the updated config
|
||||
## YAML
|
||||
with open(f"{user_config_file_path}", "w") as config_file:
|
||||
yaml.dump(new_config, config_file, default_flow_style=False)
|
||||
|
||||
# update Router - verifies if this is a valid config
|
||||
try:
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=user_config_file_path
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Exception while reading Config: {e}")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
# Revert to old config instead
|
||||
with open(f"{user_config_file_path}", "w") as config_file:
|
||||
yaml.dump(backup_config, config_file, default_flow_style=False)
|
||||
raise HTTPException(status_code=400, detail="Invalid config passed in")
|
||||
|
||||
## PRINT YAML FOR CONFIRMING IT WORKS
|
||||
printed_yaml = copy.deepcopy(config)
|
||||
printed_yaml.pop("environment_variables", None)
|
||||
## DB - writes valid config to db
|
||||
"""
|
||||
- Do not write restricted params like 'api_key' to the database
|
||||
- if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`)
|
||||
"""
|
||||
if (
|
||||
prisma_client is not None
|
||||
and litellm.get_secret("SAVE_CONFIG_TO_DB", default_value=False) == True
|
||||
):
|
||||
### KEY REMOVAL ###
|
||||
models = new_config.get("model_list", [])
|
||||
for m in models:
|
||||
if m.get("litellm_params", {}).get("api_key", None) is not None:
|
||||
# pop the key
|
||||
api_key = m["litellm_params"].pop("api_key")
|
||||
# store in local env
|
||||
key_name = f"LITELLM_MODEL_KEY_{uuid.uuid4()}"
|
||||
os.environ[key_name] = api_key
|
||||
# save the key name (not the value)
|
||||
m["litellm_params"]["api_key"] = f"os.environ/{key_name}"
|
||||
await prisma_client.insert_data(data=new_config, table_name="config")
|
||||
|
||||
print_verbose(
|
||||
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
|
||||
)
|
||||
async def load_config(
|
||||
self, router: Optional[litellm.Router], config_file_path: str
|
||||
):
|
||||
"""
|
||||
Load config values into proxy global state
|
||||
"""
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
|
||||
|
||||
## ENVIRONMENT VARIABLES
|
||||
environment_variables = config.get("environment_variables", None)
|
||||
if environment_variables:
|
||||
for key, value in environment_variables.items():
|
||||
os.environ[key] = value
|
||||
# Load existing config
|
||||
config = await self.get_config(config_file_path=config_file_path)
|
||||
## PRINT YAML FOR CONFIRMING IT WORKS
|
||||
printed_yaml = copy.deepcopy(config)
|
||||
printed_yaml.pop("environment_variables", None)
|
||||
|
||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
litellm_settings = config.get("litellm_settings", None)
|
||||
if litellm_settings is None:
|
||||
litellm_settings = {}
|
||||
if litellm_settings:
|
||||
# ANSI escape code for blue text
|
||||
blue_color_code = "\033[94m"
|
||||
reset_color_code = "\033[0m"
|
||||
for key, value in litellm_settings.items():
|
||||
if key == "cache":
|
||||
print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa
|
||||
from litellm.caching import Cache
|
||||
print_verbose(
|
||||
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
|
||||
)
|
||||
|
||||
cache_params = {}
|
||||
if "cache_params" in litellm_settings:
|
||||
cache_params_in_config = litellm_settings["cache_params"]
|
||||
# overwrie cache_params with cache_params_in_config
|
||||
cache_params.update(cache_params_in_config)
|
||||
## ENVIRONMENT VARIABLES
|
||||
environment_variables = config.get("environment_variables", None)
|
||||
if environment_variables:
|
||||
for key, value in environment_variables.items():
|
||||
os.environ[key] = value
|
||||
|
||||
cache_type = cache_params.get("type", "redis")
|
||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
litellm_settings = config.get("litellm_settings", None)
|
||||
if litellm_settings is None:
|
||||
litellm_settings = {}
|
||||
if litellm_settings:
|
||||
# ANSI escape code for blue text
|
||||
blue_color_code = "\033[94m"
|
||||
reset_color_code = "\033[0m"
|
||||
for key, value in litellm_settings.items():
|
||||
if key == "cache":
|
||||
print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa
|
||||
from litellm.caching import Cache
|
||||
|
||||
print_verbose(f"passed cache type={cache_type}")
|
||||
cache_params = {}
|
||||
if "cache_params" in litellm_settings:
|
||||
cache_params_in_config = litellm_settings["cache_params"]
|
||||
# overwrie cache_params with cache_params_in_config
|
||||
cache_params.update(cache_params_in_config)
|
||||
|
||||
if cache_type == "redis":
|
||||
cache_host = litellm.get_secret("REDIS_HOST", None)
|
||||
cache_port = litellm.get_secret("REDIS_PORT", None)
|
||||
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
||||
cache_type = cache_params.get("type", "redis")
|
||||
|
||||
cache_params = {
|
||||
"type": cache_type,
|
||||
"host": cache_host,
|
||||
"port": cache_port,
|
||||
"password": cache_password,
|
||||
}
|
||||
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
||||
print_verbose(f"passed cache type={cache_type}")
|
||||
|
||||
if cache_type == "redis":
|
||||
cache_host = litellm.get_secret("REDIS_HOST", None)
|
||||
cache_port = litellm.get_secret("REDIS_PORT", None)
|
||||
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
||||
|
||||
cache_params = {
|
||||
"type": cache_type,
|
||||
"host": cache_host,
|
||||
"port": cache_port,
|
||||
"password": cache_password,
|
||||
}
|
||||
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
||||
print( # noqa
|
||||
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
|
||||
) # noqa
|
||||
print( # noqa
|
||||
f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}"
|
||||
) # noqa
|
||||
print( # noqa
|
||||
f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}"
|
||||
) # noqa
|
||||
print( # noqa
|
||||
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
|
||||
)
|
||||
print() # noqa
|
||||
|
||||
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
||||
litellm.cache = Cache(**cache_params)
|
||||
print( # noqa
|
||||
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
|
||||
) # noqa
|
||||
print( # noqa
|
||||
f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}"
|
||||
) # noqa
|
||||
print( # noqa
|
||||
f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}"
|
||||
) # noqa
|
||||
print( # noqa
|
||||
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
|
||||
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
|
||||
)
|
||||
print() # noqa
|
||||
elif key == "callbacks":
|
||||
litellm.callbacks = [
|
||||
get_instance_fn(value=value, config_file_path=config_file_path)
|
||||
]
|
||||
print_verbose(
|
||||
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
|
||||
)
|
||||
elif key == "post_call_rules":
|
||||
litellm.post_call_rules = [
|
||||
get_instance_fn(value=value, config_file_path=config_file_path)
|
||||
]
|
||||
print_verbose(f"litellm.post_call_rules: {litellm.post_call_rules}")
|
||||
elif key == "success_callback":
|
||||
litellm.success_callback = []
|
||||
|
||||
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
||||
litellm.cache = Cache(**cache_params)
|
||||
print( # noqa
|
||||
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
|
||||
)
|
||||
elif key == "callbacks":
|
||||
litellm.callbacks = [
|
||||
get_instance_fn(value=value, config_file_path=config_file_path)
|
||||
]
|
||||
print_verbose(
|
||||
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
|
||||
)
|
||||
elif key == "post_call_rules":
|
||||
litellm.post_call_rules = [
|
||||
get_instance_fn(value=value, config_file_path=config_file_path)
|
||||
]
|
||||
print_verbose(f"litellm.post_call_rules: {litellm.post_call_rules}")
|
||||
elif key == "success_callback":
|
||||
litellm.success_callback = []
|
||||
# intialize success callbacks
|
||||
for callback in value:
|
||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||
if "." in callback:
|
||||
litellm.success_callback.append(
|
||||
get_instance_fn(value=callback)
|
||||
)
|
||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||
else:
|
||||
litellm.success_callback.append(callback)
|
||||
print_verbose(
|
||||
f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}"
|
||||
)
|
||||
elif key == "failure_callback":
|
||||
litellm.failure_callback = []
|
||||
|
||||
# intialize success callbacks
|
||||
for callback in value:
|
||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||
if "." in callback:
|
||||
litellm.success_callback.append(get_instance_fn(value=callback))
|
||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||
else:
|
||||
litellm.success_callback.append(callback)
|
||||
print_verbose(
|
||||
f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}"
|
||||
)
|
||||
elif key == "failure_callback":
|
||||
litellm.failure_callback = []
|
||||
# intialize success callbacks
|
||||
for callback in value:
|
||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||
if "." in callback:
|
||||
litellm.failure_callback.append(
|
||||
get_instance_fn(value=callback)
|
||||
)
|
||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||
else:
|
||||
litellm.failure_callback.append(callback)
|
||||
print_verbose(
|
||||
f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}"
|
||||
)
|
||||
elif key == "cache_params":
|
||||
# this is set in the cache branch
|
||||
# see usage here: https://docs.litellm.ai/docs/proxy/caching
|
||||
pass
|
||||
else:
|
||||
setattr(litellm, key, value)
|
||||
|
||||
# intialize success callbacks
|
||||
for callback in value:
|
||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||
if "." in callback:
|
||||
litellm.failure_callback.append(get_instance_fn(value=callback))
|
||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||
else:
|
||||
litellm.failure_callback.append(callback)
|
||||
print_verbose(
|
||||
f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}"
|
||||
)
|
||||
elif key == "cache_params":
|
||||
# this is set in the cache branch
|
||||
# see usage here: https://docs.litellm.ai/docs/proxy/caching
|
||||
pass
|
||||
else:
|
||||
setattr(litellm, key, value)
|
||||
|
||||
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
|
||||
general_settings = config.get("general_settings", {})
|
||||
if general_settings is None:
|
||||
general_settings = {}
|
||||
if general_settings:
|
||||
### LOAD SECRET MANAGER ###
|
||||
key_management_system = general_settings.get("key_management_system", None)
|
||||
if key_management_system is not None:
|
||||
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
|
||||
### LOAD FROM AZURE KEY VAULT ###
|
||||
load_from_azure_key_vault(use_azure_key_vault=True)
|
||||
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
|
||||
### LOAD FROM GOOGLE KMS ###
|
||||
load_google_kms(use_google_kms=True)
|
||||
else:
|
||||
raise ValueError("Invalid Key Management System selected")
|
||||
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
|
||||
use_google_kms = general_settings.get("use_google_kms", False)
|
||||
load_google_kms(use_google_kms=use_google_kms)
|
||||
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
|
||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||
### ALERTING ###
|
||||
proxy_logging_obj.update_values(
|
||||
alerting=general_settings.get("alerting", None),
|
||||
alerting_threshold=general_settings.get("alerting_threshold", 600),
|
||||
)
|
||||
### CONNECT TO DATABASE ###
|
||||
database_url = general_settings.get("database_url", None)
|
||||
if database_url and database_url.startswith("os.environ/"):
|
||||
print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
|
||||
database_url = litellm.get_secret(database_url)
|
||||
print_verbose(f"RETRIEVED DB URL: {database_url}")
|
||||
prisma_setup(database_url=database_url)
|
||||
## COST TRACKING ##
|
||||
cost_tracking()
|
||||
### MASTER KEY ###
|
||||
master_key = general_settings.get(
|
||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||
)
|
||||
if master_key and master_key.startswith("os.environ/"):
|
||||
master_key = litellm.get_secret(master_key)
|
||||
### CUSTOM API KEY AUTH ###
|
||||
custom_auth = general_settings.get("custom_auth", None)
|
||||
if custom_auth:
|
||||
user_custom_auth = get_instance_fn(
|
||||
value=custom_auth, config_file_path=config_file_path
|
||||
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
|
||||
general_settings = config.get("general_settings", {})
|
||||
if general_settings is None:
|
||||
general_settings = {}
|
||||
if general_settings:
|
||||
### LOAD SECRET MANAGER ###
|
||||
key_management_system = general_settings.get("key_management_system", None)
|
||||
if key_management_system is not None:
|
||||
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
|
||||
### LOAD FROM AZURE KEY VAULT ###
|
||||
load_from_azure_key_vault(use_azure_key_vault=True)
|
||||
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
|
||||
### LOAD FROM GOOGLE KMS ###
|
||||
load_google_kms(use_google_kms=True)
|
||||
else:
|
||||
raise ValueError("Invalid Key Management System selected")
|
||||
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
|
||||
use_google_kms = general_settings.get("use_google_kms", False)
|
||||
load_google_kms(use_google_kms=use_google_kms)
|
||||
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
|
||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||
### ALERTING ###
|
||||
proxy_logging_obj.update_values(
|
||||
alerting=general_settings.get("alerting", None),
|
||||
alerting_threshold=general_settings.get("alerting_threshold", 600),
|
||||
)
|
||||
### BACKGROUND HEALTH CHECKS ###
|
||||
# Enable background health checks
|
||||
use_background_health_checks = general_settings.get(
|
||||
"background_health_checks", False
|
||||
)
|
||||
health_check_interval = general_settings.get("health_check_interval", 300)
|
||||
### CONNECT TO DATABASE ###
|
||||
database_url = general_settings.get("database_url", None)
|
||||
if database_url and database_url.startswith("os.environ/"):
|
||||
print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
|
||||
database_url = litellm.get_secret(database_url)
|
||||
print_verbose(f"RETRIEVED DB URL: {database_url}")
|
||||
prisma_setup(database_url=database_url)
|
||||
## COST TRACKING ##
|
||||
cost_tracking()
|
||||
### MASTER KEY ###
|
||||
master_key = general_settings.get(
|
||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||
)
|
||||
if master_key and master_key.startswith("os.environ/"):
|
||||
master_key = litellm.get_secret(master_key)
|
||||
### CUSTOM API KEY AUTH ###
|
||||
custom_auth = general_settings.get("custom_auth", None)
|
||||
if custom_auth:
|
||||
user_custom_auth = get_instance_fn(
|
||||
value=custom_auth, config_file_path=config_file_path
|
||||
)
|
||||
### BACKGROUND HEALTH CHECKS ###
|
||||
# Enable background health checks
|
||||
use_background_health_checks = general_settings.get(
|
||||
"background_health_checks", False
|
||||
)
|
||||
health_check_interval = general_settings.get("health_check_interval", 300)
|
||||
|
||||
router_params: dict = {
|
||||
"num_retries": 3,
|
||||
"cache_responses": litellm.cache
|
||||
!= None, # cache if user passed in cache values
|
||||
}
|
||||
## MODEL LIST
|
||||
model_list = config.get("model_list", None)
|
||||
if model_list:
|
||||
router_params["model_list"] = model_list
|
||||
print( # noqa
|
||||
f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m"
|
||||
) # noqa
|
||||
for model in model_list:
|
||||
### LOAD FROM os.environ/ ###
|
||||
for k, v in model["litellm_params"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
model["litellm_params"][k] = litellm.get_secret(v)
|
||||
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
|
||||
litellm_model_name = model["litellm_params"]["model"]
|
||||
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
||||
if "ollama" in litellm_model_name and litellm_model_api_base is None:
|
||||
run_ollama_serve()
|
||||
|
||||
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
||||
router_settings = config.get("router_settings", None)
|
||||
if router_settings and isinstance(router_settings, dict):
|
||||
arg_spec = inspect.getfullargspec(litellm.Router)
|
||||
# model list already set
|
||||
exclude_args = {
|
||||
"self",
|
||||
"model_list",
|
||||
router_params: dict = {
|
||||
"num_retries": 3,
|
||||
"cache_responses": litellm.cache
|
||||
!= None, # cache if user passed in cache values
|
||||
}
|
||||
## MODEL LIST
|
||||
model_list = config.get("model_list", None)
|
||||
if model_list:
|
||||
router_params["model_list"] = model_list
|
||||
print( # noqa
|
||||
f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m"
|
||||
) # noqa
|
||||
for model in model_list:
|
||||
### LOAD FROM os.environ/ ###
|
||||
for k, v in model["litellm_params"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
model["litellm_params"][k] = litellm.get_secret(v)
|
||||
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
|
||||
litellm_model_name = model["litellm_params"]["model"]
|
||||
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
||||
if "ollama" in litellm_model_name and litellm_model_api_base is None:
|
||||
run_ollama_serve()
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args]
|
||||
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
||||
router_settings = config.get("router_settings", None)
|
||||
if router_settings and isinstance(router_settings, dict):
|
||||
arg_spec = inspect.getfullargspec(litellm.Router)
|
||||
# model list already set
|
||||
exclude_args = {
|
||||
"self",
|
||||
"model_list",
|
||||
}
|
||||
|
||||
for k, v in router_settings.items():
|
||||
if k in available_args:
|
||||
router_params[k] = v
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args]
|
||||
|
||||
router = litellm.Router(**router_params) # type:ignore
|
||||
return router, model_list, general_settings
|
||||
for k, v in router_settings.items():
|
||||
if k in available_args:
|
||||
router_params[k] = v
|
||||
|
||||
router = litellm.Router(**router_params) # type:ignore
|
||||
return router, model_list, general_settings
|
||||
|
||||
|
||||
proxy_config = ProxyConfig()
|
||||
|
||||
|
||||
async def generate_key_helper_fn(
|
||||
|
@ -856,10 +955,6 @@ def initialize(
|
|||
if debug == True: # this needs to be first, so users can see Router init debugg
|
||||
litellm.set_verbose = True
|
||||
dynamic_config = {"general": {}, user_model: {}}
|
||||
if config:
|
||||
llm_router, llm_model_list, general_settings = load_router_config(
|
||||
router=llm_router, config_file_path=config
|
||||
)
|
||||
if headers: # model-specific param
|
||||
user_headers = headers
|
||||
dynamic_config[user_model]["headers"] = headers
|
||||
|
@ -988,7 +1083,7 @@ def parse_cache_control(cache_control):
|
|||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key, use_background_health_checks
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings
|
||||
import json
|
||||
|
||||
### LOAD MASTER KEY ###
|
||||
|
@ -1000,10 +1095,26 @@ async def startup_event():
|
|||
print_verbose(f"worker_config: {worker_config}")
|
||||
# check if it's a valid file path
|
||||
if os.path.isfile(worker_config):
|
||||
initialize(config=worker_config)
|
||||
if worker_config.get("config", None) is not None:
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config.pop("config")
|
||||
)
|
||||
initialize(**worker_config)
|
||||
else:
|
||||
# if not, assume it's a json string
|
||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||
if worker_config.get("config", None) is not None:
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config.pop("config")
|
||||
)
|
||||
initialize(**worker_config)
|
||||
|
||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
@ -1825,7 +1936,7 @@ async def user_auth(request: Request):
|
|||
|
||||
### Check if user email in user table
|
||||
response = await prisma_client.get_generic_data(
|
||||
key="user_email", value=user_email, db="users"
|
||||
key="user_email", value=user_email, table_name="users"
|
||||
)
|
||||
### if so - generate a 24 hr key with that user id
|
||||
if response is not None:
|
||||
|
@ -1883,16 +1994,13 @@ async def user_update(request: Request):
|
|||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def add_new_model(model_params: ModelParams):
|
||||
global llm_router, llm_model_list, general_settings, user_config_file_path
|
||||
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||
try:
|
||||
print_verbose(f"User config path: {user_config_file_path}")
|
||||
# Load existing config
|
||||
if os.path.exists(f"{user_config_file_path}"):
|
||||
with open(f"{user_config_file_path}", "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
else:
|
||||
config = {"model_list": []}
|
||||
backup_config = copy.deepcopy(config)
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
print_verbose(f"User config path: {user_config_file_path}")
|
||||
|
||||
print_verbose(f"Loaded config: {config}")
|
||||
# Add the new model to the config
|
||||
model_info = model_params.model_info.json()
|
||||
|
@ -1907,22 +2015,8 @@ async def add_new_model(model_params: ModelParams):
|
|||
|
||||
print_verbose(f"updated model list: {config['model_list']}")
|
||||
|
||||
# Save the updated config
|
||||
with open(f"{user_config_file_path}", "w") as config_file:
|
||||
yaml.dump(config, config_file, default_flow_style=False)
|
||||
|
||||
# update Router
|
||||
try:
|
||||
llm_router, llm_model_list, general_settings = load_router_config(
|
||||
router=llm_router, config_file_path=user_config_file_path
|
||||
)
|
||||
except Exception as e:
|
||||
# Rever to old config instead
|
||||
with open(f"{user_config_file_path}", "w") as config_file:
|
||||
yaml.dump(backup_config, config_file, default_flow_style=False)
|
||||
raise HTTPException(status_code=400, detail="Invalid Model passed in")
|
||||
|
||||
print_verbose(f"llm_model_list: {llm_model_list}")
|
||||
# Save new config
|
||||
await proxy_config.save_config(new_config=config)
|
||||
return {"message": "Model added successfully"}
|
||||
|
||||
except Exception as e:
|
||||
|
@ -1949,13 +2043,10 @@ async def add_new_model(model_params: ModelParams):
|
|||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def model_info_v1(request: Request):
|
||||
global llm_model_list, general_settings, user_config_file_path
|
||||
global llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||
|
||||
# Load existing config
|
||||
if os.path.exists(f"{user_config_file_path}"):
|
||||
with open(f"{user_config_file_path}", "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
else:
|
||||
config = {"model_list": []} # handle base case
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
all_models = config["model_list"]
|
||||
for model in all_models:
|
||||
|
@ -1984,18 +2075,18 @@ async def model_info_v1(request: Request):
|
|||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_model(model_info: ModelInfoDelete):
|
||||
global llm_router, llm_model_list, general_settings, user_config_file_path
|
||||
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||
try:
|
||||
if not os.path.exists(user_config_file_path):
|
||||
raise HTTPException(status_code=404, detail="Config file does not exist.")
|
||||
|
||||
with open(user_config_file_path, "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
# If model_list is not in the config, nothing can be deleted
|
||||
if "model_list" not in config:
|
||||
if len(config.get("model_list", [])) == 0:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No model list available in the config."
|
||||
status_code=400, detail="No model list available in the config."
|
||||
)
|
||||
|
||||
# Check if the model with the specified model_id exists
|
||||
|
@ -2008,19 +2099,14 @@ async def delete_model(model_info: ModelInfoDelete):
|
|||
# If the model was not found, return an error
|
||||
if model_to_delete is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Model with given model_id not found."
|
||||
status_code=400, detail="Model with given model_id not found."
|
||||
)
|
||||
|
||||
# Remove model from the list and save the updated config
|
||||
config["model_list"].remove(model_to_delete)
|
||||
with open(user_config_file_path, "w") as config_file:
|
||||
yaml.dump(config, config_file, default_flow_style=False)
|
||||
|
||||
# Update Router
|
||||
llm_router, llm_model_list, general_settings = load_router_config(
|
||||
router=llm_router, config_file_path=user_config_file_path
|
||||
)
|
||||
|
||||
# Save updated config
|
||||
config = await proxy_config.save_config(new_config=config)
|
||||
return {"message": "Model deleted successfully"}
|
||||
|
||||
except HTTPException as e:
|
||||
|
@ -2200,14 +2286,11 @@ async def update_config(config_info: ConfigYAML):
|
|||
|
||||
Currently supports modifying General Settings + LiteLLM settings
|
||||
"""
|
||||
global llm_router, llm_model_list, general_settings
|
||||
global llm_router, llm_model_list, general_settings, proxy_config
|
||||
try:
|
||||
# Load existing config
|
||||
if os.path.exists(f"{user_config_file_path}"):
|
||||
with open(f"{user_config_file_path}", "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
else:
|
||||
config = {}
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
backup_config = copy.deepcopy(config)
|
||||
print_verbose(f"Loaded config: {config}")
|
||||
|
||||
|
@ -2240,21 +2323,7 @@ async def update_config(config_info: ConfigYAML):
|
|||
}
|
||||
|
||||
# Save the updated config
|
||||
with open(f"{user_config_file_path}", "w") as config_file:
|
||||
yaml.dump(config, config_file, default_flow_style=False)
|
||||
|
||||
# update Router
|
||||
try:
|
||||
llm_router, llm_model_list, general_settings = load_router_config(
|
||||
router=llm_router, config_file_path=user_config_file_path
|
||||
)
|
||||
except Exception as e:
|
||||
# Rever to old config instead
|
||||
with open(f"{user_config_file_path}", "w") as config_file:
|
||||
yaml.dump(backup_config, config_file, default_flow_style=False)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid config passed in. Errror - {str(e)}"
|
||||
)
|
||||
config = await proxy_config.save_config(new_config=config)
|
||||
return {"message": "Config updated successfully"}
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
|
|
|
@ -25,4 +25,9 @@ model LiteLLM_VerificationToken {
|
|||
user_id String?
|
||||
max_parallel_requests Int?
|
||||
metadata Json @default("{}")
|
||||
}
|
||||
|
||||
model LiteLLM_Config {
|
||||
param_name String @id
|
||||
param_value Json?
|
||||
}
|
|
@ -301,20 +301,24 @@ class PrismaClient:
|
|||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
db: Literal["users", "keys"],
|
||||
table_name: Literal["users", "keys", "config"],
|
||||
):
|
||||
"""
|
||||
Generic implementation of get data
|
||||
"""
|
||||
try:
|
||||
if db == "users":
|
||||
if table_name == "users":
|
||||
response = await self.db.litellm_usertable.find_first(
|
||||
where={key: value} # type: ignore
|
||||
)
|
||||
elif db == "keys":
|
||||
elif table_name == "keys":
|
||||
response = await self.db.litellm_verificationtoken.find_first( # type: ignore
|
||||
where={key: value} # type: ignore
|
||||
)
|
||||
elif table_name == "config":
|
||||
response = await self.db.litellm_config.find_first( # type: ignore
|
||||
where={key: value} # type: ignore
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
|
@ -385,39 +389,66 @@ class PrismaClient:
|
|||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def insert_data(self, data: dict):
|
||||
async def insert_data(
|
||||
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
|
||||
):
|
||||
"""
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
"""
|
||||
try:
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
db_data = self.jsonify_object(data=data)
|
||||
db_data["token"] = hashed_token
|
||||
max_budget = db_data.pop("max_budget", None)
|
||||
user_email = db_data.pop("user_email", None)
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
"token": hashed_token,
|
||||
},
|
||||
data={
|
||||
"create": {**db_data}, # type: ignore
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
|
||||
new_user_row = await self.db.litellm_usertable.upsert(
|
||||
where={"user_id": data["user_id"]},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": data["user_id"],
|
||||
"max_budget": max_budget,
|
||||
"user_email": user_email,
|
||||
if table_name == "user+key":
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
db_data = self.jsonify_object(data=data)
|
||||
db_data["token"] = hashed_token
|
||||
max_budget = db_data.pop("max_budget", None)
|
||||
user_email = db_data.pop("user_email", None)
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
"token": hashed_token,
|
||||
},
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
return new_verification_token
|
||||
data={
|
||||
"create": {**db_data}, # type: ignore
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
|
||||
new_user_row = await self.db.litellm_usertable.upsert(
|
||||
where={"user_id": data["user_id"]},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": data["user_id"],
|
||||
"max_budget": max_budget,
|
||||
"user_email": user_email,
|
||||
},
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
return new_verification_token
|
||||
elif table_name == "config":
|
||||
"""
|
||||
For each param,
|
||||
get the existing table values
|
||||
|
||||
Add the new values
|
||||
|
||||
Update DB
|
||||
"""
|
||||
tasks = []
|
||||
for k, v in data.items():
|
||||
updated_data = v
|
||||
updated_data = json.dumps(updated_data)
|
||||
updated_table_row = self.db.litellm_config.upsert(
|
||||
where={"param_name": k},
|
||||
data={
|
||||
"create": {"param_name": k, "param_value": updated_data},
|
||||
"update": {"param_value": updated_data},
|
||||
},
|
||||
)
|
||||
|
||||
tasks.append(updated_table_row)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
|
@ -527,6 +558,7 @@ class PrismaClient:
|
|||
async def disconnect(self):
|
||||
try:
|
||||
await self.db.disconnect()
|
||||
self.connected = False
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
import sys, re, binascii, struct
|
||||
import litellm
|
||||
import dotenv, json, traceback, threading, base64
|
||||
import dotenv, json, traceback, threading, base64, ast
|
||||
import subprocess, os
|
||||
import litellm, openai
|
||||
import itertools
|
||||
|
@ -6621,7 +6621,7 @@ def _is_base64(s):
|
|||
|
||||
def get_secret(
|
||||
secret_name: str,
|
||||
default_value: Optional[str] = None,
|
||||
default_value: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
if secret_name.startswith("os.environ/"):
|
||||
|
@ -6672,9 +6672,24 @@ def get_secret(
|
|||
secret = client.get_secret(secret_name).secret_value
|
||||
except Exception as e: # check if it's in os.environ
|
||||
secret = os.getenv(secret_name)
|
||||
return secret
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except:
|
||||
return secret
|
||||
else:
|
||||
return os.environ.get(secret_name)
|
||||
secret = os.environ.get(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except:
|
||||
return secret
|
||||
except Exception as e:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue