fix(router.py): fix pydantic object logic

This commit is contained in:
Krrish Dholakia 2024-04-03 21:57:19 -07:00
parent ef2f6ef6a2
commit 20849cbbfc
2 changed files with 16 additions and 8 deletions

View file

@ -2143,12 +2143,17 @@ class Router:
import os
for model in original_model_list:
_model_name = model.pop("model_name")
_litellm_params = model.pop("litellm_params")
_model_info = model.pop("model_info", {})
deployment = Deployment(
model_name=model["model_name"],
litellm_params=model["litellm_params"],
model_info=model.get("model_info", {}),
**model,
model_name=_model_name,
litellm_params=_litellm_params,
model_info=_model_info,
)
self._add_deployment(deployment=deployment)
deployment = self._add_deployment(deployment=deployment)
model = deployment.to_json(exclude_none=True)
@ -2157,7 +2162,7 @@ class Router:
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
self.model_names = [m["model_name"] for m in model_list]
def _add_deployment(self, deployment: Deployment):
def _add_deployment(self, deployment: Deployment) -> Deployment:
import os
#### DEPLOYMENT NAMES INIT ########
@ -2194,7 +2199,7 @@ class Router:
# Check if user is trying to use model_name == "*"
# this is a catch all model for their specific api key
if deployment.model_name == "*":
self.default_deployment = deployment
self.default_deployment = deployment.to_json(exclude_none=True)
# Azure GPT-Vision Enhancements, users can pass os.environ/
data_sources = deployment.litellm_params.get("dataSources", [])
@ -2214,6 +2219,8 @@ class Router:
# init OpenAI, Azure clients
self.set_client(model=deployment.to_json(exclude_none=True))
return deployment
def add_deployment(self, deployment: Deployment):
# check if deployment already exists
@ -2450,6 +2457,7 @@ class Router:
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
deployment = self.leastbusy_logger.get_available_deployments(
model_group=model, healthy_deployments=healthy_deployments

View file

@ -181,7 +181,7 @@ def test_weighted_selection_router_tpm_as_router_param():
pytest.fail(f"Error occurred: {e}")
test_weighted_selection_router_tpm_as_router_param()
# test_weighted_selection_router_tpm_as_router_param()
def test_weighted_selection_router_rpm_as_router_param():
@ -433,7 +433,7 @@ def test_usage_based_routing():
selection_counts[response["model"]] += 1
# print("selection counts", selection_counts)
print("selection counts", selection_counts)
total_requests = sum(selection_counts.values())