Merge pull request #3510 from BerriAI/litellm_make_lowest_cost_async

[Feat] Make lowest cost routing Async
This commit is contained in:
Ishaan Jaff 2024-05-07 14:14:04 -07:00 committed by GitHub
commit 84055c0546
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 41 additions and 28 deletions

View file

@ -2958,6 +2958,7 @@ class Router:
if (
self.routing_strategy != "usage-based-routing-v2"
and self.routing_strategy != "simple-shuffle"
and self.routing_strategy != "cost-based-routing"
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
return self.get_available_deployment(
model=model,
@ -3014,6 +3015,16 @@ class Router:
messages=messages,
input=input,
)
if (
self.routing_strategy == "cost-based-routing"
and self.lowestcost_logger is not None
):
deployment = await self.lowestcost_logger.async_get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments,
messages=messages,
input=input,
)
elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check if we can do a RPM/TPM based weighted pick #################
@ -3184,15 +3195,6 @@ class Router:
messages=messages,
input=input,
)
elif (
self.routing_strategy == "cost-based-routing"
and self.lowestcost_logger is not None
):
deployment = self.lowestcost_logger.get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs,
)
if deployment is None:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"

View file

@ -40,7 +40,7 @@ class LowestCostLoggingHandler(CustomLogger):
self.router_cache = router_cache
self.model_list = model_list
def log_success_event(self, kwargs, response_obj, start_time, end_time):
async def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update usage on success
@ -90,7 +90,11 @@ class LowestCostLoggingHandler(CustomLogger):
# Update usage
# ------------
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
request_count_dict = (
await self.router_cache.async_get_cache(key=cost_key) or {}
)
# check local result first
if id not in request_count_dict:
request_count_dict[id] = {}
@ -111,7 +115,9 @@ class LowestCostLoggingHandler(CustomLogger):
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
self.router_cache.set_cache(key=cost_key, value=request_count_dict)
await self.router_cache.async_set_cache(
key=cost_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
@ -172,7 +178,9 @@ class LowestCostLoggingHandler(CustomLogger):
# Update usage
# ------------
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
request_count_dict = (
await self.router_cache.async_get_cache(key=cost_key) or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
@ -189,7 +197,7 @@ class LowestCostLoggingHandler(CustomLogger):
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
self.router_cache.set_cache(
await self.router_cache.async_set_cache(
key=cost_key, value=request_count_dict
) # reset map within window
@ -200,7 +208,7 @@ class LowestCostLoggingHandler(CustomLogger):
traceback.print_exc()
pass
def get_available_deployments(
async def async_get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
@ -213,7 +221,7 @@ class LowestCostLoggingHandler(CustomLogger):
"""
cost_key = f"{model_group}_map"
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
request_count_dict = await self.router_cache.async_get_cache(key=cost_key) or {}
# -----------------------
# Find lowest used model

View file

@ -20,7 +20,8 @@ from litellm.caching import DualCache
### UNIT TESTS FOR cost ROUTING ###
def test_get_available_deployments():
@pytest.mark.asyncio
async def test_get_available_deployments():
test_cache = DualCache()
model_list = [
{
@ -40,7 +41,7 @@ def test_get_available_deployments():
model_group = "gpt-3.5-turbo"
## CHECK WHAT'S SELECTED ##
selected_model = lowest_cost_logger.get_available_deployments(
selected_model = await lowest_cost_logger.async_get_available_deployments(
model_group=model_group, healthy_deployments=model_list
)
print("selected model: ", selected_model)
@ -48,7 +49,8 @@ def test_get_available_deployments():
assert selected_model["model_info"]["id"] == "groq-llama"
def test_get_available_deployments_custom_price():
@pytest.mark.asyncio
async def test_get_available_deployments_custom_price():
from litellm._logging import verbose_router_logger
import logging
@ -89,7 +91,7 @@ def test_get_available_deployments_custom_price():
model_group = "gpt-3.5-turbo"
## CHECK WHAT'S SELECTED ##
selected_model = lowest_cost_logger.get_available_deployments(
selected_model = await lowest_cost_logger.async_get_available_deployments(
model_group=model_group, healthy_deployments=model_list
)
print("selected model: ", selected_model)
@ -142,7 +144,7 @@ async def _deploy(lowest_cost_logger, deployment_id, tokens_used, duration):
response_obj = {"usage": {"total_tokens": tokens_used}}
time.sleep(duration)
end_time = time.time()
lowest_cost_logger.log_success_event(
await lowest_cost_logger.async_log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
@ -150,14 +152,11 @@ async def _deploy(lowest_cost_logger, deployment_id, tokens_used, duration):
)
async def _gather_deploy(all_deploys):
return await asyncio.gather(*[_deploy(*t) for t in all_deploys])
@pytest.mark.parametrize(
"ans_rpm", [1, 5]
) # 1 should produce nothing, 10 should select first
def test_get_available_endpoints_tpm_rpm_check_async(ans_rpm):
@pytest.mark.asyncio
async def test_get_available_endpoints_tpm_rpm_check_async(ans_rpm):
"""
Pass in list of 2 valid models
@ -193,9 +192,13 @@ def test_get_available_endpoints_tpm_rpm_check_async(ans_rpm):
model_group = "gpt-3.5-turbo"
d1 = [(lowest_cost_logger, "1234", 50, 0.01)] * non_ans_rpm
d2 = [(lowest_cost_logger, "5678", 50, 0.01)] * non_ans_rpm
asyncio.run(_gather_deploy([*d1, *d2]))
await asyncio.gather(*[_deploy(*t) for t in [*d1, *d2]])
asyncio.sleep(3)
## CHECK WHAT'S SELECTED ##
d_ans = lowest_cost_logger.get_available_deployments(
d_ans = await lowest_cost_logger.async_get_available_deployments(
model_group=model_group, healthy_deployments=model_list
)
assert (d_ans and d_ans["model_info"]["id"]) == ans