diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 3d0f9b4d12..069309f8c6 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -6988,6 +6988,7 @@ async def add_new_model( }, ) + model_response = None # update DB if store_model_in_db == True: """ @@ -7002,17 +7003,20 @@ async def add_new_model( model_params.litellm_params[k] = base64.b64encode( encrypted_value ).decode("utf-8") - await prisma_client.db.litellm_proxymodeltable.create( - data={ - "model_id": model_params.model_info.id, - "model_name": model_params.model_name, - "litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore - "model_info": model_params.model_info.model_dump_json( # type: ignore - exclude_none=True - ), - "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - } + _data: dict = { + "model_id": model_params.model_info.id, + "model_name": model_params.model_name, + "litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore + "model_info": model_params.model_info.model_dump_json( # type: ignore + exclude_none=True + ), + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } + if model_params.model_info.id is not None: + _data["model_id"] = model_params.model_info.id + model_response = await prisma_client.db.litellm_proxymodeltable.create( + data=_data # type: ignore ) await proxy_config.add_deployment( @@ -7026,7 +7030,8 @@ async def add_new_model( "error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature." }, ) - return {"message": "Model added successfully"} + + return model_response except Exception as e: traceback.print_exc() @@ -7282,15 +7287,13 @@ async def delete_model(model_info: ModelInfoDelete): result = await prisma_client.db.litellm_proxymodeltable.delete( where={"model_id": model_info.id} ) - + if result is None: raise HTTPException( status_code=400, - detail={ - "error": f"Model with id={model_info.id} not found in db" - }, + detail={"error": f"Model with id={model_info.id} not found in db"}, ) - + return {"message": f"Model: {result.model_id} deleted successfully"} else: raise HTTPException( @@ -7299,7 +7302,6 @@ async def delete_model(model_info: ModelInfoDelete): "error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature." }, ) - except Exception as e: if isinstance(e, HTTPException): diff --git a/tests/test_models.py b/tests/test_models.py index 9097d5a69e..d1ae4d3dba 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -78,13 +78,15 @@ async def add_models(session, model_id="123"): async with session.post(url, headers=headers, json=data) as response: status = response.status response_text = await response.text() - print(f"Add models {response_text}") print() if status != 200: raise Exception(f"Request did not return a 200 status code: {status}") + response_json = await response.json() + return response_json + async def get_model_info(session, key): """ @@ -177,11 +179,13 @@ async def test_add_and_delete_models(): async with aiohttp.ClientSession() as session: key_gen = await generate_key(session=session) key = key_gen["key"] - model_id = "1234" - await add_models(session=session, model_id=model_id) - await asyncio.sleep(60) + model_id = "12345" + response = await add_models(session=session, model_id=model_id) + assert response["model_id"] == "12345" + await asyncio.sleep(10) await chat_completion(session=session, key=key) await delete_model(session=session, model_id=model_id) + # raise Exception("it worked!") async def add_model_for_health_checking(session, model_id="123"):