feat(lowest_latency.py): route by time to first token, for streaming requests (if available)

Closes https://github.com/BerriAI/litellm/issues/3574
This commit is contained in:
Krrish Dholakia 2024-05-21 13:08:17 -07:00
parent 620e6db027
commit 2b3da449c8
3 changed files with 232 additions and 18 deletions

View file

@ -83,8 +83,15 @@ class LowestLatencyLoggingHandler(CustomLogger):
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
response_ms: timedelta = end_time - start_time
time_to_first_token_response_time: Optional[timedelta] = None
if kwargs.get("stream", None) is not None and kwargs["stream"] == True:
# only log ttft for streaming request
time_to_first_token_response_time = (
kwargs.get("completion_start_time", end_time) - start_time
)
final_value = response_ms
total_tokens = 0
if isinstance(response_obj, ModelResponse):
@ -92,6 +99,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
total_tokens = response_obj.usage.total_tokens
final_value = float(response_ms.total_seconds() / completion_tokens)
if time_to_first_token_response_time is not None:
time_to_first_token = float(
time_to_first_token_response_time.total_seconds()
/ completion_tokens
)
# ------------
# Update usage
# ------------
@ -112,6 +125,24 @@ class LowestLatencyLoggingHandler(CustomLogger):
"latency"
][: self.routing_args.max_latency_list_size - 1] + [final_value]
## Time to first token
if time_to_first_token is not None:
if (
len(request_count_dict[id].get("time_to_first_token", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault(
"time_to_first_token", []
).append(time_to_first_token)
else:
request_count_dict[id][
"time_to_first_token"
] = request_count_dict[id]["time_to_first_token"][
: self.routing_args.max_latency_list_size - 1
] + [
time_to_first_token
]
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
@ -226,6 +257,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
{model_group}_map: {
id: {
"latency": [..]
"time_to_first_token": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
@ -239,15 +271,27 @@ class LowestLatencyLoggingHandler(CustomLogger):
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
response_ms: timedelta = end_time - start_time
time_to_first_token_response_time: Optional[timedelta] = None
if kwargs.get("stream", None) is not None and kwargs["stream"] == True:
# only log ttft for streaming request
time_to_first_token_response_time = (
kwargs.get("completion_start_time", end_time) - start_time
)
final_value = response_ms
total_tokens = 0
time_to_first_token: Optional[float] = None
if isinstance(response_obj, ModelResponse):
completion_tokens = response_obj.usage.completion_tokens
total_tokens = response_obj.usage.total_tokens
final_value = float(response_ms.total_seconds() / completion_tokens)
if time_to_first_token_response_time is not None:
time_to_first_token = float(
time_to_first_token_response_time.total_seconds()
/ completion_tokens
)
# ------------
# Update usage
# ------------
@ -268,6 +312,24 @@ class LowestLatencyLoggingHandler(CustomLogger):
"latency"
][: self.routing_args.max_latency_list_size - 1] + [final_value]
## Time to first token
if time_to_first_token is not None:
if (
len(request_count_dict[id].get("time_to_first_token", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault(
"time_to_first_token", []
).append(time_to_first_token)
else:
request_count_dict[id][
"time_to_first_token"
] = request_count_dict[id]["time_to_first_token"][
: self.routing_args.max_latency_list_size - 1
] + [
time_to_first_token
]
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
@ -370,14 +432,25 @@ class LowestLatencyLoggingHandler(CustomLogger):
or float("inf")
)
item_latency = item_map.get("latency", [])
item_ttft_latency = item_map.get("time_to_first_token", [])
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
# get average latency
# get average latency or average ttft (depending on streaming/non-streaming)
total: float = 0.0
for _call_latency in item_latency:
if isinstance(_call_latency, float):
total += _call_latency
if (
request_kwargs is not None
and request_kwargs.get("stream", None) is not None
and request_kwargs["stream"] == True
and len(item_ttft_latency) > 0
):
for _call_latency in item_ttft_latency:
if isinstance(_call_latency, float):
total += _call_latency
else:
for _call_latency in item_latency:
if isinstance(_call_latency, float):
total += _call_latency
item_latency = total / len(item_latency)
# -------------- #
@ -413,6 +486,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
# Find deployments within buffer of lowest latency
buffer = self.routing_args.lowest_latency_buffer * lowest_latency
valid_deployments = [
x for x in sorted_deployments if x[1] <= lowest_latency + buffer
]