forked from phoenix/litellm-mirror
fix(lowest_tpm_rpm_routing.py): broaden scope of get deployment logic
This commit is contained in:
parent
a6719caebd
commit
b66cf0aa43
3 changed files with 90 additions and 22 deletions
|
@ -1622,7 +1622,7 @@ class Router:
|
|||
and self.lowesttpm_logger is not None
|
||||
):
|
||||
min_deployment = self.lowesttpm_logger.get_available_deployments(
|
||||
model_group=model
|
||||
model_group=model, healthy_deployments=healthy_deployments
|
||||
)
|
||||
if min_deployment is None:
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#### What this does ####
|
||||
# identifies lowest tpm deployment
|
||||
|
||||
import dotenv, os, requests
|
||||
import dotenv, os, requests, random
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
@ -118,7 +118,7 @@ class LowestTPMLoggingHandler(CustomLogger):
|
|||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
def get_available_deployments(self, model_group: str):
|
||||
def get_available_deployments(self, model_group: str, healthy_deployments: list):
|
||||
"""
|
||||
Returns a deployment with the lowest TPM/RPM usage.
|
||||
"""
|
||||
|
@ -139,15 +139,22 @@ class LowestTPMLoggingHandler(CustomLogger):
|
|||
if tpm_dict is None: # base case
|
||||
return
|
||||
|
||||
for item, item_tpm in tpm_dict.items():
|
||||
all_deployments = tpm_dict
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
if d["model_info"]["id"] not in all_deployments:
|
||||
all_deployments[d["model_info"]["id"]] = 0
|
||||
|
||||
for item, item_tpm in all_deployments.items():
|
||||
## get the item from model list
|
||||
_deployment = None
|
||||
for m in self.model_list:
|
||||
for m in healthy_deployments:
|
||||
if item == m["model_info"]["id"]:
|
||||
_deployment = m
|
||||
|
||||
if _deployment is None:
|
||||
break
|
||||
continue # skip to next one
|
||||
|
||||
_deployment_tpm = (
|
||||
_deployment.get("tpm", None)
|
||||
or _deployment.get("litellm_params", {}).get("tpm", None)
|
||||
|
@ -163,7 +170,8 @@ class LowestTPMLoggingHandler(CustomLogger):
|
|||
)
|
||||
|
||||
if item_tpm == 0:
|
||||
return item
|
||||
deployment = _deployment
|
||||
break
|
||||
elif (
|
||||
item_tpm > _deployment_tpm or rpm_dict[item] + 1 >= _deployment_rpm
|
||||
): # if user passed in tpm / rpm in the model_list
|
||||
|
@ -171,4 +179,6 @@ class LowestTPMLoggingHandler(CustomLogger):
|
|||
elif item_tpm < lowest_tpm:
|
||||
lowest_tpm = item_tpm
|
||||
deployment = _deployment
|
||||
if deployment is None:
|
||||
deployment = random.choice(healthy_deployments)
|
||||
return deployment
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#### What this tests ####
|
||||
# This tests the router's ability to pick deployment with lowest tpm
|
||||
|
||||
import sys, os, asyncio, time
|
||||
import sys, os, asyncio, time, random
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
@ -120,11 +120,15 @@ def test_get_available_deployments():
|
|||
)
|
||||
|
||||
## CHECK WHAT'S SELECTED ##
|
||||
print(lowest_tpm_logger.get_available_deployments(model_group=model_group))
|
||||
print(
|
||||
lowest_tpm_logger.get_available_deployments(
|
||||
model_group=model_group, healthy_deployments=model_list
|
||||
)
|
||||
)
|
||||
assert (
|
||||
lowest_tpm_logger.get_available_deployments(model_group=model_group)[
|
||||
"model_info"
|
||||
]["id"]
|
||||
lowest_tpm_logger.get_available_deployments(
|
||||
model_group=model_group, healthy_deployments=model_list
|
||||
)["model_info"]["id"]
|
||||
== "5678"
|
||||
)
|
||||
|
||||
|
@ -157,16 +161,6 @@ def test_router_get_available_deployments():
|
|||
},
|
||||
"model_info": {"id": 2},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||
"rpm": 6,
|
||||
},
|
||||
"model_info": {"id": 3},
|
||||
},
|
||||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
|
@ -224,3 +218,67 @@ def test_router_get_available_deployments():
|
|||
|
||||
|
||||
# test_router_get_available_deployments()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_completion_streaming():
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
|
||||
]
|
||||
model = "azure-model"
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-turbo",
|
||||
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||
"rpm": 1440,
|
||||
},
|
||||
"model_info": {"id": 1},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||
"rpm": 6,
|
||||
},
|
||||
"model_info": {"id": 2},
|
||||
},
|
||||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy="usage-based-routing",
|
||||
set_verbose=False,
|
||||
num_retries=3,
|
||||
) # type: ignore
|
||||
|
||||
### Make 3 calls, test if 3rd call goes to lowest tpm deployment
|
||||
|
||||
## CALL 1+2
|
||||
tasks = []
|
||||
response = None
|
||||
final_response = None
|
||||
for _ in range(2):
|
||||
tasks.append(router.acompletion(model=model, messages=messages))
|
||||
response = await asyncio.gather(*tasks)
|
||||
|
||||
if response is not None:
|
||||
## CALL 3
|
||||
await asyncio.sleep(1) # let the token update happen
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
picked_deployment = router.lowesttpm_logger.get_available_deployments(
|
||||
model_group=model, healthy_deployments=router.healthy_deployments
|
||||
)
|
||||
final_response = await router.acompletion(model=model, messages=messages)
|
||||
print(f"min deployment id: {picked_deployment}")
|
||||
print(f"model id: {final_response._hidden_params['model_id']}")
|
||||
assert (
|
||||
final_response._hidden_params["model_id"]
|
||||
== picked_deployment["model_info"]["id"]
|
||||
)
|
||||
|
||||
|
||||
# asyncio.run(test_router_completion_streaming())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue