fix(lowest_tpm_rpm.py): handle null case for text/message input

This commit is contained in:
Krrish Dholakia 2024-01-02 12:24:14 +05:30
parent 11f92c0074
commit 2ab31bcaf8
2 changed files with 11 additions and 6 deletions

View file

@ -142,16 +142,21 @@ class LowestTPMLoggingHandler(CustomLogger):
# ----------------------
lowest_tpm = float("inf")
deployment = None
if tpm_dict is None: # base case
return
item = random.choice(healthy_deployments)
return item
all_deployments = tpm_dict
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments:
all_deployments[d["model_info"]["id"]] = 0
input_tokens = token_counter(messages=messages, text=input)
try:
input_tokens = token_counter(messages=messages, text=input)
except:
input_tokens = 0
for item, item_tpm in all_deployments.items():
## get the item from model list
_deployment = None

View file

@ -214,7 +214,6 @@ def test_router_get_available_deployments():
## CHECK WHAT'S SELECTED ##
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
print(router.get_available_deployment(model="azure-model"))
assert router.get_available_deployment(model="azure-model")["model_info"]["id"] == 2
@ -309,7 +308,6 @@ async def test_router_completion_streaming():
model_list=model_list,
routing_strategy="usage-based-routing",
set_verbose=False,
num_retries=3,
) # type: ignore
### Make 3 calls, test if 3rd call goes to lowest tpm deployment
@ -327,7 +325,9 @@ async def test_router_completion_streaming():
await asyncio.sleep(1) # let the token update happen
current_minute = datetime.now().strftime("%H-%M")
picked_deployment = router.lowesttpm_logger.get_available_deployments(
model_group=model, healthy_deployments=router.healthy_deployments
model_group=model,
healthy_deployments=router.healthy_deployments,
messages=messages,
)
final_response = await router.acompletion(model=model, messages=messages)
print(f"min deployment id: {picked_deployment}")