From 200e8784f39d92992950ff7f0d3517a997dbab3c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 15 Apr 2024 18:34:40 -0700 Subject: [PATCH] fix(proxy_server.py): fix delete models endpoint https://github.com/BerriAI/litellm/issues/2951 --- litellm/proxy/proxy_server.py | 53 +++++++++++++++++++++++++++++++++++ litellm/router.py | 23 +++++++++++++++ tests/test_models.py | 29 +++++++++++++------ 3 files changed, 96 insertions(+), 9 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0e76848551..7f7c23632b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2442,6 +2442,52 @@ class ProxyConfig: router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore return router, model_list, general_settings + async def _delete_deployment(self, db_models: list): + """ + (Helper function of add deployment) -> combined to reduce prisma db calls + + - Create all up list of model id's (db + config) + - Compare all up list to router model id's + - Remove any that are missing + """ + global user_config_file_path, llm_router + combined_id_list = [] + if llm_router is None: + return + + ## DB MODELS ## + for m in db_models: + if m.model_info is not None and isinstance(m.model_info, dict): + if "id" not in m.model_info: + m.model_info["id"] = m.model_id + combined_id_list.append(m.model_info) + else: + combined_id_list.append(m.model_id) + ## CONFIG MODELS ## + config = await self.get_config(config_file_path=user_config_file_path) + model_list = config.get("model_list", None) + if model_list: + for model in model_list: + ### LOAD FROM os.environ/ ### + for k, v in model["litellm_params"].items(): + if isinstance(v, str) and v.startswith("os.environ/"): + model["litellm_params"][k] = litellm.get_secret(v) + litellm_model_name = model["litellm_params"]["model"] + litellm_model_api_base = model["litellm_params"].get("api_base", None) + + model_id = litellm.Router()._generate_model_id( + model_group=model["model_name"], + litellm_params=model["litellm_params"], + ) + combined_id_list.append(model_id) # ADD CONFIG MODEL TO COMBINED LIST + + router_model_ids = llm_router.get_model_ids() + + # Check for model IDs in llm_router not present in combined_id_list and delete them + for model_id in router_model_ids: + if model_id not in combined_id_list: + llm_router.delete_deployment(id=model_id) + async def add_deployment( self, prisma_client: PrismaClient, @@ -2508,7 +2554,10 @@ class ProxyConfig: else: new_models = await prisma_client.db.litellm_proxymodeltable.find_many() verbose_proxy_logger.debug(f"len new_models: {len(new_models)}") + ## DELETE MODEL LOGIC + await self._delete_deployment(db_models=new_models) + ## ADD MODEL LOGIC for m in new_models: _litellm_params = m.litellm_params if isinstance(_litellm_params, dict): @@ -7301,6 +7350,10 @@ async def delete_model(model_info: ModelInfoDelete): detail={"error": f"Model with id={model_info.id} not found in db"}, ) + ## DELETE FROM ROUTER ## + if llm_router is not None: + llm_router.delete_deployment(id=model_info.id) + return {"message": f"Model: {result.model_id} deleted successfully"} else: raise HTTPException( diff --git a/litellm/router.py b/litellm/router.py index 6cec40fe93..8c2e21bffb 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2288,6 +2288,29 @@ class Router: self.model_names.append(deployment.model_name) return + def delete_deployment(self, id: str) -> Optional[Deployment]: + """ + Parameters: + - id: str - the id of the deployment to be deleted + + Returns: + - The deleted deployment + - OR None (if deleted deployment not found) + """ + deployment_idx = None + for idx, m in enumerate(self.model_list): + if m["model_info"]["id"] == id: + deployment_idx = idx + + try: + if deployment_idx is not None: + item = self.model_list.pop(deployment_idx) + return item + else: + return None + except: + return None + def get_deployment(self, model_id: str): for model in self.model_list: if "model_info" in model and "id" in model["model_info"]: diff --git a/tests/test_models.py b/tests/test_models.py index d1ae4d3dba..4f7c81938d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -57,7 +57,7 @@ async def test_get_models(): await get_models(session=session, key=key) -async def add_models(session, model_id="123"): +async def add_models(session, model_id="123", model_name="azure-gpt-3.5"): url = "http://0.0.0.0:4000/model/new" headers = { "Authorization": f"Bearer sk-1234", @@ -65,7 +65,7 @@ async def add_models(session, model_id="123"): } data = { - "model_name": "azure-gpt-3.5", + "model_name": model_name, "litellm_params": { "model": "azure/chatgpt-v-2", "api_key": "os.environ/AZURE_API_KEY", @@ -173,19 +173,30 @@ async def delete_model(session, model_id="123"): @pytest.mark.asyncio async def test_add_and_delete_models(): """ - Add model - Call new model + - Add model + - Call new model -> expect to pass + - Delete model + - Call model -> expect to fail """ + import uuid + async with aiohttp.ClientSession() as session: key_gen = await generate_key(session=session) key = key_gen["key"] - model_id = "12345" - response = await add_models(session=session, model_id=model_id) - assert response["model_id"] == "12345" + model_id = f"12345_{uuid.uuid4()}" + model_name = f"{uuid.uuid4()}" + response = await add_models( + session=session, model_id=model_id, model_name=model_name + ) + assert response["model_id"] == model_id await asyncio.sleep(10) - await chat_completion(session=session, key=key) + await chat_completion(session=session, key=key, model=model_name) await delete_model(session=session, model_id=model_id) - # raise Exception("it worked!") + try: + await chat_completion(session=session, key=key, model=model_name) + pytest.fail(f"Expected call to fail.") + except: + pass async def add_model_for_health_checking(session, model_id="123"):