forked from phoenix/litellm-mirror
fix(proxy_server.py): ensure id used in delete deployment matches id used in litellm Router
This commit is contained in:
parent
70716b3373
commit
13cd252f3e
4 changed files with 307 additions and 62 deletions
|
@ -2408,27 +2408,44 @@ class ProxyConfig:
|
||||||
router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore
|
router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore
|
||||||
return router, model_list, general_settings
|
return router, model_list, general_settings
|
||||||
|
|
||||||
async def _delete_deployment(self, db_models: list):
|
def get_model_info_with_id(self, model) -> RouterModelInfo:
|
||||||
|
"""
|
||||||
|
Common logic across add + delete router models
|
||||||
|
Parameters:
|
||||||
|
- deployment
|
||||||
|
|
||||||
|
Return model info w/ id
|
||||||
|
"""
|
||||||
|
if model.model_info is not None and isinstance(model.model_info, dict):
|
||||||
|
if "id" not in model.model_info:
|
||||||
|
model.model_info["id"] = model.model_id
|
||||||
|
_model_info = RouterModelInfo(**model.model_info)
|
||||||
|
else:
|
||||||
|
_model_info = RouterModelInfo(id=model.model_id)
|
||||||
|
return _model_info
|
||||||
|
|
||||||
|
async def _delete_deployment(self, db_models: list) -> int:
|
||||||
"""
|
"""
|
||||||
(Helper function of add deployment) -> combined to reduce prisma db calls
|
(Helper function of add deployment) -> combined to reduce prisma db calls
|
||||||
|
|
||||||
- Create all up list of model id's (db + config)
|
- Create all up list of model id's (db + config)
|
||||||
- Compare all up list to router model id's
|
- Compare all up list to router model id's
|
||||||
- Remove any that are missing
|
- Remove any that are missing
|
||||||
|
|
||||||
|
Return:
|
||||||
|
- int - returns number of deleted deployments
|
||||||
"""
|
"""
|
||||||
global user_config_file_path, llm_router
|
global user_config_file_path, llm_router
|
||||||
combined_id_list = []
|
combined_id_list = []
|
||||||
if llm_router is None:
|
if llm_router is None:
|
||||||
return
|
return 0
|
||||||
|
|
||||||
## DB MODELS ##
|
## DB MODELS ##
|
||||||
for m in db_models:
|
for m in db_models:
|
||||||
if m.model_info is not None and isinstance(m.model_info, dict):
|
model_info = self.get_model_info_with_id(model=m)
|
||||||
if "id" not in m.model_info:
|
if model_info.id is not None:
|
||||||
m.model_info["id"] = m.model_id
|
combined_id_list.append(model_info.id)
|
||||||
combined_id_list.append(m.model_id)
|
|
||||||
else:
|
|
||||||
combined_id_list.append(m.model_id)
|
|
||||||
## CONFIG MODELS ##
|
## CONFIG MODELS ##
|
||||||
config = await self.get_config(config_file_path=user_config_file_path)
|
config = await self.get_config(config_file_path=user_config_file_path)
|
||||||
model_list = config.get("model_list", None)
|
model_list = config.get("model_list", None)
|
||||||
|
@ -2438,21 +2455,73 @@ class ProxyConfig:
|
||||||
for k, v in model["litellm_params"].items():
|
for k, v in model["litellm_params"].items():
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
model["litellm_params"][k] = litellm.get_secret(v)
|
model["litellm_params"][k] = litellm.get_secret(v)
|
||||||
litellm_model_name = model["litellm_params"]["model"]
|
model_id = llm_router._generate_model_id(
|
||||||
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
|
||||||
|
|
||||||
model_id = litellm.Router()._generate_model_id(
|
|
||||||
model_group=model["model_name"],
|
model_group=model["model_name"],
|
||||||
litellm_params=model["litellm_params"],
|
litellm_params=model["litellm_params"],
|
||||||
)
|
)
|
||||||
combined_id_list.append(model_id) # ADD CONFIG MODEL TO COMBINED LIST
|
combined_id_list.append(model_id) # ADD CONFIG MODEL TO COMBINED LIST
|
||||||
|
|
||||||
router_model_ids = llm_router.get_model_ids()
|
router_model_ids = llm_router.get_model_ids()
|
||||||
|
|
||||||
# Check for model IDs in llm_router not present in combined_id_list and delete them
|
# Check for model IDs in llm_router not present in combined_id_list and delete them
|
||||||
|
deleted_deployments = 0
|
||||||
for model_id in router_model_ids:
|
for model_id in router_model_ids:
|
||||||
if model_id not in combined_id_list:
|
if model_id not in combined_id_list:
|
||||||
llm_router.delete_deployment(id=model_id)
|
is_deleted = llm_router.delete_deployment(id=model_id)
|
||||||
|
if is_deleted is not None:
|
||||||
|
deleted_deployments += 1
|
||||||
|
return deleted_deployments
|
||||||
|
|
||||||
|
def _add_deployment(self, db_models: list) -> int:
|
||||||
|
"""
|
||||||
|
Iterate through db models
|
||||||
|
|
||||||
|
for any not in router - add them.
|
||||||
|
|
||||||
|
Return - number of deployments added
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
|
||||||
|
if master_key is None or not isinstance(master_key, str):
|
||||||
|
raise Exception(
|
||||||
|
f"Master key is not initialized or formatted. master_key={master_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if llm_router is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
added_models = 0
|
||||||
|
## ADD MODEL LOGIC
|
||||||
|
for m in db_models:
|
||||||
|
_litellm_params = m.litellm_params
|
||||||
|
if isinstance(_litellm_params, dict):
|
||||||
|
# decrypt values
|
||||||
|
for k, v in _litellm_params.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
# decode base64
|
||||||
|
decoded_b64 = base64.b64decode(v)
|
||||||
|
# decrypt value
|
||||||
|
_litellm_params[k] = decrypt_value(
|
||||||
|
value=decoded_b64, master_key=master_key
|
||||||
|
)
|
||||||
|
_litellm_params = LiteLLM_Params(**_litellm_params)
|
||||||
|
else:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
|
||||||
|
)
|
||||||
|
continue # skip to next model
|
||||||
|
_model_info = self.get_model_info_with_id(model=m)
|
||||||
|
|
||||||
|
added = llm_router.add_deployment(
|
||||||
|
deployment=Deployment(
|
||||||
|
model_name=m.model_name,
|
||||||
|
litellm_params=_litellm_params,
|
||||||
|
model_info=_model_info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if added is not None:
|
||||||
|
added_models += 1
|
||||||
|
return added_models
|
||||||
|
|
||||||
async def add_deployment(
|
async def add_deployment(
|
||||||
self,
|
self,
|
||||||
|
@ -2500,13 +2569,7 @@ class ProxyConfig:
|
||||||
)
|
)
|
||||||
continue # skip to next model
|
continue # skip to next model
|
||||||
|
|
||||||
if m.model_info is not None and isinstance(m.model_info, dict):
|
_model_info = self.get_model_info_with_id(model=m)
|
||||||
if "id" not in m.model_info:
|
|
||||||
m.model_info["id"] = m.model_id
|
|
||||||
_model_info = RouterModelInfo(**m.model_info)
|
|
||||||
else:
|
|
||||||
_model_info = RouterModelInfo(id=m.model_id)
|
|
||||||
|
|
||||||
_model_list.append(
|
_model_list.append(
|
||||||
Deployment(
|
Deployment(
|
||||||
model_name=m.model_name,
|
model_name=m.model_name,
|
||||||
|
@ -2524,39 +2587,7 @@ class ProxyConfig:
|
||||||
await self._delete_deployment(db_models=new_models)
|
await self._delete_deployment(db_models=new_models)
|
||||||
|
|
||||||
## ADD MODEL LOGIC
|
## ADD MODEL LOGIC
|
||||||
for m in new_models:
|
self._add_deployment(db_models=new_models)
|
||||||
_litellm_params = m.litellm_params
|
|
||||||
if isinstance(_litellm_params, dict):
|
|
||||||
# decrypt values
|
|
||||||
for k, v in _litellm_params.items():
|
|
||||||
if isinstance(v, str):
|
|
||||||
# decode base64
|
|
||||||
decoded_b64 = base64.b64decode(v)
|
|
||||||
# decrypt value
|
|
||||||
_litellm_params[k] = decrypt_value(
|
|
||||||
value=decoded_b64, master_key=master_key
|
|
||||||
)
|
|
||||||
_litellm_params = LiteLLM_Params(**_litellm_params)
|
|
||||||
else:
|
|
||||||
verbose_proxy_logger.error(
|
|
||||||
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
|
|
||||||
)
|
|
||||||
continue # skip to next model
|
|
||||||
|
|
||||||
if m.model_info is not None and isinstance(m.model_info, dict):
|
|
||||||
if "id" not in m.model_info:
|
|
||||||
m.model_info["id"] = m.model_id
|
|
||||||
_model_info = RouterModelInfo(**m.model_info)
|
|
||||||
else:
|
|
||||||
_model_info = RouterModelInfo(id=m.model_id)
|
|
||||||
|
|
||||||
llm_router.add_deployment(
|
|
||||||
deployment=Deployment(
|
|
||||||
model_name=m.model_name,
|
|
||||||
litellm_params=_litellm_params,
|
|
||||||
model_info=_model_info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_model_list = llm_router.get_model_list()
|
llm_model_list = llm_router.get_model_list()
|
||||||
|
|
||||||
|
@ -3218,7 +3249,7 @@ async def startup_event():
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
proxy_config.add_deployment,
|
proxy_config.add_deployment,
|
||||||
"interval",
|
"interval",
|
||||||
seconds=30,
|
seconds=10,
|
||||||
args=[prisma_client, proxy_logging_obj],
|
args=[prisma_client, proxy_logging_obj],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -2271,11 +2271,19 @@ class Router:
|
||||||
|
|
||||||
return deployment
|
return deployment
|
||||||
|
|
||||||
def add_deployment(self, deployment: Deployment):
|
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
- deployment: Deployment - the deployment to be added to the Router
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- The added deployment
|
||||||
|
- OR None (if deployment already exists)
|
||||||
|
"""
|
||||||
# check if deployment already exists
|
# check if deployment already exists
|
||||||
|
|
||||||
if deployment.model_info.id in self.get_model_ids():
|
if deployment.model_info.id in self.get_model_ids():
|
||||||
return
|
return None
|
||||||
|
|
||||||
# add to model list
|
# add to model list
|
||||||
_deployment = deployment.to_json(exclude_none=True)
|
_deployment = deployment.to_json(exclude_none=True)
|
||||||
|
@ -2286,7 +2294,7 @@ class Router:
|
||||||
|
|
||||||
# add to model names
|
# add to model names
|
||||||
self.model_names.append(deployment.model_name)
|
self.model_names.append(deployment.model_name)
|
||||||
return
|
return deployment
|
||||||
|
|
||||||
def delete_deployment(self, id: str) -> Optional[Deployment]:
|
def delete_deployment(self, id: str) -> Optional[Deployment]:
|
||||||
"""
|
"""
|
||||||
|
|
168
litellm/tests/test_config.py
Normal file
168
litellm/tests/test_config.py
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit tests for ProxyConfig class
|
||||||
|
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os, io
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the, system path
|
||||||
|
import pytest, litellm
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from litellm.proxy.proxy_server import ProxyConfig
|
||||||
|
from litellm.proxy.utils import encrypt_value
|
||||||
|
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
|
||||||
|
|
||||||
|
|
||||||
|
class DBModel(BaseModel):
|
||||||
|
model_id: str
|
||||||
|
model_name: str
|
||||||
|
model_info: dict
|
||||||
|
litellm_params: dict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_deployment():
|
||||||
|
"""
|
||||||
|
- Ensure the global llm router is not being reset
|
||||||
|
- Ensure invalid model is deleted
|
||||||
|
- Check if model id != model_info["id"], the model_info["id"] is picked
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
|
||||||
|
litellm_params = LiteLLM_Params(
|
||||||
|
model="azure/chatgpt-v-2",
|
||||||
|
api_key=os.getenv("AZURE_API_KEY"),
|
||||||
|
api_base=os.getenv("AZURE_API_BASE"),
|
||||||
|
api_version=os.getenv("AZURE_API_VERSION"),
|
||||||
|
)
|
||||||
|
encrypted_litellm_params = litellm_params.dict(exclude_none=True)
|
||||||
|
|
||||||
|
master_key = "sk-1234"
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", master_key)
|
||||||
|
|
||||||
|
for k, v in encrypted_litellm_params.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
encrypted_value = encrypt_value(v, master_key)
|
||||||
|
encrypted_litellm_params[k] = base64.b64encode(encrypted_value).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
deployment = Deployment(model_name="gpt-3.5-turbo", litellm_params=litellm_params)
|
||||||
|
deployment_2 = Deployment(
|
||||||
|
model_name="gpt-3.5-turbo-2", litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
deployment.to_json(exclude_none=True),
|
||||||
|
deployment_2.to_json(exclude_none=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||||
|
print(f"llm_router: {llm_router}")
|
||||||
|
|
||||||
|
pc = ProxyConfig()
|
||||||
|
|
||||||
|
db_model = DBModel(
|
||||||
|
model_id=deployment.model_info.id,
|
||||||
|
model_name="gpt-3.5-turbo",
|
||||||
|
litellm_params=encrypted_litellm_params,
|
||||||
|
model_info={"id": deployment.model_info.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
db_models = [db_model]
|
||||||
|
deleted_deployments = await pc._delete_deployment(db_models=db_models)
|
||||||
|
|
||||||
|
assert deleted_deployments == 1
|
||||||
|
assert len(llm_router.model_list) == 1
|
||||||
|
|
||||||
|
"""
|
||||||
|
Scenario 2 - if model id != model_info["id"]
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm_router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
deployment.to_json(exclude_none=True),
|
||||||
|
deployment_2.to_json(exclude_none=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(f"llm_router: {llm_router}")
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||||
|
pc = ProxyConfig()
|
||||||
|
|
||||||
|
db_model = DBModel(
|
||||||
|
model_id="12340523",
|
||||||
|
model_name="gpt-3.5-turbo",
|
||||||
|
litellm_params=encrypted_litellm_params,
|
||||||
|
model_info={"id": deployment.model_info.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
db_models = [db_model]
|
||||||
|
deleted_deployments = await pc._delete_deployment(db_models=db_models)
|
||||||
|
|
||||||
|
assert deleted_deployments == 1
|
||||||
|
assert len(llm_router.model_list) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_existing_deployment():
|
||||||
|
"""
|
||||||
|
- Only add new models
|
||||||
|
- don't re-add existing models
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
|
||||||
|
litellm_params = LiteLLM_Params(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
api_key=os.getenv("AZURE_API_KEY"),
|
||||||
|
api_base=os.getenv("AZURE_API_BASE"),
|
||||||
|
api_version=os.getenv("AZURE_API_VERSION"),
|
||||||
|
)
|
||||||
|
deployment = Deployment(model_name="gpt-3.5-turbo", litellm_params=litellm_params)
|
||||||
|
deployment_2 = Deployment(
|
||||||
|
model_name="gpt-3.5-turbo-2", litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
deployment.to_json(exclude_none=True),
|
||||||
|
deployment_2.to_json(exclude_none=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(f"llm_router: {llm_router}")
|
||||||
|
master_key = "sk-1234"
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", master_key)
|
||||||
|
pc = ProxyConfig()
|
||||||
|
|
||||||
|
encrypted_litellm_params = litellm_params.dict(exclude_none=True)
|
||||||
|
|
||||||
|
for k, v in encrypted_litellm_params.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
encrypted_value = encrypt_value(v, master_key)
|
||||||
|
encrypted_litellm_params[k] = base64.b64encode(encrypted_value).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
db_model = DBModel(
|
||||||
|
model_id=deployment.model_info.id,
|
||||||
|
model_name="gpt-3.5-turbo",
|
||||||
|
litellm_params=encrypted_litellm_params,
|
||||||
|
model_info={"id": deployment.model_info.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
db_models = [db_model]
|
||||||
|
num_added = pc._add_deployment(db_models=db_models)
|
||||||
|
|
||||||
|
assert num_added == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_and_delete_deployments():
|
||||||
|
pass
|
|
@ -101,12 +101,39 @@ class LiteLLM_Params(BaseModel):
|
||||||
aws_secret_access_key: Optional[str] = None
|
aws_secret_access_key: Optional[str] = None
|
||||||
aws_region_name: Optional[str] = None
|
aws_region_name: Optional[str] = None
|
||||||
|
|
||||||
def __init__(self, max_retries: Optional[Union[int, str]] = None, **params):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
max_retries: Optional[Union[int, str]] = None,
|
||||||
|
tpm: Optional[int] = None,
|
||||||
|
rpm: Optional[int] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
api_version: Optional[str] = None,
|
||||||
|
timeout: Optional[Union[float, str]] = None, # if str, pass in as os.environ/
|
||||||
|
stream_timeout: Optional[Union[float, str]] = (
|
||||||
|
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
||||||
|
),
|
||||||
|
organization: Optional[str] = None, # for openai orgs
|
||||||
|
## VERTEX AI ##
|
||||||
|
vertex_project: Optional[str] = None,
|
||||||
|
vertex_location: Optional[str] = None,
|
||||||
|
## AWS BEDROCK / SAGEMAKER ##
|
||||||
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_region_name: Optional[str] = None,
|
||||||
|
**params
|
||||||
|
):
|
||||||
|
args = locals()
|
||||||
|
args.pop("max_retries", None)
|
||||||
|
args.pop("self", None)
|
||||||
|
args.pop("params", None)
|
||||||
|
args.pop("__class__", None)
|
||||||
if max_retries is None:
|
if max_retries is None:
|
||||||
max_retries = 2
|
max_retries = 2
|
||||||
elif isinstance(max_retries, str):
|
elif isinstance(max_retries, str):
|
||||||
max_retries = int(max_retries) # cast to int
|
max_retries = int(max_retries) # cast to int
|
||||||
super().__init__(max_retries=max_retries, **params)
|
super().__init__(max_retries=max_retries, **args, **params)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
@ -133,12 +160,23 @@ class Deployment(BaseModel):
|
||||||
litellm_params: LiteLLM_Params
|
litellm_params: LiteLLM_Params
|
||||||
model_info: ModelInfo
|
model_info: ModelInfo
|
||||||
|
|
||||||
def __init__(self, model_info: Optional[Union[ModelInfo, dict]] = None, **params):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
litellm_params: LiteLLM_Params,
|
||||||
|
model_info: Optional[Union[ModelInfo, dict]] = None,
|
||||||
|
**params
|
||||||
|
):
|
||||||
if model_info is None:
|
if model_info is None:
|
||||||
model_info = ModelInfo()
|
model_info = ModelInfo()
|
||||||
elif isinstance(model_info, dict):
|
elif isinstance(model_info, dict):
|
||||||
model_info = ModelInfo(**model_info)
|
model_info = ModelInfo(**model_info)
|
||||||
super().__init__(model_info=model_info, **params)
|
super().__init__(
|
||||||
|
model_info=model_info,
|
||||||
|
model_name=model_name,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
def to_json(self, **kwargs):
|
def to_json(self, **kwargs):
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue