forked from phoenix/litellm-mirror
fix(proxy_server.py): persist models added via /model/new
to db
allows models to be used across instances https://github.com/BerriAI/litellm/issues/2319 , https://github.com/BerriAI/litellm/issues/2329
This commit is contained in:
parent
24d9fcb32c
commit
f536fb13e6
7 changed files with 435 additions and 86 deletions
|
@ -97,6 +97,8 @@ from litellm.proxy.utils import (
|
|||
_is_projected_spend_over_limit,
|
||||
_get_projected_spend_over_limit,
|
||||
update_spend,
|
||||
encrypt_value,
|
||||
decrypt_value,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
|
@ -104,6 +106,8 @@ import pydantic
|
|||
from litellm.proxy._types import *
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm.router import LiteLLM_Params, Deployment
|
||||
from litellm.router import ModelInfo as RouterModelInfo
|
||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
|
@ -2371,6 +2375,64 @@ class ProxyConfig:
|
|||
router = litellm.Router(**router_params) # type:ignore
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def add_deployment(
|
||||
self,
|
||||
prisma_client: PrismaClient,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
):
|
||||
"""
|
||||
- Check db for new models (last 10 most recently updated)
|
||||
- Check if model id's in router already
|
||||
- If not, add to router
|
||||
"""
|
||||
global llm_router, llm_model_list, master_key
|
||||
|
||||
import base64
|
||||
|
||||
try:
|
||||
if llm_router is None:
|
||||
raise Exception("No router initialized")
|
||||
|
||||
new_models = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||
take=10, order={"updated_at": "desc"}
|
||||
)
|
||||
|
||||
for m in new_models:
|
||||
_litellm_params = m.litellm_params
|
||||
if isinstance(_litellm_params, dict):
|
||||
# decrypt values
|
||||
for k, v in _litellm_params.items():
|
||||
# decode base64
|
||||
decoded_b64 = base64.b64decode(v)
|
||||
# decrypt value
|
||||
_litellm_params[k] = decrypt_value(
|
||||
value=decoded_b64, master_key=master_key
|
||||
)
|
||||
_litellm_params = LiteLLM_Params(**_litellm_params)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
|
||||
)
|
||||
|
||||
if m.model_info is not None and isinstance(m.model_info, dict):
|
||||
if "id" not in m.model_info:
|
||||
m.model_info["id"] = m.model_id
|
||||
_model_info = RouterModelInfo(**m.model_info)
|
||||
else:
|
||||
_model_info = RouterModelInfo(id=m.model_id)
|
||||
|
||||
llm_router.add_deployment(
|
||||
deployment=Deployment(
|
||||
model_name=m.model_name,
|
||||
litellm_params=_litellm_params,
|
||||
model_info=_model_info,
|
||||
)
|
||||
)
|
||||
|
||||
llm_model_list = llm_router.get_model_list()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
proxy_config = ProxyConfig()
|
||||
|
||||
|
@ -2943,7 +3005,7 @@ async def startup_event():
|
|||
if prisma_client is not None:
|
||||
create_view_response = await prisma_client.check_view_exists()
|
||||
|
||||
### START BATCH WRITING DB ###
|
||||
### START BATCH WRITING DB + CHECKING NEW MODELS###
|
||||
if prisma_client is not None:
|
||||
scheduler = AsyncIOScheduler()
|
||||
interval = random.randint(
|
||||
|
@ -2966,6 +3028,15 @@ async def startup_event():
|
|||
seconds=batch_writing_interval,
|
||||
args=[prisma_client, db_writer_client, proxy_logging_obj],
|
||||
)
|
||||
|
||||
### ADD NEW MODELS ###
|
||||
if general_settings.get("store_model_in_db", False) == True:
|
||||
scheduler.add_job(
|
||||
proxy_config.add_deployment,
|
||||
"interval",
|
||||
seconds=30,
|
||||
args=[prisma_client, proxy_logging_obj],
|
||||
)
|
||||
scheduler.start()
|
||||
|
||||
|
||||
|
@ -3314,8 +3385,6 @@ async def chat_completion(
|
|||
)
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
# Do not change this - it should be a constant time fetch - ALWAYS
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
@ -3534,8 +3603,6 @@ async def embeddings(
|
|||
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
|
@ -6691,30 +6758,47 @@ async def info_budget(data: BudgetRequest):
|
|||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def add_new_model(model_params: ModelParams):
|
||||
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||
async def add_new_model(
|
||||
model_params: ModelParams,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key
|
||||
try:
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
import base64
|
||||
|
||||
verbose_proxy_logger.debug("User config path: %s", user_config_file_path)
|
||||
global prisma_client
|
||||
|
||||
verbose_proxy_logger.debug("Loaded config: %s", config)
|
||||
# Add the new model to the config
|
||||
model_info = model_params.model_info.json()
|
||||
model_info = {k: v for k, v in model_info.items() if v is not None}
|
||||
config["model_list"].append(
|
||||
{
|
||||
"model_name": model_params.model_name,
|
||||
"litellm_params": model_params.litellm_params,
|
||||
"model_info": model_info,
|
||||
}
|
||||
)
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
},
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("updated model list: %s", config["model_list"])
|
||||
|
||||
# Save new config
|
||||
await proxy_config.save_config(new_config=config)
|
||||
# update DB
|
||||
if general_settings.get("store_model_in_db", False) == True:
|
||||
"""
|
||||
- store model_list in db
|
||||
- store keys separately
|
||||
"""
|
||||
# encrypt litellm params #
|
||||
for k, v in model_params.litellm_params.items():
|
||||
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore
|
||||
model_params.litellm_params[k] = base64.b64encode(
|
||||
encrypted_value
|
||||
).decode("utf-8")
|
||||
await prisma_client.db.litellm_proxymodeltable.create(
|
||||
data={
|
||||
"model_name": model_params.model_name,
|
||||
"litellm_params": json.dumps(model_params.litellm_params), # type: ignore
|
||||
"model_info": model_params.model_info.model_dump_json( # type: ignore
|
||||
exclude_none=True
|
||||
),
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
)
|
||||
return {"message": "Model added successfully"}
|
||||
|
||||
except Exception as e:
|
||||
|
@ -6884,14 +6968,16 @@ async def model_info_v1(
|
|||
):
|
||||
global llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
if llm_model_list is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "LLM Model List not loaded in"}
|
||||
)
|
||||
|
||||
if len(user_api_key_dict.models) > 0:
|
||||
model_names = user_api_key_dict.models
|
||||
all_models = [m for m in config["model_list"] if m["model_name"] in model_names]
|
||||
all_models = [m for m in llm_model_list if m["model_name"] in model_names]
|
||||
else:
|
||||
all_models = config["model_list"]
|
||||
all_models = llm_model_list
|
||||
for model in all_models:
|
||||
# provided model_info in config.yaml
|
||||
model_info = model.get("model_info", {})
|
||||
|
@ -6956,6 +7042,7 @@ async def delete_model(model_info: ModelInfoDelete):
|
|||
|
||||
# Check if the model with the specified model_id exists
|
||||
model_to_delete = None
|
||||
|
||||
for model in config["model_list"]:
|
||||
if model.get("model_info", {}).get("id", None) == model_info.id:
|
||||
model_to_delete = model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue