mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
(Feat) Add x-litellm-overhead-duration-ms and "x-litellm-response-duration-ms" in response from LiteLLM (#7899)
* add track_llm_api_timing * add track_llm_api_timing * test_litellm_overhead * use ResponseMetadata class for setting hidden params and response overhead * instrument http handler * fix track_llm_api_timing * track_llm_api_timing * emit response overhead on hidden params * fix resp metadata * fix make_sync_openai_embedding_request * test_aaaaatext_completion_endpoint fixes * _get_value_from_hidden_params * set_hidden_params * test_litellm_overhead * test_litellm_overhead * test_litellm_overhead * fix import * test_litellm_overhead_stream * add LiteLLMLoggingObject * use diff folder for testing * use diff folder for overhead testing * test litellm overhead * use typing * clear typing * test_litellm_overhead * fix async_streaming * update_response_metadata * move test file * pply metadata to the response objec
This commit is contained in:
parent
63d7d04232
commit
b6f2e659b9
17 changed files with 464 additions and 73 deletions
|
@ -669,7 +669,7 @@ jobs:
|
|||
paths:
|
||||
- batches_coverage.xml
|
||||
- batches_coverage
|
||||
secret_manager_testing:
|
||||
litellm_utils_testing:
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
auth:
|
||||
|
@ -697,13 +697,13 @@ jobs:
|
|||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -vv tests/secret_manager_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5
|
||||
python -m pytest -vv tests/litellm_utils_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5
|
||||
no_output_timeout: 120m
|
||||
- run:
|
||||
name: Rename the coverage files
|
||||
command: |
|
||||
mv coverage.xml secret_manager_coverage.xml
|
||||
mv .coverage secret_manager_coverage
|
||||
mv coverage.xml litellm_utils_coverage.xml
|
||||
mv .coverage litellm_utils_coverage
|
||||
|
||||
# Store test results
|
||||
- store_test_results:
|
||||
|
@ -711,8 +711,8 @@ jobs:
|
|||
- persist_to_workspace:
|
||||
root: .
|
||||
paths:
|
||||
- secret_manager_coverage.xml
|
||||
- secret_manager_coverage
|
||||
- litellm_utils_coverage.xml
|
||||
- litellm_utils_coverage
|
||||
|
||||
pass_through_unit_testing:
|
||||
docker:
|
||||
|
@ -2029,7 +2029,7 @@ workflows:
|
|||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- secret_manager_testing:
|
||||
- litellm_utils_testing:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
|
@ -2057,7 +2057,7 @@ workflows:
|
|||
requires:
|
||||
- llm_translation_testing
|
||||
- batches_testing
|
||||
- secret_manager_testing
|
||||
- litellm_utils_testing
|
||||
- pass_through_unit_testing
|
||||
- image_gen_testing
|
||||
- logging_testing
|
||||
|
@ -2113,7 +2113,7 @@ workflows:
|
|||
- test_bad_database_url
|
||||
- llm_translation_testing
|
||||
- batches_testing
|
||||
- secret_manager_testing
|
||||
- litellm_utils_testing
|
||||
- pass_through_unit_testing
|
||||
- image_gen_testing
|
||||
- logging_testing
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
import datetime
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from litellm.litellm_core_utils.core_helpers import process_response_headers
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
|
||||
from litellm.litellm_core_utils.logging_utils import LiteLLMLoggingObject
|
||||
from litellm.types.utils import (
|
||||
EmbeddingResponse,
|
||||
HiddenParams,
|
||||
ModelResponse,
|
||||
TranscriptionResponse,
|
||||
)
|
||||
|
||||
|
||||
class ResponseMetadata:
|
||||
"""
|
||||
Handles setting and managing `_hidden_params`, `response_time_ms`, and `litellm_overhead_time_ms` for LiteLLM responses
|
||||
"""
|
||||
|
||||
def __init__(self, result: Any):
|
||||
self.result = result
|
||||
self._hidden_params: Union[HiddenParams, dict] = (
|
||||
getattr(result, "_hidden_params", {}) or {}
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_response_time(self) -> bool:
|
||||
"""Check if response type supports timing metrics"""
|
||||
return (
|
||||
isinstance(self.result, ModelResponse)
|
||||
or isinstance(self.result, EmbeddingResponse)
|
||||
or isinstance(self.result, TranscriptionResponse)
|
||||
)
|
||||
|
||||
def set_hidden_params(
|
||||
self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict
|
||||
) -> None:
|
||||
"""Set hidden parameters on the response"""
|
||||
new_params = {
|
||||
"litellm_call_id": getattr(logging_obj, "litellm_call_id", None),
|
||||
"model_id": kwargs.get("model_info", {}).get("id", None),
|
||||
"api_base": get_api_base(model=model or "", optional_params=kwargs),
|
||||
"response_cost": logging_obj._response_cost_calculator(result=self.result),
|
||||
"additional_headers": process_response_headers(
|
||||
self._get_value_from_hidden_params("additional_headers") or {}
|
||||
),
|
||||
}
|
||||
self._update_hidden_params(new_params)
|
||||
|
||||
def _update_hidden_params(self, new_params: dict) -> None:
|
||||
"""
|
||||
Update hidden params - handles when self._hidden_params is a dict or HiddenParams object
|
||||
"""
|
||||
# Handle both dict and HiddenParams cases
|
||||
if isinstance(self._hidden_params, dict):
|
||||
self._hidden_params.update(new_params)
|
||||
elif isinstance(self._hidden_params, HiddenParams):
|
||||
# For HiddenParams object, set attributes individually
|
||||
for key, value in new_params.items():
|
||||
setattr(self._hidden_params, key, value)
|
||||
|
||||
def _get_value_from_hidden_params(self, key: str) -> Optional[Any]:
|
||||
"""Get value from hidden params - handles when self._hidden_params is a dict or HiddenParams object"""
|
||||
if isinstance(self._hidden_params, dict):
|
||||
return self._hidden_params.get(key, None)
|
||||
elif isinstance(self._hidden_params, HiddenParams):
|
||||
return getattr(self._hidden_params, key, None)
|
||||
|
||||
def set_timing_metrics(
|
||||
self,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
) -> None:
|
||||
"""Set response timing metrics"""
|
||||
total_response_time_ms = (end_time - start_time).total_seconds() * 1000
|
||||
|
||||
# Set total response time if supported
|
||||
if self.supports_response_time:
|
||||
self.result._response_ms = total_response_time_ms
|
||||
|
||||
# Calculate LiteLLM overhead
|
||||
llm_api_duration_ms = logging_obj.model_call_details.get("llm_api_duration_ms")
|
||||
if llm_api_duration_ms is not None:
|
||||
overhead_ms = round(total_response_time_ms - llm_api_duration_ms, 4)
|
||||
self._update_hidden_params(
|
||||
{
|
||||
"litellm_overhead_time_ms": overhead_ms,
|
||||
"_response_ms": total_response_time_ms,
|
||||
}
|
||||
)
|
||||
|
||||
def apply(self) -> None:
|
||||
"""Apply metadata to the response object"""
|
||||
if hasattr(self.result, "_hidden_params"):
|
||||
self.result._hidden_params = self._hidden_params
|
||||
|
||||
|
||||
def update_response_metadata(
|
||||
result: Any,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
model: Optional[str],
|
||||
kwargs: dict,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
) -> None:
|
||||
"""
|
||||
Updates response metadata including hidden params and timing metrics
|
||||
"""
|
||||
if result is None:
|
||||
return
|
||||
|
||||
metadata = ResponseMetadata(result)
|
||||
metadata.set_hidden_params(logging_obj, model, kwargs)
|
||||
metadata.set_timing_metrics(start_time, end_time, logging_obj)
|
||||
metadata.apply()
|
|
@ -1,3 +1,5 @@
|
|||
import asyncio
|
||||
import functools
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
|
@ -10,10 +12,14 @@ from litellm.types.utils import (
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import ModelResponse as _ModelResponse
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObject,
|
||||
)
|
||||
|
||||
LiteLLMModelResponse = _ModelResponse
|
||||
else:
|
||||
LiteLLMModelResponse = Any
|
||||
LiteLLMLoggingObject = Any
|
||||
|
||||
|
||||
import litellm
|
||||
|
@ -91,3 +97,64 @@ def _assemble_complete_response_from_streaming_chunks(
|
|||
else:
|
||||
streaming_chunks.append(result)
|
||||
return complete_streaming_response
|
||||
|
||||
|
||||
def _set_duration_in_model_call_details(
|
||||
logging_obj: Any, # we're not guaranteed this will be `LiteLLMLoggingObject`
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
):
|
||||
"""Helper to set duration in model_call_details, with error handling"""
|
||||
try:
|
||||
duration_ms = (end_time - start_time).total_seconds() * 1000
|
||||
if logging_obj and hasattr(logging_obj, "model_call_details"):
|
||||
logging_obj.model_call_details["llm_api_duration_ms"] = duration_ms
|
||||
else:
|
||||
verbose_logger.warning(
|
||||
"`logging_obj` not found - unable to track `llm_api_duration_ms"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Error setting `llm_api_duration_ms`: {str(e)}")
|
||||
|
||||
|
||||
def track_llm_api_timing():
|
||||
"""
|
||||
Decorator to track LLM API call timing for both sync and async functions.
|
||||
The logging_obj is expected to be passed as an argument to the decorated function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
start_time = datetime.now()
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
end_time = datetime.now()
|
||||
_set_duration_in_model_call_details(
|
||||
logging_obj=kwargs.get("logging_obj", None),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_time = datetime.now()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
end_time = datetime.now()
|
||||
_set_duration_in_model_call_details(
|
||||
logging_obj=kwargs.get("logging_obj", None),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# Check if the function is async or sync
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Any, Callable, Optional, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
|
@ -26,7 +27,7 @@ def make_sync_call(
|
|||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
json_mode: Optional[bool] = False,
|
||||
fake_stream: bool = False,
|
||||
):
|
||||
|
@ -38,6 +39,7 @@ def make_sync_call(
|
|||
headers=headers,
|
||||
data=data,
|
||||
stream=not fake_stream,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
@ -171,7 +173,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
print_verbose: Callable,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
|
@ -223,7 +225,9 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await client.post(url=api_base, headers=headers, data=data) # type: ignore
|
||||
response = await client.post(
|
||||
url=api_base, headers=headers, data=data, logging_obj=logging_obj
|
||||
) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
|
@ -254,7 +258,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
|
@ -458,7 +462,12 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
### COMPLETION
|
||||
|
||||
try:
|
||||
response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
|
||||
response = client.post(
|
||||
url=proxy_endpoint_url,
|
||||
headers=prepped.headers,
|
||||
data=data,
|
||||
logging_obj=logging_obj,
|
||||
) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
|
|
|
@ -28,6 +28,7 @@ from litellm import verbose_logger
|
|||
from litellm.caching.caching import InMemoryCache
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
cohere_message_pt,
|
||||
construct_tool_use_system_prompt,
|
||||
|
@ -171,7 +172,7 @@ async def make_call(
|
|||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
logging_obj: Logging,
|
||||
fake_stream: bool = False,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
|
@ -186,6 +187,7 @@ async def make_call(
|
|||
headers=headers,
|
||||
data=data,
|
||||
stream=not fake_stream,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
@ -577,7 +579,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
logging_obj: Logging,
|
||||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
|
@ -890,6 +892,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
@ -917,7 +920,12 @@ class BedrockLLM(BaseAWSLLM):
|
|||
return streaming_response
|
||||
|
||||
try:
|
||||
response = self.client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
|
||||
response = self.client.post(
|
||||
url=proxy_endpoint_url,
|
||||
headers=dict(prepped.headers),
|
||||
data=data,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
|
@ -949,7 +957,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
logging_obj: Logging,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
|
@ -968,7 +976,13 @@ class BedrockLLM(BaseAWSLLM):
|
|||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
response = await client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
|
@ -990,6 +1004,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
encoding=encoding,
|
||||
)
|
||||
|
||||
@track_llm_api_timing() # for streaming, we need to instrument the function calling the wrapper
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -1000,7 +1015,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
logging_obj: Logging,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
|
|
|
@ -6,12 +6,17 @@ import httpx
|
|||
from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.types.llms.custom_http import *
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import LlmProviders
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObject,
|
||||
)
|
||||
else:
|
||||
LlmProviders = Any
|
||||
LiteLLMLoggingObject = Any
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
|
@ -156,6 +161,7 @@ class AsyncHTTPHandler:
|
|||
)
|
||||
return response
|
||||
|
||||
@track_llm_api_timing()
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
|
@ -165,6 +171,7 @@ class AsyncHTTPHandler:
|
|||
headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
stream: bool = False,
|
||||
logging_obj: Optional[LiteLLMLoggingObject] = None,
|
||||
):
|
||||
try:
|
||||
if timeout is None:
|
||||
|
@ -494,6 +501,7 @@ class HTTPHandler:
|
|||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
files: Optional[dict] = None,
|
||||
content: Any = None,
|
||||
logging_obj: Optional[LiteLLMLoggingObject] = None,
|
||||
):
|
||||
try:
|
||||
if timeout is not None:
|
||||
|
|
|
@ -27,6 +27,7 @@ import litellm
|
|||
from litellm import LlmProviders
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
|
||||
|
@ -380,11 +381,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
else:
|
||||
return client
|
||||
|
||||
@track_llm_api_timing()
|
||||
async def make_openai_chat_completion_request(
|
||||
self,
|
||||
openai_aclient: AsyncOpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Tuple[dict, BaseModel]:
|
||||
"""
|
||||
Helper to:
|
||||
|
@ -414,11 +417,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@track_llm_api_timing()
|
||||
def make_sync_openai_chat_completion_request(
|
||||
self,
|
||||
openai_client: OpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Tuple[dict, BaseModel]:
|
||||
"""
|
||||
Helper to:
|
||||
|
@ -630,6 +635,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
openai_client=openai_client,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -762,7 +768,10 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
headers, response = await self.make_openai_chat_completion_request(
|
||||
openai_aclient=openai_aclient, data=data, timeout=timeout
|
||||
openai_aclient=openai_aclient,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
stringified_response = response.model_dump()
|
||||
|
||||
|
@ -852,6 +861,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
openai_client=openai_client,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
logging_obj.model_call_details["response_headers"] = headers
|
||||
|
@ -910,7 +920,10 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
headers, response = await self.make_openai_chat_completion_request(
|
||||
openai_aclient=openai_aclient, data=data, timeout=timeout
|
||||
openai_aclient=openai_aclient,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
logging_obj.model_call_details["response_headers"] = headers
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
|
@ -965,11 +978,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
# Embedding
|
||||
@track_llm_api_timing()
|
||||
async def make_openai_embedding_request(
|
||||
self,
|
||||
openai_aclient: AsyncOpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
):
|
||||
"""
|
||||
Helper to:
|
||||
|
@ -986,11 +1001,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@track_llm_api_timing()
|
||||
def make_sync_openai_embedding_request(
|
||||
self,
|
||||
openai_client: OpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
):
|
||||
"""
|
||||
Helper to:
|
||||
|
@ -1030,7 +1047,10 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
client=client,
|
||||
)
|
||||
headers, response = await self.make_openai_embedding_request(
|
||||
openai_aclient=openai_aclient, data=data, timeout=timeout
|
||||
openai_aclient=openai_aclient,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
logging_obj.model_call_details["response_headers"] = headers
|
||||
stringified_response = response.model_dump()
|
||||
|
@ -1128,7 +1148,10 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
## embedding CALL
|
||||
headers: Optional[Dict] = None
|
||||
headers, sync_embedding_response = self.make_sync_openai_embedding_request(
|
||||
openai_client=openai_client, data=data, timeout=timeout
|
||||
openai_client=openai_client,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
) # type: ignore
|
||||
|
||||
## LOGGING
|
||||
|
|
|
@ -733,11 +733,13 @@ def get_custom_headers(
|
|||
version: Optional[str] = None,
|
||||
model_region: Optional[str] = None,
|
||||
response_cost: Optional[Union[float, str]] = None,
|
||||
hidden_params: Optional[dict] = None,
|
||||
fastest_response_batch_completion: Optional[bool] = None,
|
||||
request_data: Optional[dict] = {},
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
exclude_values = {"", None}
|
||||
hidden_params = hidden_params or {}
|
||||
headers = {
|
||||
"x-litellm-call-id": call_id,
|
||||
"x-litellm-model-id": model_id,
|
||||
|
@ -750,6 +752,10 @@ def get_custom_headers(
|
|||
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
|
||||
"x-litellm-key-max-budget": str(user_api_key_dict.max_budget),
|
||||
"x-litellm-key-spend": str(user_api_key_dict.spend),
|
||||
"x-litellm-response-duration-ms": str(hidden_params.get("_response_ms", None)),
|
||||
"x-litellm-overhead-duration-ms": str(
|
||||
hidden_params.get("litellm_overhead_time_ms", None)
|
||||
),
|
||||
"x-litellm-fastest_response_batch_completion": (
|
||||
str(fastest_response_batch_completion)
|
||||
if fastest_response_batch_completion is not None
|
||||
|
@ -3491,6 +3497,7 @@ async def chat_completion( # noqa: PLR0915
|
|||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
**additional_headers,
|
||||
)
|
||||
selected_data_generator = select_data_generator(
|
||||
|
@ -3526,6 +3533,7 @@ async def chat_completion( # noqa: PLR0915
|
|||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
**additional_headers,
|
||||
)
|
||||
)
|
||||
|
@ -3719,6 +3727,7 @@ async def completion( # noqa: PLR0915
|
|||
api_base=api_base,
|
||||
version=version,
|
||||
response_cost=response_cost,
|
||||
hidden_params=hidden_params,
|
||||
request_data=data,
|
||||
)
|
||||
selected_data_generator = select_data_generator(
|
||||
|
@ -3747,6 +3756,7 @@ async def completion( # noqa: PLR0915
|
|||
version=version,
|
||||
response_cost=response_cost,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
await check_response_size_is_safe(response=response)
|
||||
|
@ -3977,6 +3987,7 @@ async def embeddings( # noqa: PLR0915
|
|||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
call_id=litellm_call_id,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
**additional_headers,
|
||||
)
|
||||
)
|
||||
|
@ -4103,6 +4114,7 @@ async def image_generation(
|
|||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
call_id=litellm_call_id,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -4223,6 +4235,7 @@ async def audio_speech(
|
|||
fastest_response_batch_completion=None,
|
||||
call_id=litellm_call_id,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
|
||||
select_data_generator(
|
||||
|
@ -4362,6 +4375,7 @@ async def audio_transcriptions(
|
|||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
call_id=litellm_call_id,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
**additional_headers,
|
||||
)
|
||||
)
|
||||
|
@ -4510,6 +4524,7 @@ async def get_assistants(
|
|||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -4607,6 +4622,7 @@ async def create_assistant(
|
|||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -4703,6 +4719,7 @@ async def delete_assistant(
|
|||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -4799,6 +4816,7 @@ async def create_threads(
|
|||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -4894,6 +4912,7 @@ async def get_thread(
|
|||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -4992,6 +5011,7 @@ async def add_messages(
|
|||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -5086,6 +5106,7 @@ async def get_messages(
|
|||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -5194,6 +5215,7 @@ async def run_thread(
|
|||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -5316,6 +5338,7 @@ async def moderations(
|
|||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -5488,6 +5511,7 @@ async def anthropic_response( # noqa: PLR0915
|
|||
version=version,
|
||||
response_cost=response_cost,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
111
litellm/utils.py
111
litellm/utils.py
|
@ -93,6 +93,9 @@ from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
|
|||
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
|
||||
get_response_headers,
|
||||
)
|
||||
from litellm.litellm_core_utils.llm_response_utils.response_metadata import (
|
||||
ResponseMetadata,
|
||||
)
|
||||
from litellm.litellm_core_utils.redact_messages import (
|
||||
LiteLLMLoggingObject,
|
||||
redact_message_input_output_from_logging,
|
||||
|
@ -929,6 +932,15 @@ def client(original_function): # noqa: PLR0915
|
|||
chunks, messages=kwargs.get("messages", None)
|
||||
)
|
||||
else:
|
||||
# RETURN RESULT
|
||||
update_response_metadata(
|
||||
result=result,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
return result
|
||||
elif "acompletion" in kwargs and kwargs["acompletion"] is True:
|
||||
return result
|
||||
|
@ -966,25 +978,14 @@ def client(original_function): # noqa: PLR0915
|
|||
end_time,
|
||||
)
|
||||
# RETURN RESULT
|
||||
if hasattr(result, "_hidden_params"):
|
||||
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||
"id", None
|
||||
)
|
||||
result._hidden_params["api_base"] = get_api_base(
|
||||
model=model or "",
|
||||
optional_params=getattr(logging_obj, "optional_params", {}),
|
||||
)
|
||||
result._hidden_params["response_cost"] = (
|
||||
logging_obj._response_cost_calculator(result=result)
|
||||
)
|
||||
|
||||
result._hidden_params["additional_headers"] = process_response_headers(
|
||||
result._hidden_params.get("additional_headers") or {}
|
||||
) # GUARANTEE OPENAI HEADERS IN RESPONSE
|
||||
if result is not None:
|
||||
result._response_ms = (
|
||||
end_time - start_time
|
||||
).total_seconds() * 1000 # return response latency in ms like openai
|
||||
update_response_metadata(
|
||||
result=result,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
call_type = original_function.__name__
|
||||
|
@ -1116,39 +1117,17 @@ def client(original_function): # noqa: PLR0915
|
|||
chunks, messages=kwargs.get("messages", None)
|
||||
)
|
||||
else:
|
||||
update_response_metadata(
|
||||
result=result,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
return result
|
||||
elif call_type == CallTypes.arealtime.value:
|
||||
return result
|
||||
|
||||
# ADD HIDDEN PARAMS - additional call metadata
|
||||
if hasattr(result, "_hidden_params"):
|
||||
result._hidden_params["litellm_call_id"] = getattr(
|
||||
logging_obj, "litellm_call_id", None
|
||||
)
|
||||
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||
"id", None
|
||||
)
|
||||
result._hidden_params["api_base"] = get_api_base(
|
||||
model=model or "",
|
||||
optional_params=kwargs,
|
||||
)
|
||||
result._hidden_params["response_cost"] = (
|
||||
logging_obj._response_cost_calculator(result=result)
|
||||
)
|
||||
result._hidden_params["additional_headers"] = process_response_headers(
|
||||
result._hidden_params.get("additional_headers") or {}
|
||||
) # GUARANTEE OPENAI HEADERS IN RESPONSE
|
||||
if (
|
||||
isinstance(result, ModelResponse)
|
||||
or isinstance(result, EmbeddingResponse)
|
||||
or isinstance(result, TranscriptionResponse)
|
||||
):
|
||||
setattr(
|
||||
result,
|
||||
"_response_ms",
|
||||
(end_time - start_time).total_seconds() * 1000,
|
||||
) # return response latency in ms like openai
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(
|
||||
original_response=result, model=model, optional_params=kwargs
|
||||
|
@ -1190,6 +1169,15 @@ def client(original_function): # noqa: PLR0915
|
|||
end_time=end_time,
|
||||
)
|
||||
|
||||
update_response_metadata(
|
||||
result=result,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback_exception = traceback.format_exc()
|
||||
|
@ -1293,6 +1281,31 @@ def _is_async_request(
|
|||
return False
|
||||
|
||||
|
||||
def update_response_metadata(
|
||||
result: Any,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
model: Optional[str],
|
||||
kwargs: dict,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
) -> None:
|
||||
"""
|
||||
Updates response metadata, adds the following:
|
||||
- response._hidden_params
|
||||
- response._hidden_params["litellm_overhead_time_ms"]
|
||||
- response.response_time_ms
|
||||
"""
|
||||
if result is None:
|
||||
return
|
||||
|
||||
metadata = ResponseMetadata(result)
|
||||
metadata.set_hidden_params(logging_obj=logging_obj, model=model, kwargs=kwargs)
|
||||
metadata.set_timing_metrics(
|
||||
start_time=start_time, end_time=end_time, logging_obj=logging_obj
|
||||
)
|
||||
metadata.apply()
|
||||
|
||||
|
||||
def _select_tokenizer(
|
||||
model: str, custom_tokenizer: Optional[CustomHuggingfaceTokenizer] = None
|
||||
):
|
||||
|
|
116
tests/litellm_utils_tests/test_litellm_overhead.py
Normal file
116
tests/litellm_utils_tests/test_litellm_overhead.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"bedrock/mistral.mistral-7b-instruct-v0:2",
|
||||
"openai/gpt-4o",
|
||||
"openai/self_hosted",
|
||||
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
],
|
||||
)
|
||||
async def test_litellm_overhead(model):
|
||||
|
||||
litellm._turn_on_debug()
|
||||
start_time = datetime.now()
|
||||
if model == "openai/self_hosted":
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
)
|
||||
end_time = datetime.now()
|
||||
total_time_ms = (end_time - start_time).total_seconds() * 1000
|
||||
print(response)
|
||||
print(response._hidden_params)
|
||||
litellm_overhead_ms = response._hidden_params["litellm_overhead_time_ms"]
|
||||
# calculate percent of overhead caused by litellm
|
||||
overhead_percent = litellm_overhead_ms * 100 / total_time_ms
|
||||
print("##########################\n")
|
||||
print("total_time_ms", total_time_ms)
|
||||
print("response litellm_overhead_ms", litellm_overhead_ms)
|
||||
print("litellm overhead_percent {}%".format(overhead_percent))
|
||||
print("##########################\n")
|
||||
assert litellm_overhead_ms > 0
|
||||
assert litellm_overhead_ms < 1000
|
||||
|
||||
# latency overhead should be less than total request time
|
||||
assert litellm_overhead_ms < (end_time - start_time).total_seconds() * 1000
|
||||
|
||||
# latency overhead should be under 40% of total request time
|
||||
assert overhead_percent < 40
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"bedrock/mistral.mistral-7b-instruct-v0:2",
|
||||
"openai/gpt-4o",
|
||||
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"openai/self_hosted",
|
||||
],
|
||||
)
|
||||
async def test_litellm_overhead_stream(model):
|
||||
|
||||
litellm._turn_on_debug()
|
||||
start_time = datetime.now()
|
||||
if model == "openai/self_hosted":
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
stream=True,
|
||||
)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
print()
|
||||
|
||||
end_time = datetime.now()
|
||||
total_time_ms = (end_time - start_time).total_seconds() * 1000
|
||||
print(response)
|
||||
print(response._hidden_params)
|
||||
litellm_overhead_ms = response._hidden_params["litellm_overhead_time_ms"]
|
||||
# calculate percent of overhead caused by litellm
|
||||
overhead_percent = litellm_overhead_ms * 100 / total_time_ms
|
||||
print("##########################\n")
|
||||
print("total_time_ms", total_time_ms)
|
||||
print("response litellm_overhead_ms", litellm_overhead_ms)
|
||||
print("litellm overhead_percent {}%".format(overhead_percent))
|
||||
print("##########################\n")
|
||||
assert litellm_overhead_ms > 0
|
||||
assert litellm_overhead_ms < 1000
|
||||
|
||||
# latency overhead should be less than total request time
|
||||
assert litellm_overhead_ms < (end_time - start_time).total_seconds() * 1000
|
||||
|
||||
# latency overhead should be under 40% of total request time
|
||||
assert overhead_percent < 40
|
||||
|
||||
pass
|
Loading…
Add table
Add a link
Reference in a new issue