fix(lowest_tpm_rpm_routing.py): broaden scope of get deployment logic

This commit is contained in:
Krrish Dholakia 2023-12-30 13:27:50 +05:30
parent a6719caebd
commit b66cf0aa43
3 changed files with 90 additions and 22 deletions

View file

@ -1622,7 +1622,7 @@ class Router:
and self.lowesttpm_logger is not None
):
min_deployment = self.lowesttpm_logger.get_available_deployments(
model_group=model
model_group=model, healthy_deployments=healthy_deployments
)
if min_deployment is None:
min_deployment = random.choice(healthy_deployments)

View file

@ -1,7 +1,7 @@
#### What this does ####
# identifies lowest tpm deployment
import dotenv, os, requests
import dotenv, os, requests, random
from typing import Optional
from datetime import datetime
@ -118,7 +118,7 @@ class LowestTPMLoggingHandler(CustomLogger):
traceback.print_exc()
pass
def get_available_deployments(self, model_group: str):
def get_available_deployments(self, model_group: str, healthy_deployments: list):
"""
Returns a deployment with the lowest TPM/RPM usage.
"""
@ -139,15 +139,22 @@ class LowestTPMLoggingHandler(CustomLogger):
if tpm_dict is None: # base case
return
for item, item_tpm in tpm_dict.items():
all_deployments = tpm_dict
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments:
all_deployments[d["model_info"]["id"]] = 0
for item, item_tpm in all_deployments.items():
## get the item from model list
_deployment = None
for m in self.model_list:
for m in healthy_deployments:
if item == m["model_info"]["id"]:
_deployment = m
if _deployment is None:
break
continue # skip to next one
_deployment_tpm = (
_deployment.get("tpm", None)
or _deployment.get("litellm_params", {}).get("tpm", None)
@ -163,7 +170,8 @@ class LowestTPMLoggingHandler(CustomLogger):
)
if item_tpm == 0:
return item
deployment = _deployment
break
elif (
item_tpm > _deployment_tpm or rpm_dict[item] + 1 >= _deployment_rpm
): # if user passed in tpm / rpm in the model_list
@ -171,4 +179,6 @@ class LowestTPMLoggingHandler(CustomLogger):
elif item_tpm < lowest_tpm:
lowest_tpm = item_tpm
deployment = _deployment
if deployment is None:
deployment = random.choice(healthy_deployments)
return deployment

View file

@ -1,7 +1,7 @@
#### What this tests ####
# This tests the router's ability to pick deployment with lowest tpm
import sys, os, asyncio, time
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
@ -120,11 +120,15 @@ def test_get_available_deployments():
)
## CHECK WHAT'S SELECTED ##
print(lowest_tpm_logger.get_available_deployments(model_group=model_group))
print(
lowest_tpm_logger.get_available_deployments(
model_group=model_group, healthy_deployments=model_list
)
)
assert (
lowest_tpm_logger.get_available_deployments(model_group=model_group)[
"model_info"
]["id"]
lowest_tpm_logger.get_available_deployments(
model_group=model_group, healthy_deployments=model_list
)["model_info"]["id"]
== "5678"
)
@ -157,16 +161,6 @@ def test_router_get_available_deployments():
},
"model_info": {"id": 2},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_CANADA_API_KEY",
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 3},
},
]
router = Router(
model_list=model_list,
@ -224,3 +218,67 @@ def test_router_get_available_deployments():
# test_router_get_available_deployments()
@pytest.mark.asyncio
async def test_router_completion_streaming():
messages = [
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
]
model = "azure-model"
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing",
set_verbose=False,
num_retries=3,
) # type: ignore
### Make 3 calls, test if 3rd call goes to lowest tpm deployment
## CALL 1+2
tasks = []
response = None
final_response = None
for _ in range(2):
tasks.append(router.acompletion(model=model, messages=messages))
response = await asyncio.gather(*tasks)
if response is not None:
## CALL 3
await asyncio.sleep(1) # let the token update happen
current_minute = datetime.now().strftime("%H-%M")
picked_deployment = router.lowesttpm_logger.get_available_deployments(
model_group=model, healthy_deployments=router.healthy_deployments
)
final_response = await router.acompletion(model=model, messages=messages)
print(f"min deployment id: {picked_deployment}")
print(f"model id: {final_response._hidden_params['model_id']}")
assert (
final_response._hidden_params["model_id"]
== picked_deployment["model_info"]["id"]
)
# asyncio.run(test_router_completion_streaming())