forked from phoenix/litellm-mirror
feat(lowest_latency.py): route by time to first token, for streaming requests (if available)
Closes https://github.com/BerriAI/litellm/issues/3574
This commit is contained in:
parent
620e6db027
commit
2b3da449c8
3 changed files with 232 additions and 18 deletions
|
@ -83,8 +83,15 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||||
|
|
||||||
response_ms: timedelta = end_time - start_time
|
response_ms: timedelta = end_time - start_time
|
||||||
|
time_to_first_token_response_time: Optional[timedelta] = None
|
||||||
|
if kwargs.get("stream", None) is not None and kwargs["stream"] == True:
|
||||||
|
# only log ttft for streaming request
|
||||||
|
time_to_first_token_response_time = (
|
||||||
|
kwargs.get("completion_start_time", end_time) - start_time
|
||||||
|
)
|
||||||
|
|
||||||
final_value = response_ms
|
final_value = response_ms
|
||||||
|
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
|
||||||
if isinstance(response_obj, ModelResponse):
|
if isinstance(response_obj, ModelResponse):
|
||||||
|
@ -92,6 +99,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
total_tokens = response_obj.usage.total_tokens
|
total_tokens = response_obj.usage.total_tokens
|
||||||
final_value = float(response_ms.total_seconds() / completion_tokens)
|
final_value = float(response_ms.total_seconds() / completion_tokens)
|
||||||
|
|
||||||
|
if time_to_first_token_response_time is not None:
|
||||||
|
time_to_first_token = float(
|
||||||
|
time_to_first_token_response_time.total_seconds()
|
||||||
|
/ completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage
|
# Update usage
|
||||||
# ------------
|
# ------------
|
||||||
|
@ -112,6 +125,24 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
"latency"
|
"latency"
|
||||||
][: self.routing_args.max_latency_list_size - 1] + [final_value]
|
][: self.routing_args.max_latency_list_size - 1] + [final_value]
|
||||||
|
|
||||||
|
## Time to first token
|
||||||
|
if time_to_first_token is not None:
|
||||||
|
if (
|
||||||
|
len(request_count_dict[id].get("time_to_first_token", []))
|
||||||
|
< self.routing_args.max_latency_list_size
|
||||||
|
):
|
||||||
|
request_count_dict[id].setdefault(
|
||||||
|
"time_to_first_token", []
|
||||||
|
).append(time_to_first_token)
|
||||||
|
else:
|
||||||
|
request_count_dict[id][
|
||||||
|
"time_to_first_token"
|
||||||
|
] = request_count_dict[id]["time_to_first_token"][
|
||||||
|
: self.routing_args.max_latency_list_size - 1
|
||||||
|
] + [
|
||||||
|
time_to_first_token
|
||||||
|
]
|
||||||
|
|
||||||
if precise_minute not in request_count_dict[id]:
|
if precise_minute not in request_count_dict[id]:
|
||||||
request_count_dict[id][precise_minute] = {}
|
request_count_dict[id][precise_minute] = {}
|
||||||
|
|
||||||
|
@ -226,6 +257,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
{model_group}_map: {
|
{model_group}_map: {
|
||||||
id: {
|
id: {
|
||||||
"latency": [..]
|
"latency": [..]
|
||||||
|
"time_to_first_token": [..]
|
||||||
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -239,15 +271,27 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||||
|
|
||||||
response_ms: timedelta = end_time - start_time
|
response_ms: timedelta = end_time - start_time
|
||||||
|
time_to_first_token_response_time: Optional[timedelta] = None
|
||||||
|
if kwargs.get("stream", None) is not None and kwargs["stream"] == True:
|
||||||
|
# only log ttft for streaming request
|
||||||
|
time_to_first_token_response_time = (
|
||||||
|
kwargs.get("completion_start_time", end_time) - start_time
|
||||||
|
)
|
||||||
|
|
||||||
final_value = response_ms
|
final_value = response_ms
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
time_to_first_token: Optional[float] = None
|
||||||
|
|
||||||
if isinstance(response_obj, ModelResponse):
|
if isinstance(response_obj, ModelResponse):
|
||||||
completion_tokens = response_obj.usage.completion_tokens
|
completion_tokens = response_obj.usage.completion_tokens
|
||||||
total_tokens = response_obj.usage.total_tokens
|
total_tokens = response_obj.usage.total_tokens
|
||||||
final_value = float(response_ms.total_seconds() / completion_tokens)
|
final_value = float(response_ms.total_seconds() / completion_tokens)
|
||||||
|
|
||||||
|
if time_to_first_token_response_time is not None:
|
||||||
|
time_to_first_token = float(
|
||||||
|
time_to_first_token_response_time.total_seconds()
|
||||||
|
/ completion_tokens
|
||||||
|
)
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage
|
# Update usage
|
||||||
# ------------
|
# ------------
|
||||||
|
@ -268,6 +312,24 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
"latency"
|
"latency"
|
||||||
][: self.routing_args.max_latency_list_size - 1] + [final_value]
|
][: self.routing_args.max_latency_list_size - 1] + [final_value]
|
||||||
|
|
||||||
|
## Time to first token
|
||||||
|
if time_to_first_token is not None:
|
||||||
|
if (
|
||||||
|
len(request_count_dict[id].get("time_to_first_token", []))
|
||||||
|
< self.routing_args.max_latency_list_size
|
||||||
|
):
|
||||||
|
request_count_dict[id].setdefault(
|
||||||
|
"time_to_first_token", []
|
||||||
|
).append(time_to_first_token)
|
||||||
|
else:
|
||||||
|
request_count_dict[id][
|
||||||
|
"time_to_first_token"
|
||||||
|
] = request_count_dict[id]["time_to_first_token"][
|
||||||
|
: self.routing_args.max_latency_list_size - 1
|
||||||
|
] + [
|
||||||
|
time_to_first_token
|
||||||
|
]
|
||||||
|
|
||||||
if precise_minute not in request_count_dict[id]:
|
if precise_minute not in request_count_dict[id]:
|
||||||
request_count_dict[id][precise_minute] = {}
|
request_count_dict[id][precise_minute] = {}
|
||||||
|
|
||||||
|
@ -370,11 +432,22 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
or float("inf")
|
or float("inf")
|
||||||
)
|
)
|
||||||
item_latency = item_map.get("latency", [])
|
item_latency = item_map.get("latency", [])
|
||||||
|
item_ttft_latency = item_map.get("time_to_first_token", [])
|
||||||
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
|
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
|
||||||
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
|
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
|
||||||
|
|
||||||
# get average latency
|
# get average latency or average ttft (depending on streaming/non-streaming)
|
||||||
total: float = 0.0
|
total: float = 0.0
|
||||||
|
if (
|
||||||
|
request_kwargs is not None
|
||||||
|
and request_kwargs.get("stream", None) is not None
|
||||||
|
and request_kwargs["stream"] == True
|
||||||
|
and len(item_ttft_latency) > 0
|
||||||
|
):
|
||||||
|
for _call_latency in item_ttft_latency:
|
||||||
|
if isinstance(_call_latency, float):
|
||||||
|
total += _call_latency
|
||||||
|
else:
|
||||||
for _call_latency in item_latency:
|
for _call_latency in item_latency:
|
||||||
if isinstance(_call_latency, float):
|
if isinstance(_call_latency, float):
|
||||||
total += _call_latency
|
total += _call_latency
|
||||||
|
@ -413,6 +486,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
|
|
||||||
# Find deployments within buffer of lowest latency
|
# Find deployments within buffer of lowest latency
|
||||||
buffer = self.routing_args.lowest_latency_buffer * lowest_latency
|
buffer = self.routing_args.lowest_latency_buffer * lowest_latency
|
||||||
|
|
||||||
valid_deployments = [
|
valid_deployments = [
|
||||||
x for x in sorted_deployments if x[1] <= lowest_latency + buffer
|
x for x in sorted_deployments if x[1] <= lowest_latency + buffer
|
||||||
]
|
]
|
||||||
|
|
|
@ -38,8 +38,7 @@ class CompletionCustomHandler(
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.errors = []
|
self.errors = []
|
||||||
self.states: Optional[
|
self.states: List[
|
||||||
List[
|
|
||||||
Literal[
|
Literal[
|
||||||
"sync_pre_api_call",
|
"sync_pre_api_call",
|
||||||
"async_pre_api_call",
|
"async_pre_api_call",
|
||||||
|
@ -51,7 +50,6 @@ class CompletionCustomHandler(
|
||||||
"sync_failure",
|
"sync_failure",
|
||||||
"async_failure",
|
"async_failure",
|
||||||
]
|
]
|
||||||
]
|
|
||||||
] = []
|
] = []
|
||||||
|
|
||||||
def log_pre_api_call(self, model, messages, kwargs):
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
@ -269,6 +267,7 @@ class CompletionCustomHandler(
|
||||||
assert isinstance(kwargs["litellm_params"]["api_base"], str)
|
assert isinstance(kwargs["litellm_params"]["api_base"], str)
|
||||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||||
assert isinstance(kwargs["stream"], bool)
|
assert isinstance(kwargs["stream"], bool)
|
||||||
|
assert isinstance(kwargs["completion_start_time"], datetime)
|
||||||
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
||||||
assert isinstance(kwargs["user"], (str, type(None)))
|
assert isinstance(kwargs["user"], (str, type(None)))
|
||||||
assert isinstance(kwargs["input"], (list, dict, str))
|
assert isinstance(kwargs["input"], (list, dict, str))
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
# This tests the router's ability to pick deployment with lowest latency
|
# This tests the router's ability to pick deployment with lowest latency
|
||||||
|
|
||||||
import sys, os, asyncio, time, random
|
import sys, os, asyncio, time, random
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
import traceback
|
import traceback
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ import pytest
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
import litellm
|
||||||
|
|
||||||
### UNIT TESTS FOR LATENCY ROUTING ###
|
### UNIT TESTS FOR LATENCY ROUTING ###
|
||||||
|
|
||||||
|
@ -813,3 +814,143 @@ async def test_lowest_latency_routing_buffer(buffer):
|
||||||
assert len(selected_deployments.keys()) == 1
|
assert len(selected_deployments.keys()) == 1
|
||||||
else:
|
else:
|
||||||
assert len(selected_deployments.keys()) == 2
|
assert len(selected_deployments.keys()) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lowest_latency_routing_time_to_first_token(sync_mode):
|
||||||
|
"""
|
||||||
|
If a deployment has
|
||||||
|
- a fast time to first token
|
||||||
|
- slow latency/output token
|
||||||
|
|
||||||
|
test if:
|
||||||
|
- for streaming, the deployment with fastest time to first token is picked
|
||||||
|
- for non-streaming, fastest overall deployment is picked
|
||||||
|
"""
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
|
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||||
|
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||||
|
},
|
||||||
|
"model_info": {"id": 2},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
routing_strategy="latency-based-routing",
|
||||||
|
set_verbose=False,
|
||||||
|
num_retries=3,
|
||||||
|
) # type: ignore
|
||||||
|
## DEPLOYMENT 1 ##
|
||||||
|
deployment_id = 1
|
||||||
|
start_time = datetime.now()
|
||||||
|
one_second_later = start_time + timedelta(seconds=1)
|
||||||
|
|
||||||
|
# Compute 3 seconds after the current time
|
||||||
|
three_seconds_later = start_time + timedelta(seconds=3)
|
||||||
|
four_seconds_later = start_time + timedelta(seconds=4)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"model_group": "azure-model",
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1},
|
||||||
|
},
|
||||||
|
"stream": True,
|
||||||
|
"completion_start_time": one_second_later,
|
||||||
|
}
|
||||||
|
|
||||||
|
response_obj = litellm.ModelResponse(
|
||||||
|
usage=litellm.Usage(completion_tokens=50, total_tokens=50)
|
||||||
|
)
|
||||||
|
end_time = four_seconds_later
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
router.lowestlatency_logger.log_success_event(
|
||||||
|
response_obj=response_obj,
|
||||||
|
kwargs=kwargs,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await router.lowestlatency_logger.async_log_success_event(
|
||||||
|
response_obj=response_obj,
|
||||||
|
kwargs=kwargs,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
## DEPLOYMENT 2 ##
|
||||||
|
deployment_id = 2
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"model_group": "azure-model",
|
||||||
|
},
|
||||||
|
"model_info": {"id": 2},
|
||||||
|
},
|
||||||
|
"stream": True,
|
||||||
|
"completion_start_time": three_seconds_later,
|
||||||
|
}
|
||||||
|
response_obj = litellm.ModelResponse(
|
||||||
|
usage=litellm.Usage(completion_tokens=50, total_tokens=50)
|
||||||
|
)
|
||||||
|
end_time = three_seconds_later
|
||||||
|
if sync_mode:
|
||||||
|
router.lowestlatency_logger.log_success_event(
|
||||||
|
response_obj=response_obj,
|
||||||
|
kwargs=kwargs,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await router.lowestlatency_logger.async_log_success_event(
|
||||||
|
response_obj=response_obj,
|
||||||
|
kwargs=kwargs,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
TESTING
|
||||||
|
|
||||||
|
- expect deployment 1 to be picked for streaming
|
||||||
|
- expect deployment 2 to be picked for non-streaming
|
||||||
|
"""
|
||||||
|
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
|
||||||
|
selected_deployments = {}
|
||||||
|
for _ in range(3):
|
||||||
|
print(router.get_available_deployment(model="azure-model"))
|
||||||
|
## for non-streaming
|
||||||
|
selected_deployments[
|
||||||
|
router.get_available_deployment(model="azure-model")["model_info"]["id"]
|
||||||
|
] = 1
|
||||||
|
|
||||||
|
assert len(selected_deployments.keys()) == 1
|
||||||
|
assert "2" in list(selected_deployments.keys())
|
||||||
|
|
||||||
|
selected_deployments = {}
|
||||||
|
for _ in range(50):
|
||||||
|
print(router.get_available_deployment(model="azure-model"))
|
||||||
|
## for non-streaming
|
||||||
|
selected_deployments[
|
||||||
|
router.get_available_deployment(
|
||||||
|
model="azure-model", request_kwargs={"stream": True}
|
||||||
|
)["model_info"]["id"]
|
||||||
|
] = 1
|
||||||
|
|
||||||
|
assert len(selected_deployments.keys()) == 1
|
||||||
|
assert "1" in list(selected_deployments.keys())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue