fix(lowest_tpm_rpm_routing.py): broaden scope of get deployment logic

This commit is contained in:
Krrish Dholakia 2023-12-30 13:27:50 +05:30
parent a6719caebd
commit b66cf0aa43
3 changed files with 90 additions and 22 deletions

View file

@ -1622,7 +1622,7 @@ class Router:
and self.lowesttpm_logger is not None and self.lowesttpm_logger is not None
): ):
min_deployment = self.lowesttpm_logger.get_available_deployments( min_deployment = self.lowesttpm_logger.get_available_deployments(
model_group=model model_group=model, healthy_deployments=healthy_deployments
) )
if min_deployment is None: if min_deployment is None:
min_deployment = random.choice(healthy_deployments) min_deployment = random.choice(healthy_deployments)

View file

@ -1,7 +1,7 @@
#### What this does #### #### What this does ####
# identifies lowest tpm deployment # identifies lowest tpm deployment
import dotenv, os, requests import dotenv, os, requests, random
from typing import Optional from typing import Optional
from datetime import datetime from datetime import datetime
@ -118,7 +118,7 @@ class LowestTPMLoggingHandler(CustomLogger):
traceback.print_exc() traceback.print_exc()
pass pass
def get_available_deployments(self, model_group: str): def get_available_deployments(self, model_group: str, healthy_deployments: list):
""" """
Returns a deployment with the lowest TPM/RPM usage. Returns a deployment with the lowest TPM/RPM usage.
""" """
@ -139,15 +139,22 @@ class LowestTPMLoggingHandler(CustomLogger):
if tpm_dict is None: # base case if tpm_dict is None: # base case
return return
for item, item_tpm in tpm_dict.items(): all_deployments = tpm_dict
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments:
all_deployments[d["model_info"]["id"]] = 0
for item, item_tpm in all_deployments.items():
## get the item from model list ## get the item from model list
_deployment = None _deployment = None
for m in self.model_list: for m in healthy_deployments:
if item == m["model_info"]["id"]: if item == m["model_info"]["id"]:
_deployment = m _deployment = m
if _deployment is None: if _deployment is None:
break continue # skip to next one
_deployment_tpm = ( _deployment_tpm = (
_deployment.get("tpm", None) _deployment.get("tpm", None)
or _deployment.get("litellm_params", {}).get("tpm", None) or _deployment.get("litellm_params", {}).get("tpm", None)
@ -163,7 +170,8 @@ class LowestTPMLoggingHandler(CustomLogger):
) )
if item_tpm == 0: if item_tpm == 0:
return item deployment = _deployment
break
elif ( elif (
item_tpm > _deployment_tpm or rpm_dict[item] + 1 >= _deployment_rpm item_tpm > _deployment_tpm or rpm_dict[item] + 1 >= _deployment_rpm
): # if user passed in tpm / rpm in the model_list ): # if user passed in tpm / rpm in the model_list
@ -171,4 +179,6 @@ class LowestTPMLoggingHandler(CustomLogger):
elif item_tpm < lowest_tpm: elif item_tpm < lowest_tpm:
lowest_tpm = item_tpm lowest_tpm = item_tpm
deployment = _deployment deployment = _deployment
if deployment is None:
deployment = random.choice(healthy_deployments)
return deployment return deployment

View file

@ -1,7 +1,7 @@
#### What this tests #### #### What this tests ####
# This tests the router's ability to pick deployment with lowest tpm # This tests the router's ability to pick deployment with lowest tpm
import sys, os, asyncio, time import sys, os, asyncio, time, random
from datetime import datetime from datetime import datetime
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
@ -120,11 +120,15 @@ def test_get_available_deployments():
) )
## CHECK WHAT'S SELECTED ## ## CHECK WHAT'S SELECTED ##
print(lowest_tpm_logger.get_available_deployments(model_group=model_group)) print(
lowest_tpm_logger.get_available_deployments(
model_group=model_group, healthy_deployments=model_list
)
)
assert ( assert (
lowest_tpm_logger.get_available_deployments(model_group=model_group)[ lowest_tpm_logger.get_available_deployments(
"model_info" model_group=model_group, healthy_deployments=model_list
]["id"] )["model_info"]["id"]
== "5678" == "5678"
) )
@ -157,16 +161,6 @@ def test_router_get_available_deployments():
}, },
"model_info": {"id": 2}, "model_info": {"id": 2},
}, },
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_CANADA_API_KEY",
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 3},
},
] ]
router = Router( router = Router(
model_list=model_list, model_list=model_list,
@ -224,3 +218,67 @@ def test_router_get_available_deployments():
# test_router_get_available_deployments() # test_router_get_available_deployments()
@pytest.mark.asyncio
async def test_router_completion_streaming():
messages = [
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
]
model = "azure-model"
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing",
set_verbose=False,
num_retries=3,
) # type: ignore
### Make 3 calls, test if 3rd call goes to lowest tpm deployment
## CALL 1+2
tasks = []
response = None
final_response = None
for _ in range(2):
tasks.append(router.acompletion(model=model, messages=messages))
response = await asyncio.gather(*tasks)
if response is not None:
## CALL 3
await asyncio.sleep(1) # let the token update happen
current_minute = datetime.now().strftime("%H-%M")
picked_deployment = router.lowesttpm_logger.get_available_deployments(
model_group=model, healthy_deployments=router.healthy_deployments
)
final_response = await router.acompletion(model=model, messages=messages)
print(f"min deployment id: {picked_deployment}")
print(f"model id: {final_response._hidden_params['model_id']}")
assert (
final_response._hidden_params["model_id"]
== picked_deployment["model_info"]["id"]
)
# asyncio.run(test_router_completion_streaming())