mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(lowest_latency.py): route by time to first token, for streaming requests (if available)
Closes https://github.com/BerriAI/litellm/issues/3574
This commit is contained in:
parent
620e6db027
commit
2b3da449c8
3 changed files with 232 additions and 18 deletions
|
@ -83,8 +83,15 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
response_ms: timedelta = end_time - start_time
|
||||
time_to_first_token_response_time: Optional[timedelta] = None
|
||||
if kwargs.get("stream", None) is not None and kwargs["stream"] == True:
|
||||
# only log ttft for streaming request
|
||||
time_to_first_token_response_time = (
|
||||
kwargs.get("completion_start_time", end_time) - start_time
|
||||
)
|
||||
|
||||
final_value = response_ms
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
|
@ -92,6 +99,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
total_tokens = response_obj.usage.total_tokens
|
||||
final_value = float(response_ms.total_seconds() / completion_tokens)
|
||||
|
||||
if time_to_first_token_response_time is not None:
|
||||
time_to_first_token = float(
|
||||
time_to_first_token_response_time.total_seconds()
|
||||
/ completion_tokens
|
||||
)
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
|
@ -112,6 +125,24 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
"latency"
|
||||
][: self.routing_args.max_latency_list_size - 1] + [final_value]
|
||||
|
||||
## Time to first token
|
||||
if time_to_first_token is not None:
|
||||
if (
|
||||
len(request_count_dict[id].get("time_to_first_token", []))
|
||||
< self.routing_args.max_latency_list_size
|
||||
):
|
||||
request_count_dict[id].setdefault(
|
||||
"time_to_first_token", []
|
||||
).append(time_to_first_token)
|
||||
else:
|
||||
request_count_dict[id][
|
||||
"time_to_first_token"
|
||||
] = request_count_dict[id]["time_to_first_token"][
|
||||
: self.routing_args.max_latency_list_size - 1
|
||||
] + [
|
||||
time_to_first_token
|
||||
]
|
||||
|
||||
if precise_minute not in request_count_dict[id]:
|
||||
request_count_dict[id][precise_minute] = {}
|
||||
|
||||
|
@ -226,6 +257,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
{model_group}_map: {
|
||||
id: {
|
||||
"latency": [..]
|
||||
"time_to_first_token": [..]
|
||||
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
||||
}
|
||||
}
|
||||
|
@ -239,15 +271,27 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
response_ms: timedelta = end_time - start_time
|
||||
time_to_first_token_response_time: Optional[timedelta] = None
|
||||
if kwargs.get("stream", None) is not None and kwargs["stream"] == True:
|
||||
# only log ttft for streaming request
|
||||
time_to_first_token_response_time = (
|
||||
kwargs.get("completion_start_time", end_time) - start_time
|
||||
)
|
||||
|
||||
final_value = response_ms
|
||||
total_tokens = 0
|
||||
time_to_first_token: Optional[float] = None
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
completion_tokens = response_obj.usage.completion_tokens
|
||||
total_tokens = response_obj.usage.total_tokens
|
||||
final_value = float(response_ms.total_seconds() / completion_tokens)
|
||||
|
||||
if time_to_first_token_response_time is not None:
|
||||
time_to_first_token = float(
|
||||
time_to_first_token_response_time.total_seconds()
|
||||
/ completion_tokens
|
||||
)
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
|
@ -268,6 +312,24 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
"latency"
|
||||
][: self.routing_args.max_latency_list_size - 1] + [final_value]
|
||||
|
||||
## Time to first token
|
||||
if time_to_first_token is not None:
|
||||
if (
|
||||
len(request_count_dict[id].get("time_to_first_token", []))
|
||||
< self.routing_args.max_latency_list_size
|
||||
):
|
||||
request_count_dict[id].setdefault(
|
||||
"time_to_first_token", []
|
||||
).append(time_to_first_token)
|
||||
else:
|
||||
request_count_dict[id][
|
||||
"time_to_first_token"
|
||||
] = request_count_dict[id]["time_to_first_token"][
|
||||
: self.routing_args.max_latency_list_size - 1
|
||||
] + [
|
||||
time_to_first_token
|
||||
]
|
||||
|
||||
if precise_minute not in request_count_dict[id]:
|
||||
request_count_dict[id][precise_minute] = {}
|
||||
|
||||
|
@ -370,14 +432,25 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
or float("inf")
|
||||
)
|
||||
item_latency = item_map.get("latency", [])
|
||||
item_ttft_latency = item_map.get("time_to_first_token", [])
|
||||
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
|
||||
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
|
||||
|
||||
# get average latency
|
||||
# get average latency or average ttft (depending on streaming/non-streaming)
|
||||
total: float = 0.0
|
||||
for _call_latency in item_latency:
|
||||
if isinstance(_call_latency, float):
|
||||
total += _call_latency
|
||||
if (
|
||||
request_kwargs is not None
|
||||
and request_kwargs.get("stream", None) is not None
|
||||
and request_kwargs["stream"] == True
|
||||
and len(item_ttft_latency) > 0
|
||||
):
|
||||
for _call_latency in item_ttft_latency:
|
||||
if isinstance(_call_latency, float):
|
||||
total += _call_latency
|
||||
else:
|
||||
for _call_latency in item_latency:
|
||||
if isinstance(_call_latency, float):
|
||||
total += _call_latency
|
||||
item_latency = total / len(item_latency)
|
||||
|
||||
# -------------- #
|
||||
|
@ -413,6 +486,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
|
||||
# Find deployments within buffer of lowest latency
|
||||
buffer = self.routing_args.lowest_latency_buffer * lowest_latency
|
||||
|
||||
valid_deployments = [
|
||||
x for x in sorted_deployments if x[1] <= lowest_latency + buffer
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue