fix(proxy_server.py): handle router being initialized without a model list

This commit is contained in:
Krrish Dholakia 2024-04-23 10:52:28 -07:00
parent 13b84ca13f
commit f1f08af785
3 changed files with 223 additions and 92 deletions

View file

@ -2470,14 +2470,20 @@ class ProxyConfig:
for k, v in model["litellm_params"].items(): for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"): if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v) model["litellm_params"][k] = litellm.get_secret(v)
model_id = llm_router._generate_model_id(
model_group=model["model_name"], ## check if they have model-id's ##
litellm_params=model["litellm_params"], model_id = model.get("model_info", {}).get("id", None)
) if model_id is None:
## else - generate stable id's ##
model_id = llm_router._generate_model_id(
model_group=model["model_name"],
litellm_params=model["litellm_params"],
)
combined_id_list.append(model_id) # ADD CONFIG MODEL TO COMBINED LIST combined_id_list.append(model_id) # ADD CONFIG MODEL TO COMBINED LIST
router_model_ids = llm_router.get_model_ids() router_model_ids = llm_router.get_model_ids()
# Check for model IDs in llm_router not present in combined_id_list and delete them # Check for model IDs in llm_router not present in combined_id_list and delete them
deleted_deployments = 0 deleted_deployments = 0
for model_id in router_model_ids: for model_id in router_model_ids:
if model_id not in combined_id_list: if model_id not in combined_id_list:
@ -2538,6 +2544,95 @@ class ProxyConfig:
added_models += 1 added_models += 1
return added_models return added_models
async def _update_llm_router(
self,
new_models: list,
proxy_logging_obj: ProxyLogging,
):
global llm_router, llm_model_list, master_key, general_settings
import base64
if llm_router is None and master_key is not None:
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
_model_list: list = []
for m in new_models:
_litellm_params = m.litellm_params
if isinstance(_litellm_params, dict):
# decrypt values
for k, v in _litellm_params.items():
if isinstance(v, str):
# decode base64
decoded_b64 = base64.b64decode(v)
# decrypt value
_litellm_params[k] = decrypt_value(
value=decoded_b64, master_key=master_key # type: ignore
)
_litellm_params = LiteLLM_Params(**_litellm_params)
else:
verbose_proxy_logger.error(
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
)
continue # skip to next model
_model_info = self.get_model_info_with_id(model=m)
_model_list.append(
Deployment(
model_name=m.model_name,
litellm_params=_litellm_params,
model_info=_model_info,
).to_json(exclude_none=True)
)
if len(_model_list) > 0:
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
llm_router = litellm.Router(model_list=_model_list)
verbose_proxy_logger.debug(f"updated llm_router: {llm_router}")
else:
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
## DELETE MODEL LOGIC
await self._delete_deployment(db_models=new_models)
## ADD MODEL LOGIC
self._add_deployment(db_models=new_models)
if llm_router is not None:
llm_model_list = llm_router.get_model_list()
# check if user set any callbacks in Config Table
config_data = await proxy_config.get_config()
litellm_settings = config_data.get("litellm_settings", {}) or {}
success_callbacks = litellm_settings.get("success_callback", None)
if success_callbacks is not None and isinstance(success_callbacks, list):
for success_callback in success_callbacks:
if success_callback not in litellm.success_callback:
litellm.success_callback.append(success_callback)
# we need to set env variables too
environment_variables = config_data.get("environment_variables", {})
for k, v in environment_variables.items():
try:
decoded_b64 = base64.b64decode(v)
value = decrypt_value(value=decoded_b64, master_key=master_key) # type: ignore
os.environ[k] = value
except Exception as e:
verbose_proxy_logger.error(
"Error setting env variable: %s - %s", k, str(e)
)
# general_settings
_general_settings = config_data.get("general_settings", {})
if "alerting" in _general_settings:
general_settings["alerting"] = _general_settings["alerting"]
proxy_logging_obj.alerting = general_settings["alerting"]
if "alert_types" in _general_settings:
general_settings["alert_types"] = _general_settings["alert_types"]
proxy_logging_obj.alert_types = general_settings["alert_types"]
# router settings
if llm_router is not None:
_router_settings = config_data.get("router_settings", {})
llm_router.update_settings(**_router_settings)
async def add_deployment( async def add_deployment(
self, self,
prisma_client: PrismaClient, prisma_client: PrismaClient,
@ -2550,95 +2645,16 @@ class ProxyConfig:
""" """
global llm_router, llm_model_list, master_key, general_settings global llm_router, llm_model_list, master_key, general_settings
import base64
try: try:
if master_key is None or not isinstance(master_key, str): if master_key is None or not isinstance(master_key, str):
raise Exception( raise Exception(
f"Master key is not initialized or formatted. master_key={master_key}" f"Master key is not initialized or formatted. master_key={master_key}"
) )
verbose_proxy_logger.debug(f"llm_router: {llm_router}") verbose_proxy_logger.debug(f"llm_router: {llm_router}")
if llm_router is None: new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
new_models = ( await self._update_llm_router(
await prisma_client.db.litellm_proxymodeltable.find_many() new_models=new_models, proxy_logging_obj=proxy_logging_obj
) # get all models in db )
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
_model_list: list = []
for m in new_models:
_litellm_params = m.litellm_params
if isinstance(_litellm_params, dict):
# decrypt values
for k, v in _litellm_params.items():
if isinstance(v, str):
# decode base64
decoded_b64 = base64.b64decode(v)
# decrypt value
_litellm_params[k] = decrypt_value(
value=decoded_b64, master_key=master_key
)
_litellm_params = LiteLLM_Params(**_litellm_params)
else:
verbose_proxy_logger.error(
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
)
continue # skip to next model
_model_info = self.get_model_info_with_id(model=m)
_model_list.append(
Deployment(
model_name=m.model_name,
litellm_params=_litellm_params,
model_info=_model_info,
).to_json(exclude_none=True)
)
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
llm_router = litellm.Router(model_list=_model_list)
verbose_proxy_logger.debug(f"updated llm_router: {llm_router}")
else:
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
## DELETE MODEL LOGIC
await self._delete_deployment(db_models=new_models)
## ADD MODEL LOGIC
self._add_deployment(db_models=new_models)
llm_model_list = llm_router.get_model_list()
# check if user set any callbacks in Config Table
config_data = await proxy_config.get_config()
litellm_settings = config_data.get("litellm_settings", {}) or {}
success_callbacks = litellm_settings.get("success_callback", None)
if success_callbacks is not None and isinstance(success_callbacks, list):
for success_callback in success_callbacks:
if success_callback not in litellm.success_callback:
litellm.success_callback.append(success_callback)
# we need to set env variables too
environment_variables = config_data.get("environment_variables", {})
for k, v in environment_variables.items():
try:
decoded_b64 = base64.b64decode(v)
value = decrypt_value(value=decoded_b64, master_key=master_key)
os.environ[k] = value
except Exception as e:
verbose_proxy_logger.error(
"Error setting env variable: %s - %s", k, str(e)
)
# general_settings
_general_settings = config_data.get("general_settings", {})
if "alerting" in _general_settings:
general_settings["alerting"] = _general_settings["alerting"]
proxy_logging_obj.alerting = general_settings["alerting"]
if "alert_types" in _general_settings:
general_settings["alert_types"] = _general_settings["alert_types"]
proxy_logging_obj.alert_types = general_settings["alert_types"]
# router settings
_router_settings = config_data.get("router_settings", {})
llm_router.update_settings(**_router_settings)
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
"{}\nTraceback:{}".format(str(e), traceback.format_exc()) "{}\nTraceback:{}".format(str(e), traceback.format_exc())

View file

@ -206,12 +206,16 @@ class Router:
self.default_deployment = None # use this to track the users default deployment, when they want to use model = * self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
self.default_max_parallel_requests = default_max_parallel_requests self.default_max_parallel_requests = default_max_parallel_requests
if model_list: if model_list is not None:
model_list = copy.deepcopy(model_list) model_list = copy.deepcopy(model_list)
self.set_model_list(model_list) self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list self.healthy_deployments: List = self.model_list # type: ignore
for m in model_list: for m in model_list:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0 self.deployment_latency_map[m["litellm_params"]["model"]] = 0
else:
self.model_list: List = (
[]
) # initialize an empty list - to allow _add_deployment and delete_deployment to work
self.allowed_fails = allowed_fails or litellm.allowed_fails self.allowed_fails = allowed_fails or litellm.allowed_fails
self.cooldown_time = cooldown_time or 1 self.cooldown_time = cooldown_time or 1

View file

@ -15,8 +15,9 @@ sys.path.insert(
import pytest, litellm import pytest, litellm
from pydantic import BaseModel from pydantic import BaseModel
from litellm.proxy.proxy_server import ProxyConfig from litellm.proxy.proxy_server import ProxyConfig
from litellm.proxy.utils import encrypt_value from litellm.proxy.utils import encrypt_value, ProxyLogging, DualCache
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
from typing import Literal
class DBModel(BaseModel): class DBModel(BaseModel):
@ -163,6 +164,116 @@ async def test_add_existing_deployment():
assert num_added == 0 assert num_added == 0
litellm_params = LiteLLM_Params(
model="azure/chatgpt-v-2",
api_key=os.getenv("AZURE_API_KEY"),
api_base=os.getenv("AZURE_API_BASE"),
api_version=os.getenv("AZURE_API_VERSION"),
)
deployment = Deployment(model_name="gpt-3.5-turbo", litellm_params=litellm_params)
deployment_2 = Deployment(model_name="gpt-3.5-turbo-2", litellm_params=litellm_params)
def _create_model_list(flag_value: Literal[0, 1], master_key: str):
"""
0 - empty list
1 - list with an element
"""
import base64
new_litellm_params = LiteLLM_Params(
model="azure/chatgpt-v-2-3",
api_key=os.getenv("AZURE_API_KEY"),
api_base=os.getenv("AZURE_API_BASE"),
api_version=os.getenv("AZURE_API_VERSION"),
)
encrypted_litellm_params = new_litellm_params.dict(exclude_none=True)
for k, v in encrypted_litellm_params.items():
if isinstance(v, str):
encrypted_value = encrypt_value(v, master_key)
encrypted_litellm_params[k] = base64.b64encode(encrypted_value).decode(
"utf-8"
)
db_model = DBModel(
model_id="12345",
model_name="gpt-3.5-turbo",
litellm_params=encrypted_litellm_params,
model_info={"id": "12345"},
)
db_models = [db_model]
if flag_value == 0:
return []
elif flag_value == 1:
return db_models
@pytest.mark.parametrize(
"llm_router",
[
None,
litellm.Router(),
litellm.Router(
model_list=[
deployment.to_json(exclude_none=True),
deployment_2.to_json(exclude_none=True),
]
),
],
)
@pytest.mark.parametrize(
"model_list_flag_value",
[0, 1],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_and_delete_deployments(): async def test_add_and_delete_deployments(llm_router, model_list_flag_value):
pass """
Test add + delete logic in 3 scenarios
- when router is none
- when router is init but empty
- when router is init and not empty
"""
master_key = "sk-1234"
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
setattr(litellm.proxy.proxy_server, "master_key", master_key)
pc = ProxyConfig()
pl = ProxyLogging(DualCache())
async def _monkey_patch_get_config(*args, **kwargs):
print(f"ENTERS MP GET CONFIG")
if llm_router is None:
return {}
else:
print(f"llm_router.model_list: {llm_router.model_list}")
return {"model_list": llm_router.model_list}
pc.get_config = _monkey_patch_get_config
model_list = _create_model_list(
flag_value=model_list_flag_value, master_key=master_key
)
if llm_router is None:
prev_llm_router_val = None
else:
prev_llm_router_val = len(llm_router.model_list)
await pc._update_llm_router(new_models=model_list, proxy_logging_obj=pl)
llm_router = getattr(litellm.proxy.proxy_server, "llm_router")
if model_list_flag_value == 0:
if prev_llm_router_val is None:
assert prev_llm_router_val == llm_router
else:
assert prev_llm_router_val == len(llm_router.model_list)
else:
if prev_llm_router_val is None:
assert len(llm_router.model_list) == len(model_list)
else:
assert len(llm_router.model_list) == len(model_list) + prev_llm_router_val