forked from phoenix/litellm-mirror
fix(lowest_tpm_rpm_routing.py): broaden scope of get deployment logic
This commit is contained in:
parent
a6719caebd
commit
b66cf0aa43
3 changed files with 90 additions and 22 deletions
|
@ -1622,7 +1622,7 @@ class Router:
|
||||||
and self.lowesttpm_logger is not None
|
and self.lowesttpm_logger is not None
|
||||||
):
|
):
|
||||||
min_deployment = self.lowesttpm_logger.get_available_deployments(
|
min_deployment = self.lowesttpm_logger.get_available_deployments(
|
||||||
model_group=model
|
model_group=model, healthy_deployments=healthy_deployments
|
||||||
)
|
)
|
||||||
if min_deployment is None:
|
if min_deployment is None:
|
||||||
min_deployment = random.choice(healthy_deployments)
|
min_deployment = random.choice(healthy_deployments)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# identifies lowest tpm deployment
|
# identifies lowest tpm deployment
|
||||||
|
|
||||||
import dotenv, os, requests
|
import dotenv, os, requests, random
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ class LowestTPMLoggingHandler(CustomLogger):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_available_deployments(self, model_group: str):
|
def get_available_deployments(self, model_group: str, healthy_deployments: list):
|
||||||
"""
|
"""
|
||||||
Returns a deployment with the lowest TPM/RPM usage.
|
Returns a deployment with the lowest TPM/RPM usage.
|
||||||
"""
|
"""
|
||||||
|
@ -139,15 +139,22 @@ class LowestTPMLoggingHandler(CustomLogger):
|
||||||
if tpm_dict is None: # base case
|
if tpm_dict is None: # base case
|
||||||
return
|
return
|
||||||
|
|
||||||
for item, item_tpm in tpm_dict.items():
|
all_deployments = tpm_dict
|
||||||
|
for d in healthy_deployments:
|
||||||
|
## if healthy deployment not yet used
|
||||||
|
if d["model_info"]["id"] not in all_deployments:
|
||||||
|
all_deployments[d["model_info"]["id"]] = 0
|
||||||
|
|
||||||
|
for item, item_tpm in all_deployments.items():
|
||||||
## get the item from model list
|
## get the item from model list
|
||||||
_deployment = None
|
_deployment = None
|
||||||
for m in self.model_list:
|
for m in healthy_deployments:
|
||||||
if item == m["model_info"]["id"]:
|
if item == m["model_info"]["id"]:
|
||||||
_deployment = m
|
_deployment = m
|
||||||
|
|
||||||
if _deployment is None:
|
if _deployment is None:
|
||||||
break
|
continue # skip to next one
|
||||||
|
|
||||||
_deployment_tpm = (
|
_deployment_tpm = (
|
||||||
_deployment.get("tpm", None)
|
_deployment.get("tpm", None)
|
||||||
or _deployment.get("litellm_params", {}).get("tpm", None)
|
or _deployment.get("litellm_params", {}).get("tpm", None)
|
||||||
|
@ -163,7 +170,8 @@ class LowestTPMLoggingHandler(CustomLogger):
|
||||||
)
|
)
|
||||||
|
|
||||||
if item_tpm == 0:
|
if item_tpm == 0:
|
||||||
return item
|
deployment = _deployment
|
||||||
|
break
|
||||||
elif (
|
elif (
|
||||||
item_tpm > _deployment_tpm or rpm_dict[item] + 1 >= _deployment_rpm
|
item_tpm > _deployment_tpm or rpm_dict[item] + 1 >= _deployment_rpm
|
||||||
): # if user passed in tpm / rpm in the model_list
|
): # if user passed in tpm / rpm in the model_list
|
||||||
|
@ -171,4 +179,6 @@ class LowestTPMLoggingHandler(CustomLogger):
|
||||||
elif item_tpm < lowest_tpm:
|
elif item_tpm < lowest_tpm:
|
||||||
lowest_tpm = item_tpm
|
lowest_tpm = item_tpm
|
||||||
deployment = _deployment
|
deployment = _deployment
|
||||||
|
if deployment is None:
|
||||||
|
deployment = random.choice(healthy_deployments)
|
||||||
return deployment
|
return deployment
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
# This tests the router's ability to pick deployment with lowest tpm
|
# This tests the router's ability to pick deployment with lowest tpm
|
||||||
|
|
||||||
import sys, os, asyncio, time
|
import sys, os, asyncio, time, random
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import traceback
|
import traceback
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -120,11 +120,15 @@ def test_get_available_deployments():
|
||||||
)
|
)
|
||||||
|
|
||||||
## CHECK WHAT'S SELECTED ##
|
## CHECK WHAT'S SELECTED ##
|
||||||
print(lowest_tpm_logger.get_available_deployments(model_group=model_group))
|
print(
|
||||||
|
lowest_tpm_logger.get_available_deployments(
|
||||||
|
model_group=model_group, healthy_deployments=model_list
|
||||||
|
)
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
lowest_tpm_logger.get_available_deployments(model_group=model_group)[
|
lowest_tpm_logger.get_available_deployments(
|
||||||
"model_info"
|
model_group=model_group, healthy_deployments=model_list
|
||||||
]["id"]
|
)["model_info"]["id"]
|
||||||
== "5678"
|
== "5678"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -157,16 +161,6 @@ def test_router_get_available_deployments():
|
||||||
},
|
},
|
||||||
"model_info": {"id": 2},
|
"model_info": {"id": 2},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"model_name": "azure-model",
|
|
||||||
"litellm_params": {
|
|
||||||
"model": "azure/gpt-35-turbo",
|
|
||||||
"api_key": "os.environ/AZURE_CANADA_API_KEY",
|
|
||||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
|
||||||
"rpm": 6,
|
|
||||||
},
|
|
||||||
"model_info": {"id": 3},
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
router = Router(
|
router = Router(
|
||||||
model_list=model_list,
|
model_list=model_list,
|
||||||
|
@ -224,3 +218,67 @@ def test_router_get_available_deployments():
|
||||||
|
|
||||||
|
|
||||||
# test_router_get_available_deployments()
|
# test_router_get_available_deployments()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_completion_streaming():
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
|
||||||
|
]
|
||||||
|
model = "azure-model"
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
|
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
|
"rpm": 1440,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||||
|
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||||
|
"rpm": 6,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 2},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
routing_strategy="usage-based-routing",
|
||||||
|
set_verbose=False,
|
||||||
|
num_retries=3,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
### Make 3 calls, test if 3rd call goes to lowest tpm deployment
|
||||||
|
|
||||||
|
## CALL 1+2
|
||||||
|
tasks = []
|
||||||
|
response = None
|
||||||
|
final_response = None
|
||||||
|
for _ in range(2):
|
||||||
|
tasks.append(router.acompletion(model=model, messages=messages))
|
||||||
|
response = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
if response is not None:
|
||||||
|
## CALL 3
|
||||||
|
await asyncio.sleep(1) # let the token update happen
|
||||||
|
current_minute = datetime.now().strftime("%H-%M")
|
||||||
|
picked_deployment = router.lowesttpm_logger.get_available_deployments(
|
||||||
|
model_group=model, healthy_deployments=router.healthy_deployments
|
||||||
|
)
|
||||||
|
final_response = await router.acompletion(model=model, messages=messages)
|
||||||
|
print(f"min deployment id: {picked_deployment}")
|
||||||
|
print(f"model id: {final_response._hidden_params['model_id']}")
|
||||||
|
assert (
|
||||||
|
final_response._hidden_params["model_id"]
|
||||||
|
== picked_deployment["model_info"]["id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# asyncio.run(test_router_completion_streaming())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue