mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(proxy_server.py): fix delete models endpoint
https://github.com/BerriAI/litellm/issues/2951
This commit is contained in:
parent
92a92eee3c
commit
200e8784f3
3 changed files with 96 additions and 9 deletions
|
@ -2442,6 +2442,52 @@ class ProxyConfig:
|
||||||
router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore
|
router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore
|
||||||
return router, model_list, general_settings
|
return router, model_list, general_settings
|
||||||
|
|
||||||
|
async def _delete_deployment(self, db_models: list):
|
||||||
|
"""
|
||||||
|
(Helper function of add deployment) -> combined to reduce prisma db calls
|
||||||
|
|
||||||
|
- Create all up list of model id's (db + config)
|
||||||
|
- Compare all up list to router model id's
|
||||||
|
- Remove any that are missing
|
||||||
|
"""
|
||||||
|
global user_config_file_path, llm_router
|
||||||
|
combined_id_list = []
|
||||||
|
if llm_router is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
## DB MODELS ##
|
||||||
|
for m in db_models:
|
||||||
|
if m.model_info is not None and isinstance(m.model_info, dict):
|
||||||
|
if "id" not in m.model_info:
|
||||||
|
m.model_info["id"] = m.model_id
|
||||||
|
combined_id_list.append(m.model_info)
|
||||||
|
else:
|
||||||
|
combined_id_list.append(m.model_id)
|
||||||
|
## CONFIG MODELS ##
|
||||||
|
config = await self.get_config(config_file_path=user_config_file_path)
|
||||||
|
model_list = config.get("model_list", None)
|
||||||
|
if model_list:
|
||||||
|
for model in model_list:
|
||||||
|
### LOAD FROM os.environ/ ###
|
||||||
|
for k, v in model["litellm_params"].items():
|
||||||
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
|
model["litellm_params"][k] = litellm.get_secret(v)
|
||||||
|
litellm_model_name = model["litellm_params"]["model"]
|
||||||
|
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
||||||
|
|
||||||
|
model_id = litellm.Router()._generate_model_id(
|
||||||
|
model_group=model["model_name"],
|
||||||
|
litellm_params=model["litellm_params"],
|
||||||
|
)
|
||||||
|
combined_id_list.append(model_id) # ADD CONFIG MODEL TO COMBINED LIST
|
||||||
|
|
||||||
|
router_model_ids = llm_router.get_model_ids()
|
||||||
|
|
||||||
|
# Check for model IDs in llm_router not present in combined_id_list and delete them
|
||||||
|
for model_id in router_model_ids:
|
||||||
|
if model_id not in combined_id_list:
|
||||||
|
llm_router.delete_deployment(id=model_id)
|
||||||
|
|
||||||
async def add_deployment(
|
async def add_deployment(
|
||||||
self,
|
self,
|
||||||
prisma_client: PrismaClient,
|
prisma_client: PrismaClient,
|
||||||
|
@ -2508,7 +2554,10 @@ class ProxyConfig:
|
||||||
else:
|
else:
|
||||||
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
|
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||||
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
|
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
|
||||||
|
## DELETE MODEL LOGIC
|
||||||
|
await self._delete_deployment(db_models=new_models)
|
||||||
|
|
||||||
|
## ADD MODEL LOGIC
|
||||||
for m in new_models:
|
for m in new_models:
|
||||||
_litellm_params = m.litellm_params
|
_litellm_params = m.litellm_params
|
||||||
if isinstance(_litellm_params, dict):
|
if isinstance(_litellm_params, dict):
|
||||||
|
@ -7301,6 +7350,10 @@ async def delete_model(model_info: ModelInfoDelete):
|
||||||
detail={"error": f"Model with id={model_info.id} not found in db"},
|
detail={"error": f"Model with id={model_info.id} not found in db"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## DELETE FROM ROUTER ##
|
||||||
|
if llm_router is not None:
|
||||||
|
llm_router.delete_deployment(id=model_info.id)
|
||||||
|
|
||||||
return {"message": f"Model: {result.model_id} deleted successfully"}
|
return {"message": f"Model: {result.model_id} deleted successfully"}
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
@ -2288,6 +2288,29 @@ class Router:
|
||||||
self.model_names.append(deployment.model_name)
|
self.model_names.append(deployment.model_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def delete_deployment(self, id: str) -> Optional[Deployment]:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
- id: str - the id of the deployment to be deleted
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- The deleted deployment
|
||||||
|
- OR None (if deleted deployment not found)
|
||||||
|
"""
|
||||||
|
deployment_idx = None
|
||||||
|
for idx, m in enumerate(self.model_list):
|
||||||
|
if m["model_info"]["id"] == id:
|
||||||
|
deployment_idx = idx
|
||||||
|
|
||||||
|
try:
|
||||||
|
if deployment_idx is not None:
|
||||||
|
item = self.model_list.pop(deployment_idx)
|
||||||
|
return item
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_deployment(self, model_id: str):
|
def get_deployment(self, model_id: str):
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
if "model_info" in model and "id" in model["model_info"]:
|
if "model_info" in model and "id" in model["model_info"]:
|
||||||
|
|
|
@ -57,7 +57,7 @@ async def test_get_models():
|
||||||
await get_models(session=session, key=key)
|
await get_models(session=session, key=key)
|
||||||
|
|
||||||
|
|
||||||
async def add_models(session, model_id="123"):
|
async def add_models(session, model_id="123", model_name="azure-gpt-3.5"):
|
||||||
url = "http://0.0.0.0:4000/model/new"
|
url = "http://0.0.0.0:4000/model/new"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer sk-1234",
|
"Authorization": f"Bearer sk-1234",
|
||||||
|
@ -65,7 +65,7 @@ async def add_models(session, model_id="123"):
|
||||||
}
|
}
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model_name": "azure-gpt-3.5",
|
"model_name": model_name,
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "azure/chatgpt-v-2",
|
"model": "azure/chatgpt-v-2",
|
||||||
"api_key": "os.environ/AZURE_API_KEY",
|
"api_key": "os.environ/AZURE_API_KEY",
|
||||||
|
@ -173,19 +173,30 @@ async def delete_model(session, model_id="123"):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_and_delete_models():
|
async def test_add_and_delete_models():
|
||||||
"""
|
"""
|
||||||
Add model
|
- Add model
|
||||||
Call new model
|
- Call new model -> expect to pass
|
||||||
|
- Delete model
|
||||||
|
- Call model -> expect to fail
|
||||||
"""
|
"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
key_gen = await generate_key(session=session)
|
key_gen = await generate_key(session=session)
|
||||||
key = key_gen["key"]
|
key = key_gen["key"]
|
||||||
model_id = "12345"
|
model_id = f"12345_{uuid.uuid4()}"
|
||||||
response = await add_models(session=session, model_id=model_id)
|
model_name = f"{uuid.uuid4()}"
|
||||||
assert response["model_id"] == "12345"
|
response = await add_models(
|
||||||
|
session=session, model_id=model_id, model_name=model_name
|
||||||
|
)
|
||||||
|
assert response["model_id"] == model_id
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
await chat_completion(session=session, key=key)
|
await chat_completion(session=session, key=key, model=model_name)
|
||||||
await delete_model(session=session, model_id=model_id)
|
await delete_model(session=session, model_id=model_id)
|
||||||
# raise Exception("it worked!")
|
try:
|
||||||
|
await chat_completion(session=session, key=key, model=model_name)
|
||||||
|
pytest.fail(f"Expected call to fail.")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def add_model_for_health_checking(session, model_id="123"):
|
async def add_model_for_health_checking(session, model_id="123"):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue