fix(proxy_server.py): allow for no models in model_list - all models passed in via /model/new

This commit is contained in:
Krrish Dholakia 2024-04-08 16:17:34 -07:00
parent 0d925a6c55
commit 8f1872eaf3

View file

@ -305,6 +305,7 @@ litellm_master_key_hash = None
disable_spend_logs = False
jwt_handler = JWTHandler()
prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None
store_model_in_db: bool = False
### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
### REDIS QUEUE ###
@ -1922,7 +1923,7 @@ class ProxyConfig:
"""
Load config values into proxy global state
"""
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db
# Load existing config
config = await self.get_config(config_file_path=config_file_path)
@ -2262,6 +2263,10 @@ class ProxyConfig:
if master_key is not None and isinstance(master_key, str):
litellm_master_key_hash = master_key
### STORE MODEL IN DB ### feature flag for `/model/new`
store_model_in_db = general_settings.get("store_model_in_db", False)
if store_model_in_db is None:
store_model_in_db = False
### CUSTOM API KEY AUTH ###
## pass filepath
custom_auth = general_settings.get("custom_auth", None)
@ -2327,13 +2332,6 @@ class ProxyConfig:
"background_health_checks", False
)
health_check_interval = general_settings.get("health_check_interval", 300)
## check env ##
_store_model_in_db = litellm.get_secret(
"STORE_MODEL_IN_DB", None
) # feature flag for `/model/new`
verbose_proxy_logger.info(f"'STORE_MODEL_IN_DB'={_store_model_in_db}")
if _store_model_in_db is not None and _store_model_in_db == True:
general_settings["store_model_in_db"] = True
router_params: dict = {
"cache_responses": litellm.cache
!= None, # cache if user passed in cache values
@ -2390,55 +2388,95 @@ class ProxyConfig:
import base64
try:
if llm_router is None:
raise Exception("No router initialized")
if master_key is None or not isinstance(master_key, str):
raise Exception(
f"Master key is not initialized or formatted. master_key={master_key}"
)
new_models = await prisma_client.db.litellm_proxymodeltable.find_many(
take=10, order={"updated_at": "desc"}
)
if llm_router is None:
new_models = (
await prisma_client.db.litellm_proxymodeltable.find_many()
) # get all models in db
_model_list: list = []
for m in new_models:
_litellm_params = m.litellm_params
if isinstance(_litellm_params, dict):
# decrypt values
for k, v in _litellm_params.items():
if isinstance(v, str):
# decode base64
decoded_b64 = base64.b64decode(v)
# decrypt value
_litellm_params[k] = decrypt_value(
value=decoded_b64, master_key=master_key
)
_litellm_params = LiteLLM_Params(**_litellm_params)
else:
verbose_proxy_logger.error(
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
)
continue # skip to next model
for m in new_models:
_litellm_params = m.litellm_params
if isinstance(_litellm_params, dict):
# decrypt values
for k, v in _litellm_params.items():
if isinstance(v, str):
# decode base64
decoded_b64 = base64.b64decode(v)
# decrypt value
_litellm_params[k] = decrypt_value(
value=decoded_b64, master_key=master_key
)
_litellm_params = LiteLLM_Params(**_litellm_params)
else:
verbose_proxy_logger.error(
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
if m.model_info is not None and isinstance(m.model_info, dict):
if "id" not in m.model_info:
m.model_info["id"] = m.model_id
_model_info = RouterModelInfo(**m.model_info)
else:
_model_info = RouterModelInfo(id=m.model_id)
_model_list.append(
Deployment(
model_name=m.model_name,
litellm_params=_litellm_params,
model_info=_model_info,
).to_json(exclude_none=True)
)
continue # skip to next model
if m.model_info is not None and isinstance(m.model_info, dict):
if "id" not in m.model_info:
m.model_info["id"] = m.model_id
_model_info = RouterModelInfo(**m.model_info)
else:
_model_info = RouterModelInfo(id=m.model_id)
llm_router.add_deployment(
deployment=Deployment(
model_name=m.model_name,
litellm_params=_litellm_params,
model_info=_model_info,
)
llm_router = litellm.Router(model_list=_model_list)
else:
new_models = await prisma_client.db.litellm_proxymodeltable.find_many(
take=10, order={"updated_at": "desc"}
)
for m in new_models:
_litellm_params = m.litellm_params
if isinstance(_litellm_params, dict):
# decrypt values
for k, v in _litellm_params.items():
if isinstance(v, str):
# decode base64
decoded_b64 = base64.b64decode(v)
# decrypt value
_litellm_params[k] = decrypt_value(
value=decoded_b64, master_key=master_key
)
_litellm_params = LiteLLM_Params(**_litellm_params)
else:
verbose_proxy_logger.error(
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
)
continue # skip to next model
if m.model_info is not None and isinstance(m.model_info, dict):
if "id" not in m.model_info:
m.model_info["id"] = m.model_id
_model_info = RouterModelInfo(**m.model_info)
else:
_model_info = RouterModelInfo(id=m.model_id)
llm_router.add_deployment(
deployment=Deployment(
model_name=m.model_name,
litellm_params=_litellm_params,
model_info=_model_info,
)
)
llm_model_list = llm_router.get_model_list()
except Exception as e:
verbose_proxy_logger.error("{}".format(str(e)))
verbose_proxy_logger.error(
"{}\nTraceback:{}".format(str(e), traceback.format_exc())
)
proxy_config = ProxyConfig()
@ -2894,7 +2932,7 @@ def on_backoff(details):
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db
import json
### LOAD MASTER KEY ###
@ -3043,7 +3081,8 @@ async def startup_event():
)
### ADD NEW MODELS ###
if general_settings.get("store_model_in_db", False) == True:
store_model_in_db = litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db)
if store_model_in_db == True:
scheduler.add_job(
proxy_config.add_deployment,
"interval",
@ -6854,7 +6893,7 @@ async def add_new_model(
model_params: Deployment,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db
try:
import base64
@ -6869,7 +6908,7 @@ async def add_new_model(
)
# update DB
if general_settings.get("store_model_in_db", False) == True:
if store_model_in_db == True:
"""
- store model_list in db
- store keys separately
@ -7156,7 +7195,7 @@ async def delete_model(model_info: ModelInfoDelete):
)
# update DB
if general_settings.get("store_model_in_db", False) == True:
if store_model_in_db == True:
"""
- store model_list in db
- store keys separately