fix(router.py): handle id being passed in as int

This commit is contained in:
Krrish Dholakia 2024-04-04 14:22:50 -07:00
parent 9638e244f8
commit 48a5948081
7 changed files with 29 additions and 4 deletions

View file

@ -37,9 +37,11 @@ class ModelInfo(BaseModel):
str
] # Allow id to be optional on input, but it will always be present as a str in the model instance
def __init__(self, id: Optional[str] = None, **params):
def __init__(self, id: Optional[Union[str, int]] = None, **params):
if id is None:
id = str(uuid.uuid4()) # Generate a UUID if id is None or not provided
elif isinstance(id, int):
id = str(id)
super().__init__(id=id, **params)
class Config:

View file

@ -41,6 +41,8 @@ class LeastBusyLoggingHandler(CustomLogger):
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# update cache
@ -67,6 +69,8 @@ class LeastBusyLoggingHandler(CustomLogger):
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
@ -95,6 +99,8 @@ class LeastBusyLoggingHandler(CustomLogger):
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
@ -124,6 +130,8 @@ class LeastBusyLoggingHandler(CustomLogger):
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
@ -152,6 +160,8 @@ class LeastBusyLoggingHandler(CustomLogger):
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache

View file

@ -57,6 +57,8 @@ class LowestLatencyLoggingHandler(CustomLogger):
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
# ------------
# Setup values
@ -139,6 +141,8 @@ class LowestLatencyLoggingHandler(CustomLogger):
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
# ------------
# Setup values

View file

@ -39,6 +39,8 @@ class LowestTPMLoggingHandler(CustomLogger):
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = response_obj["usage"]["total_tokens"]
@ -87,6 +89,8 @@ class LowestTPMLoggingHandler(CustomLogger):
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = response_obj["usage"]["total_tokens"]

View file

@ -112,7 +112,7 @@ def test_router_get_available_deployments():
deployment = router.get_available_deployment(model=model_group, messages=None)
print(f"deployment: {deployment}")
assert deployment["model_info"]["id"] == 1
assert deployment["model_info"]["id"] == "1"
## run router completion - assert completion event, no change in 'busy'ness once calls are complete

View file

@ -337,7 +337,9 @@ def test_router_get_available_deployments():
## CHECK WHAT'S SELECTED ##
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
print(router.get_available_deployment(model="azure-model"))
assert router.get_available_deployment(model="azure-model")["model_info"]["id"] == 2
assert (
router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2"
)
# test_router_get_available_deployments()

View file

@ -173,6 +173,7 @@ def test_router_get_available_deployments():
num_retries=3,
) # type: ignore
print(f"router id's: {router.get_model_ids()}")
## DEPLOYMENT 1 ##
deployment_id = 1
kwargs = {
@ -214,7 +215,9 @@ def test_router_get_available_deployments():
## CHECK WHAT'S SELECTED ##
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
assert router.get_available_deployment(model="azure-model")["model_info"]["id"] == 2
assert (
router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2"
)
# test_get_available_deployments()