forked from phoenix/litellm-mirror
fix(router.py): fix pydantic object logic
This commit is contained in:
parent
ef2f6ef6a2
commit
20849cbbfc
2 changed files with 16 additions and 8 deletions
|
@ -2143,12 +2143,17 @@ class Router:
|
|||
import os
|
||||
|
||||
for model in original_model_list:
|
||||
_model_name = model.pop("model_name")
|
||||
_litellm_params = model.pop("litellm_params")
|
||||
_model_info = model.pop("model_info", {})
|
||||
deployment = Deployment(
|
||||
model_name=model["model_name"],
|
||||
litellm_params=model["litellm_params"],
|
||||
model_info=model.get("model_info", {}),
|
||||
**model,
|
||||
model_name=_model_name,
|
||||
litellm_params=_litellm_params,
|
||||
model_info=_model_info,
|
||||
)
|
||||
self._add_deployment(deployment=deployment)
|
||||
|
||||
deployment = self._add_deployment(deployment=deployment)
|
||||
|
||||
model = deployment.to_json(exclude_none=True)
|
||||
|
||||
|
@ -2157,7 +2162,7 @@ class Router:
|
|||
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
|
||||
self.model_names = [m["model_name"] for m in model_list]
|
||||
|
||||
def _add_deployment(self, deployment: Deployment):
|
||||
def _add_deployment(self, deployment: Deployment) -> Deployment:
|
||||
import os
|
||||
|
||||
#### DEPLOYMENT NAMES INIT ########
|
||||
|
@ -2194,7 +2199,7 @@ class Router:
|
|||
# Check if user is trying to use model_name == "*"
|
||||
# this is a catch all model for their specific api key
|
||||
if deployment.model_name == "*":
|
||||
self.default_deployment = deployment
|
||||
self.default_deployment = deployment.to_json(exclude_none=True)
|
||||
|
||||
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
||||
data_sources = deployment.litellm_params.get("dataSources", [])
|
||||
|
@ -2214,6 +2219,8 @@ class Router:
|
|||
# init OpenAI, Azure clients
|
||||
self.set_client(model=deployment.to_json(exclude_none=True))
|
||||
|
||||
return deployment
|
||||
|
||||
def add_deployment(self, deployment: Deployment):
|
||||
# check if deployment already exists
|
||||
|
||||
|
@ -2450,6 +2457,7 @@ class Router:
|
|||
model = litellm.model_alias_map[
|
||||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
|
||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||
deployment = self.leastbusy_logger.get_available_deployments(
|
||||
model_group=model, healthy_deployments=healthy_deployments
|
||||
|
|
|
@ -181,7 +181,7 @@ def test_weighted_selection_router_tpm_as_router_param():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
test_weighted_selection_router_tpm_as_router_param()
|
||||
# test_weighted_selection_router_tpm_as_router_param()
|
||||
|
||||
|
||||
def test_weighted_selection_router_rpm_as_router_param():
|
||||
|
@ -433,7 +433,7 @@ def test_usage_based_routing():
|
|||
|
||||
selection_counts[response["model"]] += 1
|
||||
|
||||
# print("selection counts", selection_counts)
|
||||
print("selection counts", selection_counts)
|
||||
|
||||
total_requests = sum(selection_counts.values())
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue