diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5af18c0ee6..96b1b16067 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -106,7 +106,7 @@ import pydantic from litellm.proxy._types import * from litellm.caching import DualCache, RedisCache from litellm.proxy.health_check import perform_health_check -from litellm.router import LiteLLM_Params, Deployment +from litellm.router import LiteLLM_Params, Deployment, updateDeployment from litellm.router import ModelInfo as RouterModelInfo from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm.proxy.auth.handle_jwt import JWTHandler @@ -7236,6 +7236,89 @@ async def add_new_model( ) +#### MODEL MANAGEMENT #### +@router.post( + "/model/update", + description="Edit existing model params", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], +) +async def update_model( + model_params: updateDeployment, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj + try: + import base64 + + global prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={ + "error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys" + }, + ) + # update DB + if store_model_in_db == True: + _model_id = None + _model_info = getattr(model_params, "model_info", None) + if _model_info is None: + raise Exception("model_info not provided") + + _model_id = _model_info.id + if _model_id is None: + raise Exception("model_info.id not provided") + _existing_litellm_params = ( + await prisma_client.db.litellm_proxymodeltable.find_unique( + where={"model_id": _model_id} + ) + ) + if _existing_litellm_params is None: + raise Exception("model not found") + _existing_litellm_params_dict = dict( + _existing_litellm_params.litellm_params + ) + + if model_params.litellm_params is None: + raise Exception("litellm_params not provided") + + _new_litellm_params_dict = model_params.litellm_params.dict( + exclude_none=True + ) + + for key, value in _existing_litellm_params_dict.items(): + if key in _new_litellm_params_dict: + _existing_litellm_params_dict[key] = _new_litellm_params_dict[key] + + _data: dict = { + "litellm_params": json.dumps(_existing_litellm_params_dict), # type: ignore + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } + model_response = await prisma_client.db.litellm_proxymodeltable.update( + where={"model_id": _model_id}, + data=_data, # type: ignore + ) + except Exception as e: + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + @router.get( "/v2/model/info", description="v2 - returns all the models set on the config.yaml, shows 'user_access' = True if the user has access to the model. Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", diff --git a/litellm/tests/test_add_update_models.py b/litellm/tests/test_add_update_models.py new file mode 100644 index 0000000000..9eccc83246 --- /dev/null +++ b/litellm/tests/test_add_update_models.py @@ -0,0 +1,189 @@ +import sys, os +import traceback +from dotenv import load_dotenv +from fastapi import Request +from datetime import datetime + +load_dotenv() +import os, io, time + +# this file is to test litellm/proxy + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest, logging, asyncio +import litellm, asyncio +from litellm.proxy.proxy_server import add_new_model, update_model +from litellm._logging import verbose_proxy_logger +from litellm.proxy.utils import PrismaClient, ProxyLogging + +verbose_proxy_logger.setLevel(level=logging.DEBUG) +from litellm.proxy.utils import DBClient +from litellm.caching import DualCache +from litellm.router import ( + Deployment, + updateDeployment, + LiteLLM_Params, + ModelInfo, + updateLiteLLMParams, +) + +from litellm.proxy._types import ( + UserAPIKeyAuth, +) + +proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + + +@pytest.fixture +def prisma_client(): + from litellm.proxy.proxy_cli import append_query_params + + ### add connection pool + pool timeout args + params = {"connection_limit": 100, "pool_timeout": 60} + database_url = os.getenv("DATABASE_URL") + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + os.environ["STORE_MODEL_IN_DB"] = "true" + + # Assuming DBClient is a class that needs to be instantiated + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + + # Reset litellm.proxy.proxy_server.prisma_client to None + litellm.proxy.proxy_server.custom_db_client = None + litellm.proxy.proxy_server.litellm_proxy_budget_name = ( + f"litellm-proxy-budget-{time.time()}" + ) + litellm.proxy.proxy_server.user_custom_key_generate = None + + return prisma_client + + +@pytest.mark.asyncio +async def test_add_new_model(prisma_client): + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "store_model_in_db", True) + + await litellm.proxy.proxy_server.prisma_client.connect() + from litellm.proxy.proxy_server import user_api_key_cache + import uuid + + _new_model_id = f"local-test-{uuid.uuid4().hex}" + + await add_new_model( + model_params=Deployment( + model_name="test_model", + litellm_params=LiteLLM_Params( + model="azure/gpt-3.5-turbo", + api_key="test_api_key", + api_base="test_api_base", + rpm=1000, + tpm=1000, + ), + model_info=ModelInfo( + id=_new_model_id, + ), + ), + user_api_key_dict=UserAPIKeyAuth( + user_role="proxy_admin", api_key="sk-1234", user_id="1234" + ), + ) + + _new_models = await prisma_client.db.litellm_proxymodeltable.find_many() + print("_new_models: ", _new_models) + + _new_model_in_db = None + for model in _new_models: + print("current model: ", model) + if model.model_info["id"] == _new_model_id: + print("FOUND MODEL: ", model) + _new_model_in_db = model + + assert _new_model_in_db is not None + + +@pytest.mark.asyncio +async def test_add_update_model(prisma_client): + # test that existing litellm_params are not updated + # only new / updated params get updated + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "store_model_in_db", True) + + await litellm.proxy.proxy_server.prisma_client.connect() + from litellm.proxy.proxy_server import user_api_key_cache + import uuid + + _new_model_id = f"local-test-{uuid.uuid4().hex}" + + await add_new_model( + model_params=Deployment( + model_name="test_model", + litellm_params=LiteLLM_Params( + model="azure/gpt-3.5-turbo", + api_key="test_api_key", + api_base="test_api_base", + rpm=1000, + tpm=1000, + ), + model_info=ModelInfo( + id=_new_model_id, + ), + ), + user_api_key_dict=UserAPIKeyAuth( + user_role="proxy_admin", api_key="sk-1234", user_id="1234" + ), + ) + + _new_models = await prisma_client.db.litellm_proxymodeltable.find_many() + print("_new_models: ", _new_models) + + _new_model_in_db = None + for model in _new_models: + print("current model: ", model) + if model.model_info["id"] == _new_model_id: + print("FOUND MODEL: ", model) + _new_model_in_db = model + + assert _new_model_in_db is not None + + _original_model = _new_model_in_db + _original_litellm_params = _new_model_in_db.litellm_params + print("_original_litellm_params: ", _original_litellm_params) + print("now updating the tpm for model") + # run update to update "tpm" + await update_model( + model_params=updateDeployment( + litellm_params=updateLiteLLMParams(tpm=123456), + model_info=ModelInfo( + id=_new_model_id, + ), + ), + user_api_key_dict=UserAPIKeyAuth( + user_role="proxy_admin", api_key="sk-1234", user_id="1234" + ), + ) + + _new_models = await prisma_client.db.litellm_proxymodeltable.find_many() + + _new_model_in_db = None + for model in _new_models: + if model.model_info["id"] == _new_model_id: + print("\nFOUND MODEL: ", model) + _new_model_in_db = model + + # assert all other litellm params are identical to _original_litellm_params + for key, value in _original_litellm_params.items(): + if key == "tpm": + # assert that tpm actually got updated + assert _new_model_in_db.litellm_params[key] == 123456 + else: + assert _new_model_in_db.litellm_params[key] == value + + assert _original_model.model_id == _new_model_in_db.model_id + assert _original_model.model_name == _new_model_in_db.model_name + assert _original_model.model_info == _new_model_in_db.model_info