Litellm dev 01 20 2025 p1 (#7884)

* fix(initial-test-to-return-api-timeout-value-in-openai-timeout-exception): Makes it easier for user to debug why request timed out

* feat(openai.py): return timeout value + time taken on openai timeout errors

helps debug timeout errors

* fix(utils.py): fix num retries extraction logic when num_retries = 0

* fix(config_settings.md): litellm_logging.py

support printing payload to console if 'LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD' is true

 Enables easier debug

* test(test_auth_checks.py'): remove common checks userapikeyauth enforcement check

* fix(litellm_logging.py): fix linting error
This commit is contained in:
Krish Dholakia 2025-01-20 21:45:48 -08:00 committed by GitHub
parent 806df5d31c
commit 4b23420a20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 65 additions and 46 deletions

View file

@ -437,6 +437,7 @@ router_settings:
| LITELLM_SALT_KEY | Salt key for encryption in LiteLLM | LITELLM_SALT_KEY | Salt key for encryption in LiteLLM
| LITELLM_SECRET_AWS_KMS_LITELLM_LICENSE | AWS KMS encrypted license for LiteLLM | LITELLM_SECRET_AWS_KMS_LITELLM_LICENSE | AWS KMS encrypted license for LiteLLM
| LITELLM_TOKEN | Access token for LiteLLM integration | LITELLM_TOKEN | Access token for LiteLLM integration
| LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD | If true, prints the standard logging payload to the console - useful for debugging
| LOGFIRE_TOKEN | Token for Logfire logging service | LOGFIRE_TOKEN | Token for Logfire logging service
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services | MICROSOFT_CLIENT_ID | Client ID for Microsoft services
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services | MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services

View file

@ -224,8 +224,9 @@ def exception_type( # type: ignore # noqa: PLR0915
or "Timed out generating response" in error_str or "Timed out generating response" in error_str
): ):
exception_mapping_worked = True exception_mapping_worked = True
raise Timeout( raise Timeout(
message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}", message=f"APITimeoutError - Request timed out. Error_str: {error_str}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,

View file

@ -3069,8 +3069,7 @@ def get_standard_logging_object_payload(
original_exception: Optional[Exception] = None, original_exception: Optional[Exception] = None,
) -> Optional[StandardLoggingPayload]: ) -> Optional[StandardLoggingPayload]:
try: try:
if kwargs is None: kwargs = kwargs or {}
kwargs = {}
hidden_params: Optional[dict] = None hidden_params: Optional[dict] = None
if init_response_obj is None: if init_response_obj is None:
@ -3239,6 +3238,7 @@ def get_standard_logging_object_payload(
), ),
) )
emit_standard_logging_payload(payload)
return payload return payload
except Exception as e: except Exception as e:
verbose_logger.exception( verbose_logger.exception(
@ -3247,6 +3247,11 @@ def get_standard_logging_object_payload(
return None return None
def emit_standard_logging_payload(payload: StandardLoggingPayload):
if os.getenv("LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD"):
verbose_logger.info(json.dumps(payload, indent=4))
def get_standard_logging_metadata( def get_standard_logging_metadata(
metadata: Optional[Dict[str, Any]] metadata: Optional[Dict[str, Any]]
) -> StandardLoggingMetadata: ) -> StandardLoggingMetadata:

View file

@ -1,4 +1,5 @@
import hashlib import hashlib
import time
import types import types
from typing import ( from typing import (
Any, Any,
@ -390,12 +391,14 @@ class OpenAIChatCompletion(BaseLLM):
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True - call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default - call chat.completions.create by default
""" """
start_time = time.time()
try: try:
raw_response = ( raw_response = (
await openai_aclient.chat.completions.with_raw_response.create( await openai_aclient.chat.completions.with_raw_response.create(
**data, timeout=timeout **data, timeout=timeout
) )
) )
end_time = time.time()
if hasattr(raw_response, "headers"): if hasattr(raw_response, "headers"):
headers = dict(raw_response.headers) headers = dict(raw_response.headers)
@ -403,6 +406,11 @@ class OpenAIChatCompletion(BaseLLM):
headers = {} headers = {}
response = raw_response.parse() response = raw_response.parse()
return headers, response return headers, response
except openai.APITimeoutError as e:
end_time = time.time()
time_delta = round(end_time - start_time, 2)
e.message += f" - timeout value={timeout}, time taken={time_delta} seconds"
raise e
except Exception as e: except Exception as e:
raise e raise e
@ -521,6 +529,7 @@ class OpenAIChatCompletion(BaseLLM):
for _ in range( for _ in range(
2 2
): # if call fails due to alternating messages, retry with reformatted message ): # if call fails due to alternating messages, retry with reformatted message
if provider_config is not None: if provider_config is not None:
data = provider_config.transform_request( data = provider_config.transform_request(
model=model, model=model,
@ -725,6 +734,7 @@ class OpenAIChatCompletion(BaseLLM):
for _ in range( for _ in range(
2 2
): # if call fails due to alternating messages, retry with reformatted message ): # if call fails due to alternating messages, retry with reformatted message
try: try:
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
is_async=True, is_async=True,
@ -792,9 +802,10 @@ class OpenAIChatCompletion(BaseLLM):
error_headers = getattr(e, "headers", None) error_headers = getattr(e, "headers", None)
if error_headers is None and exception_response: if error_headers is None and exception_response:
error_headers = getattr(exception_response, "headers", None) error_headers = getattr(exception_response, "headers", None)
message = getattr(e, "message", str(e))
raise OpenAIError( raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers status_code=status_code, message=message, headers=error_headers
) )
def streaming( def streaming(

View file

@ -9,5 +9,5 @@ model_list:
model: openai/random_sleep model: openai/random_sleep
api_key: sk-1234 api_key: sk-1234
api_base: http://0.0.0.0:8090 api_base: http://0.0.0.0:8090
model_info: timeout: 2
health_check_timeout: 1 num_retries: 0

View file

@ -3073,6 +3073,7 @@ class Router:
deployment_num_retries, int deployment_num_retries, int
): ):
num_retries = deployment_num_retries num_retries = deployment_num_retries
""" """
Retry Logic Retry Logic
""" """
@ -3119,6 +3120,7 @@ class Router:
) )
await asyncio.sleep(retry_after) await asyncio.sleep(retry_after)
for current_attempt in range(num_retries): for current_attempt in range(num_retries):
try: try:
# if the function call is successful, no exception will be raised and we'll break out of the loop # if the function call is successful, no exception will be raised and we'll break out of the loop

View file

@ -622,6 +622,28 @@ async def _client_async_logging_helper(
) )
def _get_wrapper_num_retries(
kwargs: Dict[str, Any], exception: Exception
) -> Tuple[Optional[int], Dict[str, Any]]:
"""
Get the number of retries from the kwargs and the retry policy.
Used for the wrapper functions.
"""
num_retries = kwargs.get("num_retries", None)
if num_retries is None:
num_retries = litellm.num_retries
if kwargs.get("retry_policy", None):
retry_policy_num_retries = get_num_retries_from_retry_policy(
exception=exception,
retry_policy=kwargs.get("retry_policy"),
)
kwargs["retry_policy"] = reset_retry_policy()
if retry_policy_num_retries is not None:
num_retries = retry_policy_num_retries
return num_retries, kwargs
def client(original_function): # noqa: PLR0915 def client(original_function): # noqa: PLR0915
rules_obj = Rules() rules_obj = Rules()
@ -736,19 +758,6 @@ def client(original_function): # noqa: PLR0915
except Exception as e: except Exception as e:
raise e raise e
def _get_num_retries(
kwargs: Dict[str, Any], exception: Exception
) -> Tuple[Optional[int], Dict[str, Any]]:
num_retries = kwargs.get("num_retries", None) or litellm.num_retries or None
if kwargs.get("retry_policy", None):
num_retries = get_num_retries_from_retry_policy(
exception=exception,
retry_policy=kwargs.get("retry_policy"),
)
kwargs["retry_policy"] = reset_retry_policy()
return num_retries, kwargs
@wraps(original_function) @wraps(original_function)
def wrapper(*args, **kwargs): # noqa: PLR0915 def wrapper(*args, **kwargs): # noqa: PLR0915
# DO NOT MOVE THIS. It always needs to run first # DO NOT MOVE THIS. It always needs to run first
@ -1200,7 +1209,7 @@ def client(original_function): # noqa: PLR0915
raise e raise e
call_type = original_function.__name__ call_type = original_function.__name__
num_retries, kwargs = _get_num_retries(kwargs=kwargs, exception=e) num_retries, kwargs = _get_wrapper_num_retries(kwargs=kwargs, exception=e)
if call_type == CallTypes.acompletion.value: if call_type == CallTypes.acompletion.value:
context_window_fallback_dict = kwargs.get( context_window_fallback_dict = kwargs.get(
"context_window_fallback_dict", {} "context_window_fallback_dict", {}

View file

@ -201,31 +201,6 @@ async def test_can_team_call_model(model, expect_to_work):
assert not model_in_access_group(**args) assert not model_in_access_group(**args)
def test_common_checks_import():
"""
Enforce that common_checks can only be imported by the 'user_api_key_auth()' function.
"""
try:
from litellm.proxy.auth.user_api_key_auth import common_checks
from litellm.proxy._types import CommonProxyErrors
common_checks(
request_body={},
team_object=None,
user_object=None,
end_user_object=None,
global_proxy_spend=None,
general_settings={},
route="",
llm_router=None,
)
pytest.fail(
"common_checks can only be imported by the 'user_api_key_auth()' function."
)
except Exception as e:
assert CommonProxyErrors.not_premium_user.value in str(e)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_valid_fallback_model(): async def test_is_valid_fallback_model():
from litellm.proxy.auth.auth_checks import is_valid_fallback_model from litellm.proxy.auth.auth_checks import is_valid_fallback_model

View file

@ -74,6 +74,7 @@ async def test_aaaaazure_tenant_id_auth(respx_mock: MockRouter):
created=int(datetime.now().timestamp()), created=int(datetime.now().timestamp()),
) )
litellm.set_verbose = True litellm.set_verbose = True
mock_request = respx_mock.post(url__regex=r".*/chat/completions.*").mock( mock_request = respx_mock.post(url__regex=r".*/chat/completions.*").mock(
return_value=httpx.Response(200, json=obj.model_dump(mode="json")) return_value=httpx.Response(200, json=obj.model_dump(mode="json"))
) )

View file

@ -20,7 +20,7 @@ import os
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from respx import MockRouter
import httpx import httpx
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel from pydantic import BaseModel

View file

@ -1480,3 +1480,17 @@ def test_get_potential_model_names():
model="bedrock/ap-northeast-1/anthropic.claude-instant-v1", model="bedrock/ap-northeast-1/anthropic.claude-instant-v1",
custom_llm_provider="bedrock", custom_llm_provider="bedrock",
) )
@pytest.mark.parametrize("num_retries", [0, 1, 5])
def test_get_num_retries(num_retries):
from litellm.utils import _get_wrapper_num_retries
assert _get_wrapper_num_retries(
kwargs={"num_retries": num_retries}, exception=Exception("test")
) == (
num_retries,
{
"num_retries": num_retries,
},
)