(Feat) Add x-litellm-overhead-duration-ms and "x-litellm-response-duration-ms" in response from LiteLLM (#7899)

* add track_llm_api_timing

* add track_llm_api_timing

* test_litellm_overhead

* use ResponseMetadata class for setting hidden params and response overhead

* instrument http handler

* fix track_llm_api_timing

* track_llm_api_timing

* emit response overhead on hidden params

* fix resp metadata

* fix make_sync_openai_embedding_request

* test_aaaaatext_completion_endpoint fixes

* _get_value_from_hidden_params

* set_hidden_params

* test_litellm_overhead

* test_litellm_overhead

* test_litellm_overhead

* fix import

* test_litellm_overhead_stream

* add LiteLLMLoggingObject

* use diff folder for testing

* use diff folder for overhead testing

* test litellm overhead

* use typing

* clear typing

* test_litellm_overhead

* fix async_streaming

* update_response_metadata

* move test file

* pply metadata to the response objec
This commit is contained in:
Ishaan Jaff 2025-01-21 20:27:55 -08:00 committed by GitHub
parent 63d7d04232
commit b6f2e659b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 464 additions and 73 deletions

View file

@ -669,7 +669,7 @@ jobs:
paths:
- batches_coverage.xml
- batches_coverage
secret_manager_testing:
litellm_utils_testing:
docker:
- image: cimg/python:3.11
auth:
@ -697,13 +697,13 @@ jobs:
command: |
pwd
ls
python -m pytest -vv tests/secret_manager_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5
python -m pytest -vv tests/litellm_utils_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m
- run:
name: Rename the coverage files
command: |
mv coverage.xml secret_manager_coverage.xml
mv .coverage secret_manager_coverage
mv coverage.xml litellm_utils_coverage.xml
mv .coverage litellm_utils_coverage
# Store test results
- store_test_results:
@ -711,8 +711,8 @@ jobs:
- persist_to_workspace:
root: .
paths:
- secret_manager_coverage.xml
- secret_manager_coverage
- litellm_utils_coverage.xml
- litellm_utils_coverage
pass_through_unit_testing:
docker:
@ -2029,7 +2029,7 @@ workflows:
only:
- main
- /litellm_.*/
- secret_manager_testing:
- litellm_utils_testing:
filters:
branches:
only:
@ -2057,7 +2057,7 @@ workflows:
requires:
- llm_translation_testing
- batches_testing
- secret_manager_testing
- litellm_utils_testing
- pass_through_unit_testing
- image_gen_testing
- logging_testing
@ -2113,7 +2113,7 @@ workflows:
- test_bad_database_url
- llm_translation_testing
- batches_testing
- secret_manager_testing
- litellm_utils_testing
- pass_through_unit_testing
- image_gen_testing
- logging_testing

View file

@ -0,0 +1,116 @@
import datetime
from typing import Any, Optional, Union
from litellm.litellm_core_utils.core_helpers import process_response_headers
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
from litellm.litellm_core_utils.logging_utils import LiteLLMLoggingObject
from litellm.types.utils import (
EmbeddingResponse,
HiddenParams,
ModelResponse,
TranscriptionResponse,
)
class ResponseMetadata:
"""
Handles setting and managing `_hidden_params`, `response_time_ms`, and `litellm_overhead_time_ms` for LiteLLM responses
"""
def __init__(self, result: Any):
self.result = result
self._hidden_params: Union[HiddenParams, dict] = (
getattr(result, "_hidden_params", {}) or {}
)
@property
def supports_response_time(self) -> bool:
"""Check if response type supports timing metrics"""
return (
isinstance(self.result, ModelResponse)
or isinstance(self.result, EmbeddingResponse)
or isinstance(self.result, TranscriptionResponse)
)
def set_hidden_params(
self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict
) -> None:
"""Set hidden parameters on the response"""
new_params = {
"litellm_call_id": getattr(logging_obj, "litellm_call_id", None),
"model_id": kwargs.get("model_info", {}).get("id", None),
"api_base": get_api_base(model=model or "", optional_params=kwargs),
"response_cost": logging_obj._response_cost_calculator(result=self.result),
"additional_headers": process_response_headers(
self._get_value_from_hidden_params("additional_headers") or {}
),
}
self._update_hidden_params(new_params)
def _update_hidden_params(self, new_params: dict) -> None:
"""
Update hidden params - handles when self._hidden_params is a dict or HiddenParams object
"""
# Handle both dict and HiddenParams cases
if isinstance(self._hidden_params, dict):
self._hidden_params.update(new_params)
elif isinstance(self._hidden_params, HiddenParams):
# For HiddenParams object, set attributes individually
for key, value in new_params.items():
setattr(self._hidden_params, key, value)
def _get_value_from_hidden_params(self, key: str) -> Optional[Any]:
"""Get value from hidden params - handles when self._hidden_params is a dict or HiddenParams object"""
if isinstance(self._hidden_params, dict):
return self._hidden_params.get(key, None)
elif isinstance(self._hidden_params, HiddenParams):
return getattr(self._hidden_params, key, None)
def set_timing_metrics(
self,
start_time: datetime.datetime,
end_time: datetime.datetime,
logging_obj: LiteLLMLoggingObject,
) -> None:
"""Set response timing metrics"""
total_response_time_ms = (end_time - start_time).total_seconds() * 1000
# Set total response time if supported
if self.supports_response_time:
self.result._response_ms = total_response_time_ms
# Calculate LiteLLM overhead
llm_api_duration_ms = logging_obj.model_call_details.get("llm_api_duration_ms")
if llm_api_duration_ms is not None:
overhead_ms = round(total_response_time_ms - llm_api_duration_ms, 4)
self._update_hidden_params(
{
"litellm_overhead_time_ms": overhead_ms,
"_response_ms": total_response_time_ms,
}
)
def apply(self) -> None:
"""Apply metadata to the response object"""
if hasattr(self.result, "_hidden_params"):
self.result._hidden_params = self._hidden_params
def update_response_metadata(
result: Any,
logging_obj: LiteLLMLoggingObject,
model: Optional[str],
kwargs: dict,
start_time: datetime.datetime,
end_time: datetime.datetime,
) -> None:
"""
Updates response metadata including hidden params and timing metrics
"""
if result is None:
return
metadata = ResponseMetadata(result)
metadata.set_hidden_params(logging_obj, model, kwargs)
metadata.set_timing_metrics(start_time, end_time, logging_obj)
metadata.apply()

View file

@ -1,3 +1,5 @@
import asyncio
import functools
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, Union
@ -10,10 +12,14 @@ from litellm.types.utils import (
if TYPE_CHECKING:
from litellm import ModelResponse as _ModelResponse
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObject,
)
LiteLLMModelResponse = _ModelResponse
else:
LiteLLMModelResponse = Any
LiteLLMLoggingObject = Any
import litellm
@ -91,3 +97,64 @@ def _assemble_complete_response_from_streaming_chunks(
else:
streaming_chunks.append(result)
return complete_streaming_response
def _set_duration_in_model_call_details(
logging_obj: Any, # we're not guaranteed this will be `LiteLLMLoggingObject`
start_time: datetime,
end_time: datetime,
):
"""Helper to set duration in model_call_details, with error handling"""
try:
duration_ms = (end_time - start_time).total_seconds() * 1000
if logging_obj and hasattr(logging_obj, "model_call_details"):
logging_obj.model_call_details["llm_api_duration_ms"] = duration_ms
else:
verbose_logger.warning(
"`logging_obj` not found - unable to track `llm_api_duration_ms"
)
except Exception as e:
verbose_logger.warning(f"Error setting `llm_api_duration_ms`: {str(e)}")
def track_llm_api_timing():
"""
Decorator to track LLM API call timing for both sync and async functions.
The logging_obj is expected to be passed as an argument to the decorated function.
"""
def decorator(func):
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
start_time = datetime.now()
try:
result = await func(*args, **kwargs)
return result
finally:
end_time = datetime.now()
_set_duration_in_model_call_details(
logging_obj=kwargs.get("logging_obj", None),
start_time=start_time,
end_time=end_time,
)
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
start_time = datetime.now()
try:
result = func(*args, **kwargs)
return result
finally:
end_time = datetime.now()
_set_duration_in_model_call_details(
logging_obj=kwargs.get("logging_obj", None),
start_time=start_time,
end_time=end_time,
)
# Check if the function is async or sync
if asyncio.iscoroutinefunction(func):
return async_wrapper
return sync_wrapper
return decorator

View file

@ -5,6 +5,7 @@ from typing import Any, Callable, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -26,7 +27,7 @@ def make_sync_call(
data: str,
model: str,
messages: list,
logging_obj,
logging_obj: LiteLLMLoggingObject,
json_mode: Optional[bool] = False,
fake_stream: bool = False,
):
@ -38,6 +39,7 @@ def make_sync_call(
headers=headers,
data=data,
stream=not fake_stream,
logging_obj=logging_obj,
)
if response.status_code != 200:
@ -171,7 +173,7 @@ class BedrockConverseLLM(BaseAWSLLM):
print_verbose: Callable,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
logging_obj: LiteLLMLoggingObject,
stream,
optional_params: dict,
litellm_params: dict,
@ -223,7 +225,9 @@ class BedrockConverseLLM(BaseAWSLLM):
client = client # type: ignore
try:
response = await client.post(url=api_base, headers=headers, data=data) # type: ignore
response = await client.post(
url=api_base, headers=headers, data=data, logging_obj=logging_obj
) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
@ -254,7 +258,7 @@ class BedrockConverseLLM(BaseAWSLLM):
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
logging_obj: LiteLLMLoggingObject,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
@ -458,7 +462,12 @@ class BedrockConverseLLM(BaseAWSLLM):
### COMPLETION
try:
response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
response = client.post(
url=proxy_endpoint_url,
headers=prepped.headers,
data=data,
logging_obj=logging_obj,
) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code

View file

@ -28,6 +28,7 @@ from litellm import verbose_logger
from litellm.caching.caching import InMemoryCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.litellm_core_utils.prompt_templates.factory import (
cohere_message_pt,
construct_tool_use_system_prompt,
@ -171,7 +172,7 @@ async def make_call(
data: str,
model: str,
messages: list,
logging_obj,
logging_obj: Logging,
fake_stream: bool = False,
json_mode: Optional[bool] = False,
):
@ -186,6 +187,7 @@ async def make_call(
headers=headers,
data=data,
stream=not fake_stream,
logging_obj=logging_obj,
)
if response.status_code != 200:
@ -577,7 +579,7 @@ class BedrockLLM(BaseAWSLLM):
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
logging_obj: Logging,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
@ -890,6 +892,7 @@ class BedrockLLM(BaseAWSLLM):
headers=prepped.headers, # type: ignore
data=data,
stream=stream,
logging_obj=logging_obj,
)
if response.status_code != 200:
@ -917,7 +920,12 @@ class BedrockLLM(BaseAWSLLM):
return streaming_response
try:
response = self.client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
response = self.client.post(
url=proxy_endpoint_url,
headers=dict(prepped.headers),
data=data,
logging_obj=logging_obj,
)
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
@ -949,7 +957,7 @@ class BedrockLLM(BaseAWSLLM):
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
logging_obj: Logging,
stream,
optional_params: dict,
litellm_params=None,
@ -968,7 +976,13 @@ class BedrockLLM(BaseAWSLLM):
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, data=data) # type: ignore
response = await client.post(
api_base,
headers=headers,
data=data,
timeout=timeout,
logging_obj=logging_obj,
)
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
@ -990,6 +1004,7 @@ class BedrockLLM(BaseAWSLLM):
encoding=encoding,
)
@track_llm_api_timing() # for streaming, we need to instrument the function calling the wrapper
async def async_streaming(
self,
model: str,
@ -1000,7 +1015,7 @@ class BedrockLLM(BaseAWSLLM):
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
logging_obj: Logging,
stream,
optional_params: dict,
litellm_params=None,

View file

@ -6,12 +6,17 @@ import httpx
from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport
import litellm
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.types.llms.custom_http import *
if TYPE_CHECKING:
from litellm import LlmProviders
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObject,
)
else:
LlmProviders = Any
LiteLLMLoggingObject = Any
try:
from litellm._version import version
@ -156,6 +161,7 @@ class AsyncHTTPHandler:
)
return response
@track_llm_api_timing()
async def post(
self,
url: str,
@ -165,6 +171,7 @@ class AsyncHTTPHandler:
headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
stream: bool = False,
logging_obj: Optional[LiteLLMLoggingObject] = None,
):
try:
if timeout is None:
@ -494,6 +501,7 @@ class HTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None,
files: Optional[dict] = None,
content: Any = None,
logging_obj: Optional[LiteLLMLoggingObject] = None,
):
try:
if timeout is not None:

View file

@ -27,6 +27,7 @@ import litellm
from litellm import LlmProviders
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
@ -380,11 +381,13 @@ class OpenAIChatCompletion(BaseLLM):
else:
return client
@track_llm_api_timing()
async def make_openai_chat_completion_request(
self,
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
) -> Tuple[dict, BaseModel]:
"""
Helper to:
@ -414,11 +417,13 @@ class OpenAIChatCompletion(BaseLLM):
except Exception as e:
raise e
@track_llm_api_timing()
def make_sync_openai_chat_completion_request(
self,
openai_client: OpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
) -> Tuple[dict, BaseModel]:
"""
Helper to:
@ -630,6 +635,7 @@ class OpenAIChatCompletion(BaseLLM):
openai_client=openai_client,
data=data,
timeout=timeout,
logging_obj=logging_obj,
)
)
@ -762,7 +768,10 @@ class OpenAIChatCompletion(BaseLLM):
)
headers, response = await self.make_openai_chat_completion_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
openai_aclient=openai_aclient,
data=data,
timeout=timeout,
logging_obj=logging_obj,
)
stringified_response = response.model_dump()
@ -852,6 +861,7 @@ class OpenAIChatCompletion(BaseLLM):
openai_client=openai_client,
data=data,
timeout=timeout,
logging_obj=logging_obj,
)
logging_obj.model_call_details["response_headers"] = headers
@ -910,7 +920,10 @@ class OpenAIChatCompletion(BaseLLM):
)
headers, response = await self.make_openai_chat_completion_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
openai_aclient=openai_aclient,
data=data,
timeout=timeout,
logging_obj=logging_obj,
)
logging_obj.model_call_details["response_headers"] = headers
streamwrapper = CustomStreamWrapper(
@ -965,11 +978,13 @@ class OpenAIChatCompletion(BaseLLM):
)
# Embedding
@track_llm_api_timing()
async def make_openai_embedding_request(
self,
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
):
"""
Helper to:
@ -986,11 +1001,13 @@ class OpenAIChatCompletion(BaseLLM):
except Exception as e:
raise e
@track_llm_api_timing()
def make_sync_openai_embedding_request(
self,
openai_client: OpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
):
"""
Helper to:
@ -1030,7 +1047,10 @@ class OpenAIChatCompletion(BaseLLM):
client=client,
)
headers, response = await self.make_openai_embedding_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
openai_aclient=openai_aclient,
data=data,
timeout=timeout,
logging_obj=logging_obj,
)
logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump()
@ -1128,7 +1148,10 @@ class OpenAIChatCompletion(BaseLLM):
## embedding CALL
headers: Optional[Dict] = None
headers, sync_embedding_response = self.make_sync_openai_embedding_request(
openai_client=openai_client, data=data, timeout=timeout
openai_client=openai_client,
data=data,
timeout=timeout,
logging_obj=logging_obj,
) # type: ignore
## LOGGING

View file

@ -733,11 +733,13 @@ def get_custom_headers(
version: Optional[str] = None,
model_region: Optional[str] = None,
response_cost: Optional[Union[float, str]] = None,
hidden_params: Optional[dict] = None,
fastest_response_batch_completion: Optional[bool] = None,
request_data: Optional[dict] = {},
**kwargs,
) -> dict:
exclude_values = {"", None}
hidden_params = hidden_params or {}
headers = {
"x-litellm-call-id": call_id,
"x-litellm-model-id": model_id,
@ -750,6 +752,10 @@ def get_custom_headers(
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
"x-litellm-key-max-budget": str(user_api_key_dict.max_budget),
"x-litellm-key-spend": str(user_api_key_dict.spend),
"x-litellm-response-duration-ms": str(hidden_params.get("_response_ms", None)),
"x-litellm-overhead-duration-ms": str(
hidden_params.get("litellm_overhead_time_ms", None)
),
"x-litellm-fastest_response_batch_completion": (
str(fastest_response_batch_completion)
if fastest_response_batch_completion is not None
@ -3491,6 +3497,7 @@ async def chat_completion( # noqa: PLR0915
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
request_data=data,
hidden_params=hidden_params,
**additional_headers,
)
selected_data_generator = select_data_generator(
@ -3526,6 +3533,7 @@ async def chat_completion( # noqa: PLR0915
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
request_data=data,
hidden_params=hidden_params,
**additional_headers,
)
)
@ -3719,6 +3727,7 @@ async def completion( # noqa: PLR0915
api_base=api_base,
version=version,
response_cost=response_cost,
hidden_params=hidden_params,
request_data=data,
)
selected_data_generator = select_data_generator(
@ -3747,6 +3756,7 @@ async def completion( # noqa: PLR0915
version=version,
response_cost=response_cost,
request_data=data,
hidden_params=hidden_params,
)
)
await check_response_size_is_safe(response=response)
@ -3977,6 +3987,7 @@ async def embeddings( # noqa: PLR0915
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id,
request_data=data,
hidden_params=hidden_params,
**additional_headers,
)
)
@ -4103,6 +4114,7 @@ async def image_generation(
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id,
request_data=data,
hidden_params=hidden_params,
)
)
@ -4223,6 +4235,7 @@ async def audio_speech(
fastest_response_batch_completion=None,
call_id=litellm_call_id,
request_data=data,
hidden_params=hidden_params,
)
select_data_generator(
@ -4362,6 +4375,7 @@ async def audio_transcriptions(
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id,
request_data=data,
hidden_params=hidden_params,
**additional_headers,
)
)
@ -4510,6 +4524,7 @@ async def get_assistants(
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
hidden_params=hidden_params,
)
)
@ -4607,6 +4622,7 @@ async def create_assistant(
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
hidden_params=hidden_params,
)
)
@ -4703,6 +4719,7 @@ async def delete_assistant(
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
hidden_params=hidden_params,
)
)
@ -4799,6 +4816,7 @@ async def create_threads(
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
hidden_params=hidden_params,
)
)
@ -4894,6 +4912,7 @@ async def get_thread(
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
hidden_params=hidden_params,
)
)
@ -4992,6 +5011,7 @@ async def add_messages(
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
hidden_params=hidden_params,
)
)
@ -5086,6 +5106,7 @@ async def get_messages(
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
hidden_params=hidden_params,
)
)
@ -5194,6 +5215,7 @@ async def run_thread(
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
hidden_params=hidden_params,
)
)
@ -5316,6 +5338,7 @@ async def moderations(
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
hidden_params=hidden_params,
)
)
@ -5488,6 +5511,7 @@ async def anthropic_response( # noqa: PLR0915
version=version,
response_cost=response_cost,
request_data=data,
hidden_params=hidden_params,
)
)

View file

@ -93,6 +93,9 @@ from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
get_response_headers,
)
from litellm.litellm_core_utils.llm_response_utils.response_metadata import (
ResponseMetadata,
)
from litellm.litellm_core_utils.redact_messages import (
LiteLLMLoggingObject,
redact_message_input_output_from_logging,
@ -929,6 +932,15 @@ def client(original_function): # noqa: PLR0915
chunks, messages=kwargs.get("messages", None)
)
else:
# RETURN RESULT
update_response_metadata(
result=result,
logging_obj=logging_obj,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
return result
elif "acompletion" in kwargs and kwargs["acompletion"] is True:
return result
@ -966,25 +978,14 @@ def client(original_function): # noqa: PLR0915
end_time,
)
# RETURN RESULT
if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
result._hidden_params["api_base"] = get_api_base(
model=model or "",
optional_params=getattr(logging_obj, "optional_params", {}),
)
result._hidden_params["response_cost"] = (
logging_obj._response_cost_calculator(result=result)
)
result._hidden_params["additional_headers"] = process_response_headers(
result._hidden_params.get("additional_headers") or {}
) # GUARANTEE OPENAI HEADERS IN RESPONSE
if result is not None:
result._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
update_response_metadata(
result=result,
logging_obj=logging_obj,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
return result
except Exception as e:
call_type = original_function.__name__
@ -1116,39 +1117,17 @@ def client(original_function): # noqa: PLR0915
chunks, messages=kwargs.get("messages", None)
)
else:
update_response_metadata(
result=result,
logging_obj=logging_obj,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
return result
elif call_type == CallTypes.arealtime.value:
return result
# ADD HIDDEN PARAMS - additional call metadata
if hasattr(result, "_hidden_params"):
result._hidden_params["litellm_call_id"] = getattr(
logging_obj, "litellm_call_id", None
)
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
result._hidden_params["api_base"] = get_api_base(
model=model or "",
optional_params=kwargs,
)
result._hidden_params["response_cost"] = (
logging_obj._response_cost_calculator(result=result)
)
result._hidden_params["additional_headers"] = process_response_headers(
result._hidden_params.get("additional_headers") or {}
) # GUARANTEE OPENAI HEADERS IN RESPONSE
if (
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
or isinstance(result, TranscriptionResponse)
):
setattr(
result,
"_response_ms",
(end_time - start_time).total_seconds() * 1000,
) # return response latency in ms like openai
### POST-CALL RULES ###
post_call_processing(
original_response=result, model=model, optional_params=kwargs
@ -1190,6 +1169,15 @@ def client(original_function): # noqa: PLR0915
end_time=end_time,
)
update_response_metadata(
result=result,
logging_obj=logging_obj,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
return result
except Exception as e:
traceback_exception = traceback.format_exc()
@ -1293,6 +1281,31 @@ def _is_async_request(
return False
def update_response_metadata(
result: Any,
logging_obj: LiteLLMLoggingObject,
model: Optional[str],
kwargs: dict,
start_time: datetime.datetime,
end_time: datetime.datetime,
) -> None:
"""
Updates response metadata, adds the following:
- response._hidden_params
- response._hidden_params["litellm_overhead_time_ms"]
- response.response_time_ms
"""
if result is None:
return
metadata = ResponseMetadata(result)
metadata.set_hidden_params(logging_obj=logging_obj, model=model, kwargs=kwargs)
metadata.set_timing_metrics(
start_time=start_time, end_time=end_time, logging_obj=logging_obj
)
metadata.apply()
def _select_tokenizer(
model: str, custom_tokenizer: Optional[CustomHuggingfaceTokenizer] = None
):

View file

@ -0,0 +1,116 @@
import json
import os
import sys
import time
from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model",
[
"bedrock/mistral.mistral-7b-instruct-v0:2",
"openai/gpt-4o",
"openai/self_hosted",
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
],
)
async def test_litellm_overhead(model):
litellm._turn_on_debug()
start_time = datetime.now()
if model == "openai/self_hosted":
response = await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "Hello, world!"}],
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
)
else:
response = await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "Hello, world!"}],
)
end_time = datetime.now()
total_time_ms = (end_time - start_time).total_seconds() * 1000
print(response)
print(response._hidden_params)
litellm_overhead_ms = response._hidden_params["litellm_overhead_time_ms"]
# calculate percent of overhead caused by litellm
overhead_percent = litellm_overhead_ms * 100 / total_time_ms
print("##########################\n")
print("total_time_ms", total_time_ms)
print("response litellm_overhead_ms", litellm_overhead_ms)
print("litellm overhead_percent {}%".format(overhead_percent))
print("##########################\n")
assert litellm_overhead_ms > 0
assert litellm_overhead_ms < 1000
# latency overhead should be less than total request time
assert litellm_overhead_ms < (end_time - start_time).total_seconds() * 1000
# latency overhead should be under 40% of total request time
assert overhead_percent < 40
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model",
[
"bedrock/mistral.mistral-7b-instruct-v0:2",
"openai/gpt-4o",
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
"openai/self_hosted",
],
)
async def test_litellm_overhead_stream(model):
litellm._turn_on_debug()
start_time = datetime.now()
if model == "openai/self_hosted":
response = await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "Hello, world!"}],
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
stream=True,
)
else:
response = await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "Hello, world!"}],
stream=True,
)
async for chunk in response:
print()
end_time = datetime.now()
total_time_ms = (end_time - start_time).total_seconds() * 1000
print(response)
print(response._hidden_params)
litellm_overhead_ms = response._hidden_params["litellm_overhead_time_ms"]
# calculate percent of overhead caused by litellm
overhead_percent = litellm_overhead_ms * 100 / total_time_ms
print("##########################\n")
print("total_time_ms", total_time_ms)
print("response litellm_overhead_ms", litellm_overhead_ms)
print("litellm overhead_percent {}%".format(overhead_percent))
print("##########################\n")
assert litellm_overhead_ms > 0
assert litellm_overhead_ms < 1000
# latency overhead should be less than total request time
assert litellm_overhead_ms < (end_time - start_time).total_seconds() * 1000
# latency overhead should be under 40% of total request time
assert overhead_percent < 40
pass