Litellm dev 01 20 2025 p1 (#7884)

* fix(initial-test-to-return-api-timeout-value-in-openai-timeout-exception): Makes it easier for user to debug why request timed out

* feat(openai.py): return timeout value + time taken on openai timeout errors

helps debug timeout errors

* fix(utils.py): fix num retries extraction logic when num_retries = 0

* fix(config_settings.md): litellm_logging.py

support printing payload to console if 'LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD' is true

 Enables easier debug

* test(test_auth_checks.py'): remove common checks userapikeyauth enforcement check

* fix(litellm_logging.py): fix linting error
This commit is contained in:
Krish Dholakia 2025-01-20 21:45:48 -08:00 committed by GitHub
parent a4d3276bed
commit 4c1d4acabc
11 changed files with 65 additions and 46 deletions

View file

@ -437,6 +437,7 @@ router_settings:
| LITELLM_SALT_KEY | Salt key for encryption in LiteLLM
| LITELLM_SECRET_AWS_KMS_LITELLM_LICENSE | AWS KMS encrypted license for LiteLLM
| LITELLM_TOKEN | Access token for LiteLLM integration
| LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD | If true, prints the standard logging payload to the console - useful for debugging
| LOGFIRE_TOKEN | Token for Logfire logging service
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services

View file

@ -224,8 +224,9 @@ def exception_type( # type: ignore # noqa: PLR0915
or "Timed out generating response" in error_str
):
exception_mapping_worked = True
raise Timeout(
message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}",
message=f"APITimeoutError - Request timed out. Error_str: {error_str}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,

View file

@ -3069,8 +3069,7 @@ def get_standard_logging_object_payload(
original_exception: Optional[Exception] = None,
) -> Optional[StandardLoggingPayload]:
try:
if kwargs is None:
kwargs = {}
kwargs = kwargs or {}
hidden_params: Optional[dict] = None
if init_response_obj is None:
@ -3239,6 +3238,7 @@ def get_standard_logging_object_payload(
),
)
emit_standard_logging_payload(payload)
return payload
except Exception as e:
verbose_logger.exception(
@ -3247,6 +3247,11 @@ def get_standard_logging_object_payload(
return None
def emit_standard_logging_payload(payload: StandardLoggingPayload):
if os.getenv("LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD"):
verbose_logger.info(json.dumps(payload, indent=4))
def get_standard_logging_metadata(
metadata: Optional[Dict[str, Any]]
) -> StandardLoggingMetadata:

View file

@ -1,4 +1,5 @@
import hashlib
import time
import types
from typing import (
Any,
@ -390,12 +391,14 @@ class OpenAIChatCompletion(BaseLLM):
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
start_time = time.time()
try:
raw_response = (
await openai_aclient.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
)
end_time = time.time()
if hasattr(raw_response, "headers"):
headers = dict(raw_response.headers)
@ -403,6 +406,11 @@ class OpenAIChatCompletion(BaseLLM):
headers = {}
response = raw_response.parse()
return headers, response
except openai.APITimeoutError as e:
end_time = time.time()
time_delta = round(end_time - start_time, 2)
e.message += f" - timeout value={timeout}, time taken={time_delta} seconds"
raise e
except Exception as e:
raise e
@ -521,6 +529,7 @@ class OpenAIChatCompletion(BaseLLM):
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
if provider_config is not None:
data = provider_config.transform_request(
model=model,
@ -725,6 +734,7 @@ class OpenAIChatCompletion(BaseLLM):
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
try:
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
is_async=True,
@ -792,9 +802,10 @@ class OpenAIChatCompletion(BaseLLM):
error_headers = getattr(e, "headers", None)
if error_headers is None and exception_response:
error_headers = getattr(exception_response, "headers", None)
message = getattr(e, "message", str(e))
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
status_code=status_code, message=message, headers=error_headers
)
def streaming(

View file

@ -9,5 +9,5 @@ model_list:
model: openai/random_sleep
api_key: sk-1234
api_base: http://0.0.0.0:8090
model_info:
health_check_timeout: 1
timeout: 2
num_retries: 0

View file

@ -3073,6 +3073,7 @@ class Router:
deployment_num_retries, int
):
num_retries = deployment_num_retries
"""
Retry Logic
"""
@ -3119,6 +3120,7 @@ class Router:
)
await asyncio.sleep(retry_after)
for current_attempt in range(num_retries):
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop

View file

@ -622,6 +622,28 @@ async def _client_async_logging_helper(
)
def _get_wrapper_num_retries(
kwargs: Dict[str, Any], exception: Exception
) -> Tuple[Optional[int], Dict[str, Any]]:
"""
Get the number of retries from the kwargs and the retry policy.
Used for the wrapper functions.
"""
num_retries = kwargs.get("num_retries", None)
if num_retries is None:
num_retries = litellm.num_retries
if kwargs.get("retry_policy", None):
retry_policy_num_retries = get_num_retries_from_retry_policy(
exception=exception,
retry_policy=kwargs.get("retry_policy"),
)
kwargs["retry_policy"] = reset_retry_policy()
if retry_policy_num_retries is not None:
num_retries = retry_policy_num_retries
return num_retries, kwargs
def client(original_function): # noqa: PLR0915
rules_obj = Rules()
@ -736,19 +758,6 @@ def client(original_function): # noqa: PLR0915
except Exception as e:
raise e
def _get_num_retries(
kwargs: Dict[str, Any], exception: Exception
) -> Tuple[Optional[int], Dict[str, Any]]:
num_retries = kwargs.get("num_retries", None) or litellm.num_retries or None
if kwargs.get("retry_policy", None):
num_retries = get_num_retries_from_retry_policy(
exception=exception,
retry_policy=kwargs.get("retry_policy"),
)
kwargs["retry_policy"] = reset_retry_policy()
return num_retries, kwargs
@wraps(original_function)
def wrapper(*args, **kwargs): # noqa: PLR0915
# DO NOT MOVE THIS. It always needs to run first
@ -1200,7 +1209,7 @@ def client(original_function): # noqa: PLR0915
raise e
call_type = original_function.__name__
num_retries, kwargs = _get_num_retries(kwargs=kwargs, exception=e)
num_retries, kwargs = _get_wrapper_num_retries(kwargs=kwargs, exception=e)
if call_type == CallTypes.acompletion.value:
context_window_fallback_dict = kwargs.get(
"context_window_fallback_dict", {}

View file

@ -201,31 +201,6 @@ async def test_can_team_call_model(model, expect_to_work):
assert not model_in_access_group(**args)
def test_common_checks_import():
"""
Enforce that common_checks can only be imported by the 'user_api_key_auth()' function.
"""
try:
from litellm.proxy.auth.user_api_key_auth import common_checks
from litellm.proxy._types import CommonProxyErrors
common_checks(
request_body={},
team_object=None,
user_object=None,
end_user_object=None,
global_proxy_spend=None,
general_settings={},
route="",
llm_router=None,
)
pytest.fail(
"common_checks can only be imported by the 'user_api_key_auth()' function."
)
except Exception as e:
assert CommonProxyErrors.not_premium_user.value in str(e)
@pytest.mark.asyncio
async def test_is_valid_fallback_model():
from litellm.proxy.auth.auth_checks import is_valid_fallback_model

View file

@ -74,6 +74,7 @@ async def test_aaaaazure_tenant_id_auth(respx_mock: MockRouter):
created=int(datetime.now().timestamp()),
)
litellm.set_verbose = True
mock_request = respx_mock.post(url__regex=r".*/chat/completions.*").mock(
return_value=httpx.Response(200, json=obj.model_dump(mode="json"))
)

View file

@ -20,7 +20,7 @@ import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import AsyncMock, MagicMock, patch
from respx import MockRouter
import httpx
from dotenv import load_dotenv
from pydantic import BaseModel

View file

@ -1480,3 +1480,17 @@ def test_get_potential_model_names():
model="bedrock/ap-northeast-1/anthropic.claude-instant-v1",
custom_llm_provider="bedrock",
)
@pytest.mark.parametrize("num_retries", [0, 1, 5])
def test_get_num_retries(num_retries):
from litellm.utils import _get_wrapper_num_retries
assert _get_wrapper_num_retries(
kwargs={"num_retries": num_retries}, exception=Exception("test")
) == (
num_retries,
{
"num_retries": num_retries,
},
)