fix(proxy_server.py): persist models added via /model/new to db

allows models to be used across instances

https://github.com/BerriAI/litellm/issues/2319 , https://github.com/BerriAI/litellm/issues/2329
This commit is contained in:
Krrish Dholakia 2024-04-03 20:16:41 -07:00
parent 24d9fcb32c
commit f536fb13e6
7 changed files with 435 additions and 86 deletions

View file

@ -97,6 +97,8 @@ from litellm.proxy.utils import (
_is_projected_spend_over_limit,
_get_projected_spend_over_limit,
update_spend,
encrypt_value,
decrypt_value,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
@ -104,6 +106,8 @@ import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache
from litellm.proxy.health_check import perform_health_check
from litellm.router import LiteLLM_Params, Deployment
from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.hooks.prompt_injection_detection import (
@ -2371,6 +2375,64 @@ class ProxyConfig:
router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings
async def add_deployment(
self,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
):
"""
- Check db for new models (last 10 most recently updated)
- Check if model id's in router already
- If not, add to router
"""
global llm_router, llm_model_list, master_key
import base64
try:
if llm_router is None:
raise Exception("No router initialized")
new_models = await prisma_client.db.litellm_proxymodeltable.find_many(
take=10, order={"updated_at": "desc"}
)
for m in new_models:
_litellm_params = m.litellm_params
if isinstance(_litellm_params, dict):
# decrypt values
for k, v in _litellm_params.items():
# decode base64
decoded_b64 = base64.b64decode(v)
# decrypt value
_litellm_params[k] = decrypt_value(
value=decoded_b64, master_key=master_key
)
_litellm_params = LiteLLM_Params(**_litellm_params)
else:
raise Exception(
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
)
if m.model_info is not None and isinstance(m.model_info, dict):
if "id" not in m.model_info:
m.model_info["id"] = m.model_id
_model_info = RouterModelInfo(**m.model_info)
else:
_model_info = RouterModelInfo(id=m.model_id)
llm_router.add_deployment(
deployment=Deployment(
model_name=m.model_name,
litellm_params=_litellm_params,
model_info=_model_info,
)
)
llm_model_list = llm_router.get_model_list()
except Exception as e:
raise e
proxy_config = ProxyConfig()
@ -2943,7 +3005,7 @@ async def startup_event():
if prisma_client is not None:
create_view_response = await prisma_client.check_view_exists()
### START BATCH WRITING DB ###
### START BATCH WRITING DB + CHECKING NEW MODELS###
if prisma_client is not None:
scheduler = AsyncIOScheduler()
interval = random.randint(
@ -2966,6 +3028,15 @@ async def startup_event():
seconds=batch_writing_interval,
args=[prisma_client, db_writer_client, proxy_logging_obj],
)
### ADD NEW MODELS ###
if general_settings.get("store_model_in_db", False) == True:
scheduler.add_job(
proxy_config.add_deployment,
"interval",
seconds=30,
args=[prisma_client, proxy_logging_obj],
)
scheduler.start()
@ -3314,8 +3385,6 @@ async def chat_completion(
)
)
start_time = time.time()
### ROUTE THE REQUEST ###
# Do not change this - it should be a constant time fetch - ALWAYS
router_model_names = llm_router.model_names if llm_router is not None else []
@ -3534,8 +3603,6 @@ async def embeddings(
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
)
start_time = time.time()
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
@ -6691,30 +6758,47 @@ async def info_budget(data: BudgetRequest):
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
)
async def add_new_model(model_params: ModelParams):
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
async def add_new_model(
model_params: ModelParams,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key
try:
# Load existing config
config = await proxy_config.get_config()
import base64
verbose_proxy_logger.debug("User config path: %s", user_config_file_path)
global prisma_client
verbose_proxy_logger.debug("Loaded config: %s", config)
# Add the new model to the config
model_info = model_params.model_info.json()
model_info = {k: v for k, v in model_info.items() if v is not None}
config["model_list"].append(
{
"model_name": model_params.model_name,
"litellm_params": model_params.litellm_params,
"model_info": model_info,
}
)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
},
)
verbose_proxy_logger.debug("updated model list: %s", config["model_list"])
# Save new config
await proxy_config.save_config(new_config=config)
# update DB
if general_settings.get("store_model_in_db", False) == True:
"""
- store model_list in db
- store keys separately
"""
# encrypt litellm params #
for k, v in model_params.litellm_params.items():
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore
model_params.litellm_params[k] = base64.b64encode(
encrypted_value
).decode("utf-8")
await prisma_client.db.litellm_proxymodeltable.create(
data={
"model_name": model_params.model_name,
"litellm_params": json.dumps(model_params.litellm_params), # type: ignore
"model_info": model_params.model_info.model_dump_json( # type: ignore
exclude_none=True
),
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
)
return {"message": "Model added successfully"}
except Exception as e:
@ -6884,14 +6968,16 @@ async def model_info_v1(
):
global llm_model_list, general_settings, user_config_file_path, proxy_config
# Load existing config
config = await proxy_config.get_config()
if llm_model_list is None:
raise HTTPException(
status_code=500, detail={"error": "LLM Model List not loaded in"}
)
if len(user_api_key_dict.models) > 0:
model_names = user_api_key_dict.models
all_models = [m for m in config["model_list"] if m["model_name"] in model_names]
all_models = [m for m in llm_model_list if m["model_name"] in model_names]
else:
all_models = config["model_list"]
all_models = llm_model_list
for model in all_models:
# provided model_info in config.yaml
model_info = model.get("model_info", {})
@ -6956,6 +7042,7 @@ async def delete_model(model_info: ModelInfoDelete):
# Check if the model with the specified model_id exists
model_to_delete = None
for model in config["model_list"]:
if model.get("model_info", {}).get("id", None) == model_info.id:
model_to_delete = model