mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(proxy_server.py): handle router being initialized without a model list
This commit is contained in:
parent
13b84ca13f
commit
f1f08af785
3 changed files with 223 additions and 92 deletions
|
@ -2470,6 +2470,11 @@ class ProxyConfig:
|
||||||
for k, v in model["litellm_params"].items():
|
for k, v in model["litellm_params"].items():
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
model["litellm_params"][k] = litellm.get_secret(v)
|
model["litellm_params"][k] = litellm.get_secret(v)
|
||||||
|
|
||||||
|
## check if they have model-id's ##
|
||||||
|
model_id = model.get("model_info", {}).get("id", None)
|
||||||
|
if model_id is None:
|
||||||
|
## else - generate stable id's ##
|
||||||
model_id = llm_router._generate_model_id(
|
model_id = llm_router._generate_model_id(
|
||||||
model_group=model["model_name"],
|
model_group=model["model_name"],
|
||||||
litellm_params=model["litellm_params"],
|
litellm_params=model["litellm_params"],
|
||||||
|
@ -2478,6 +2483,7 @@ class ProxyConfig:
|
||||||
|
|
||||||
router_model_ids = llm_router.get_model_ids()
|
router_model_ids = llm_router.get_model_ids()
|
||||||
# Check for model IDs in llm_router not present in combined_id_list and delete them
|
# Check for model IDs in llm_router not present in combined_id_list and delete them
|
||||||
|
|
||||||
deleted_deployments = 0
|
deleted_deployments = 0
|
||||||
for model_id in router_model_ids:
|
for model_id in router_model_ids:
|
||||||
if model_id not in combined_id_list:
|
if model_id not in combined_id_list:
|
||||||
|
@ -2538,30 +2544,15 @@ class ProxyConfig:
|
||||||
added_models += 1
|
added_models += 1
|
||||||
return added_models
|
return added_models
|
||||||
|
|
||||||
async def add_deployment(
|
async def _update_llm_router(
|
||||||
self,
|
self,
|
||||||
prisma_client: PrismaClient,
|
new_models: list,
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
- Check db for new models (last 10 most recently updated)
|
|
||||||
- Check if model id's in router already
|
|
||||||
- If not, add to router
|
|
||||||
"""
|
|
||||||
global llm_router, llm_model_list, master_key, general_settings
|
global llm_router, llm_model_list, master_key, general_settings
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
try:
|
if llm_router is None and master_key is not None:
|
||||||
if master_key is None or not isinstance(master_key, str):
|
|
||||||
raise Exception(
|
|
||||||
f"Master key is not initialized or formatted. master_key={master_key}"
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(f"llm_router: {llm_router}")
|
|
||||||
if llm_router is None:
|
|
||||||
new_models = (
|
|
||||||
await prisma_client.db.litellm_proxymodeltable.find_many()
|
|
||||||
) # get all models in db
|
|
||||||
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
|
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
|
||||||
|
|
||||||
_model_list: list = []
|
_model_list: list = []
|
||||||
|
@ -2575,7 +2566,7 @@ class ProxyConfig:
|
||||||
decoded_b64 = base64.b64decode(v)
|
decoded_b64 = base64.b64decode(v)
|
||||||
# decrypt value
|
# decrypt value
|
||||||
_litellm_params[k] = decrypt_value(
|
_litellm_params[k] = decrypt_value(
|
||||||
value=decoded_b64, master_key=master_key
|
value=decoded_b64, master_key=master_key # type: ignore
|
||||||
)
|
)
|
||||||
_litellm_params = LiteLLM_Params(**_litellm_params)
|
_litellm_params = LiteLLM_Params(**_litellm_params)
|
||||||
else:
|
else:
|
||||||
|
@ -2592,11 +2583,11 @@ class ProxyConfig:
|
||||||
model_info=_model_info,
|
model_info=_model_info,
|
||||||
).to_json(exclude_none=True)
|
).to_json(exclude_none=True)
|
||||||
)
|
)
|
||||||
|
if len(_model_list) > 0:
|
||||||
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
|
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
|
||||||
llm_router = litellm.Router(model_list=_model_list)
|
llm_router = litellm.Router(model_list=_model_list)
|
||||||
verbose_proxy_logger.debug(f"updated llm_router: {llm_router}")
|
verbose_proxy_logger.debug(f"updated llm_router: {llm_router}")
|
||||||
else:
|
else:
|
||||||
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
|
|
||||||
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
|
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
|
||||||
## DELETE MODEL LOGIC
|
## DELETE MODEL LOGIC
|
||||||
await self._delete_deployment(db_models=new_models)
|
await self._delete_deployment(db_models=new_models)
|
||||||
|
@ -2604,6 +2595,7 @@ class ProxyConfig:
|
||||||
## ADD MODEL LOGIC
|
## ADD MODEL LOGIC
|
||||||
self._add_deployment(db_models=new_models)
|
self._add_deployment(db_models=new_models)
|
||||||
|
|
||||||
|
if llm_router is not None:
|
||||||
llm_model_list = llm_router.get_model_list()
|
llm_model_list = llm_router.get_model_list()
|
||||||
|
|
||||||
# check if user set any callbacks in Config Table
|
# check if user set any callbacks in Config Table
|
||||||
|
@ -2620,7 +2612,7 @@ class ProxyConfig:
|
||||||
for k, v in environment_variables.items():
|
for k, v in environment_variables.items():
|
||||||
try:
|
try:
|
||||||
decoded_b64 = base64.b64decode(v)
|
decoded_b64 = base64.b64decode(v)
|
||||||
value = decrypt_value(value=decoded_b64, master_key=master_key)
|
value = decrypt_value(value=decoded_b64, master_key=master_key) # type: ignore
|
||||||
os.environ[k] = value
|
os.environ[k] = value
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
|
@ -2637,8 +2629,32 @@ class ProxyConfig:
|
||||||
proxy_logging_obj.alert_types = general_settings["alert_types"]
|
proxy_logging_obj.alert_types = general_settings["alert_types"]
|
||||||
|
|
||||||
# router settings
|
# router settings
|
||||||
|
if llm_router is not None:
|
||||||
_router_settings = config_data.get("router_settings", {})
|
_router_settings = config_data.get("router_settings", {})
|
||||||
llm_router.update_settings(**_router_settings)
|
llm_router.update_settings(**_router_settings)
|
||||||
|
|
||||||
|
async def add_deployment(
|
||||||
|
self,
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
- Check db for new models (last 10 most recently updated)
|
||||||
|
- Check if model id's in router already
|
||||||
|
- If not, add to router
|
||||||
|
"""
|
||||||
|
global llm_router, llm_model_list, master_key, general_settings
|
||||||
|
|
||||||
|
try:
|
||||||
|
if master_key is None or not isinstance(master_key, str):
|
||||||
|
raise Exception(
|
||||||
|
f"Master key is not initialized or formatted. master_key={master_key}"
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(f"llm_router: {llm_router}")
|
||||||
|
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||||
|
await self._update_llm_router(
|
||||||
|
new_models=new_models, proxy_logging_obj=proxy_logging_obj
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
"{}\nTraceback:{}".format(str(e), traceback.format_exc())
|
"{}\nTraceback:{}".format(str(e), traceback.format_exc())
|
||||||
|
|
|
@ -206,12 +206,16 @@ class Router:
|
||||||
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
|
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
|
||||||
self.default_max_parallel_requests = default_max_parallel_requests
|
self.default_max_parallel_requests = default_max_parallel_requests
|
||||||
|
|
||||||
if model_list:
|
if model_list is not None:
|
||||||
model_list = copy.deepcopy(model_list)
|
model_list = copy.deepcopy(model_list)
|
||||||
self.set_model_list(model_list)
|
self.set_model_list(model_list)
|
||||||
self.healthy_deployments: List = self.model_list
|
self.healthy_deployments: List = self.model_list # type: ignore
|
||||||
for m in model_list:
|
for m in model_list:
|
||||||
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
|
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
|
||||||
|
else:
|
||||||
|
self.model_list: List = (
|
||||||
|
[]
|
||||||
|
) # initialize an empty list - to allow _add_deployment and delete_deployment to work
|
||||||
|
|
||||||
self.allowed_fails = allowed_fails or litellm.allowed_fails
|
self.allowed_fails = allowed_fails or litellm.allowed_fails
|
||||||
self.cooldown_time = cooldown_time or 1
|
self.cooldown_time = cooldown_time or 1
|
||||||
|
|
|
@ -15,8 +15,9 @@ sys.path.insert(
|
||||||
import pytest, litellm
|
import pytest, litellm
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from litellm.proxy.proxy_server import ProxyConfig
|
from litellm.proxy.proxy_server import ProxyConfig
|
||||||
from litellm.proxy.utils import encrypt_value
|
from litellm.proxy.utils import encrypt_value, ProxyLogging, DualCache
|
||||||
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
|
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
class DBModel(BaseModel):
|
class DBModel(BaseModel):
|
||||||
|
@ -163,6 +164,116 @@ async def test_add_existing_deployment():
|
||||||
assert num_added == 0
|
assert num_added == 0
|
||||||
|
|
||||||
|
|
||||||
|
litellm_params = LiteLLM_Params(
|
||||||
|
model="azure/chatgpt-v-2",
|
||||||
|
api_key=os.getenv("AZURE_API_KEY"),
|
||||||
|
api_base=os.getenv("AZURE_API_BASE"),
|
||||||
|
api_version=os.getenv("AZURE_API_VERSION"),
|
||||||
|
)
|
||||||
|
|
||||||
|
deployment = Deployment(model_name="gpt-3.5-turbo", litellm_params=litellm_params)
|
||||||
|
deployment_2 = Deployment(model_name="gpt-3.5-turbo-2", litellm_params=litellm_params)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_model_list(flag_value: Literal[0, 1], master_key: str):
|
||||||
|
"""
|
||||||
|
0 - empty list
|
||||||
|
1 - list with an element
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
|
||||||
|
new_litellm_params = LiteLLM_Params(
|
||||||
|
model="azure/chatgpt-v-2-3",
|
||||||
|
api_key=os.getenv("AZURE_API_KEY"),
|
||||||
|
api_base=os.getenv("AZURE_API_BASE"),
|
||||||
|
api_version=os.getenv("AZURE_API_VERSION"),
|
||||||
|
)
|
||||||
|
|
||||||
|
encrypted_litellm_params = new_litellm_params.dict(exclude_none=True)
|
||||||
|
|
||||||
|
for k, v in encrypted_litellm_params.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
encrypted_value = encrypt_value(v, master_key)
|
||||||
|
encrypted_litellm_params[k] = base64.b64encode(encrypted_value).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
db_model = DBModel(
|
||||||
|
model_id="12345",
|
||||||
|
model_name="gpt-3.5-turbo",
|
||||||
|
litellm_params=encrypted_litellm_params,
|
||||||
|
model_info={"id": "12345"},
|
||||||
|
)
|
||||||
|
|
||||||
|
db_models = [db_model]
|
||||||
|
|
||||||
|
if flag_value == 0:
|
||||||
|
return []
|
||||||
|
elif flag_value == 1:
|
||||||
|
return db_models
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"llm_router",
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
litellm.Router(),
|
||||||
|
litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
deployment.to_json(exclude_none=True),
|
||||||
|
deployment_2.to_json(exclude_none=True),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_list_flag_value",
|
||||||
|
[0, 1],
|
||||||
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_and_delete_deployments():
|
async def test_add_and_delete_deployments(llm_router, model_list_flag_value):
|
||||||
pass
|
"""
|
||||||
|
Test add + delete logic in 3 scenarios
|
||||||
|
- when router is none
|
||||||
|
- when router is init but empty
|
||||||
|
- when router is init and not empty
|
||||||
|
"""
|
||||||
|
|
||||||
|
master_key = "sk-1234"
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", master_key)
|
||||||
|
pc = ProxyConfig()
|
||||||
|
pl = ProxyLogging(DualCache())
|
||||||
|
|
||||||
|
async def _monkey_patch_get_config(*args, **kwargs):
|
||||||
|
print(f"ENTERS MP GET CONFIG")
|
||||||
|
if llm_router is None:
|
||||||
|
return {}
|
||||||
|
else:
|
||||||
|
print(f"llm_router.model_list: {llm_router.model_list}")
|
||||||
|
return {"model_list": llm_router.model_list}
|
||||||
|
|
||||||
|
pc.get_config = _monkey_patch_get_config
|
||||||
|
|
||||||
|
model_list = _create_model_list(
|
||||||
|
flag_value=model_list_flag_value, master_key=master_key
|
||||||
|
)
|
||||||
|
|
||||||
|
if llm_router is None:
|
||||||
|
prev_llm_router_val = None
|
||||||
|
else:
|
||||||
|
prev_llm_router_val = len(llm_router.model_list)
|
||||||
|
|
||||||
|
await pc._update_llm_router(new_models=model_list, proxy_logging_obj=pl)
|
||||||
|
|
||||||
|
llm_router = getattr(litellm.proxy.proxy_server, "llm_router")
|
||||||
|
|
||||||
|
if model_list_flag_value == 0:
|
||||||
|
if prev_llm_router_val is None:
|
||||||
|
assert prev_llm_router_val == llm_router
|
||||||
|
else:
|
||||||
|
assert prev_llm_router_val == len(llm_router.model_list)
|
||||||
|
else:
|
||||||
|
if prev_llm_router_val is None:
|
||||||
|
assert len(llm_router.model_list) == len(model_list)
|
||||||
|
else:
|
||||||
|
assert len(llm_router.model_list) == len(model_list) + prev_llm_router_val
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue