mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm dev 01 20 2025 p1 (#7884)
* fix(initial-test-to-return-api-timeout-value-in-openai-timeout-exception): Makes it easier for user to debug why request timed out * feat(openai.py): return timeout value + time taken on openai timeout errors helps debug timeout errors * fix(utils.py): fix num retries extraction logic when num_retries = 0 * fix(config_settings.md): litellm_logging.py support printing payload to console if 'LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD' is true Enables easier debug * test(test_auth_checks.py'): remove common checks userapikeyauth enforcement check * fix(litellm_logging.py): fix linting error
This commit is contained in:
parent
a4d3276bed
commit
4c1d4acabc
11 changed files with 65 additions and 46 deletions
|
@ -437,6 +437,7 @@ router_settings:
|
||||||
| LITELLM_SALT_KEY | Salt key for encryption in LiteLLM
|
| LITELLM_SALT_KEY | Salt key for encryption in LiteLLM
|
||||||
| LITELLM_SECRET_AWS_KMS_LITELLM_LICENSE | AWS KMS encrypted license for LiteLLM
|
| LITELLM_SECRET_AWS_KMS_LITELLM_LICENSE | AWS KMS encrypted license for LiteLLM
|
||||||
| LITELLM_TOKEN | Access token for LiteLLM integration
|
| LITELLM_TOKEN | Access token for LiteLLM integration
|
||||||
|
| LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD | If true, prints the standard logging payload to the console - useful for debugging
|
||||||
| LOGFIRE_TOKEN | Token for Logfire logging service
|
| LOGFIRE_TOKEN | Token for Logfire logging service
|
||||||
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
|
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
|
||||||
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
|
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
|
||||||
|
|
|
@ -224,8 +224,9 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
or "Timed out generating response" in error_str
|
or "Timed out generating response" in error_str
|
||||||
):
|
):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
|
||||||
raise Timeout(
|
raise Timeout(
|
||||||
message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}",
|
message=f"APITimeoutError - Request timed out. Error_str: {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
|
|
@ -3069,8 +3069,7 @@ def get_standard_logging_object_payload(
|
||||||
original_exception: Optional[Exception] = None,
|
original_exception: Optional[Exception] = None,
|
||||||
) -> Optional[StandardLoggingPayload]:
|
) -> Optional[StandardLoggingPayload]:
|
||||||
try:
|
try:
|
||||||
if kwargs is None:
|
kwargs = kwargs or {}
|
||||||
kwargs = {}
|
|
||||||
|
|
||||||
hidden_params: Optional[dict] = None
|
hidden_params: Optional[dict] = None
|
||||||
if init_response_obj is None:
|
if init_response_obj is None:
|
||||||
|
@ -3239,6 +3238,7 @@ def get_standard_logging_object_payload(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
emit_standard_logging_payload(payload)
|
||||||
return payload
|
return payload
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.exception(
|
verbose_logger.exception(
|
||||||
|
@ -3247,6 +3247,11 @@ def get_standard_logging_object_payload(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def emit_standard_logging_payload(payload: StandardLoggingPayload):
|
||||||
|
if os.getenv("LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD"):
|
||||||
|
verbose_logger.info(json.dumps(payload, indent=4))
|
||||||
|
|
||||||
|
|
||||||
def get_standard_logging_metadata(
|
def get_standard_logging_metadata(
|
||||||
metadata: Optional[Dict[str, Any]]
|
metadata: Optional[Dict[str, Any]]
|
||||||
) -> StandardLoggingMetadata:
|
) -> StandardLoggingMetadata:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import time
|
||||||
import types
|
import types
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
@ -390,12 +391,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||||
- call chat.completions.create by default
|
- call chat.completions.create by default
|
||||||
"""
|
"""
|
||||||
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
raw_response = (
|
raw_response = (
|
||||||
await openai_aclient.chat.completions.with_raw_response.create(
|
await openai_aclient.chat.completions.with_raw_response.create(
|
||||||
**data, timeout=timeout
|
**data, timeout=timeout
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
if hasattr(raw_response, "headers"):
|
if hasattr(raw_response, "headers"):
|
||||||
headers = dict(raw_response.headers)
|
headers = dict(raw_response.headers)
|
||||||
|
@ -403,6 +406,11 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
headers = {}
|
headers = {}
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
return headers, response
|
return headers, response
|
||||||
|
except openai.APITimeoutError as e:
|
||||||
|
end_time = time.time()
|
||||||
|
time_delta = round(end_time - start_time, 2)
|
||||||
|
e.message += f" - timeout value={timeout}, time taken={time_delta} seconds"
|
||||||
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -521,6 +529,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
for _ in range(
|
for _ in range(
|
||||||
2
|
2
|
||||||
): # if call fails due to alternating messages, retry with reformatted message
|
): # if call fails due to alternating messages, retry with reformatted message
|
||||||
|
|
||||||
if provider_config is not None:
|
if provider_config is not None:
|
||||||
data = provider_config.transform_request(
|
data = provider_config.transform_request(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -725,6 +734,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
for _ in range(
|
for _ in range(
|
||||||
2
|
2
|
||||||
): # if call fails due to alternating messages, retry with reformatted message
|
): # if call fails due to alternating messages, retry with reformatted message
|
||||||
|
|
||||||
try:
|
try:
|
||||||
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
||||||
is_async=True,
|
is_async=True,
|
||||||
|
@ -792,9 +802,10 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
error_headers = getattr(e, "headers", None)
|
error_headers = getattr(e, "headers", None)
|
||||||
if error_headers is None and exception_response:
|
if error_headers is None and exception_response:
|
||||||
error_headers = getattr(exception_response, "headers", None)
|
error_headers = getattr(exception_response, "headers", None)
|
||||||
|
message = getattr(e, "message", str(e))
|
||||||
|
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
status_code=status_code, message=str(e), headers=error_headers
|
status_code=status_code, message=message, headers=error_headers
|
||||||
)
|
)
|
||||||
|
|
||||||
def streaming(
|
def streaming(
|
||||||
|
|
|
@ -9,5 +9,5 @@ model_list:
|
||||||
model: openai/random_sleep
|
model: openai/random_sleep
|
||||||
api_key: sk-1234
|
api_key: sk-1234
|
||||||
api_base: http://0.0.0.0:8090
|
api_base: http://0.0.0.0:8090
|
||||||
model_info:
|
timeout: 2
|
||||||
health_check_timeout: 1
|
num_retries: 0
|
||||||
|
|
|
@ -3073,6 +3073,7 @@ class Router:
|
||||||
deployment_num_retries, int
|
deployment_num_retries, int
|
||||||
):
|
):
|
||||||
num_retries = deployment_num_retries
|
num_retries = deployment_num_retries
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Retry Logic
|
Retry Logic
|
||||||
"""
|
"""
|
||||||
|
@ -3119,6 +3120,7 @@ class Router:
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.sleep(retry_after)
|
await asyncio.sleep(retry_after)
|
||||||
|
|
||||||
for current_attempt in range(num_retries):
|
for current_attempt in range(num_retries):
|
||||||
try:
|
try:
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
|
|
|
@ -622,6 +622,28 @@ async def _client_async_logging_helper(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_wrapper_num_retries(
|
||||||
|
kwargs: Dict[str, Any], exception: Exception
|
||||||
|
) -> Tuple[Optional[int], Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get the number of retries from the kwargs and the retry policy.
|
||||||
|
Used for the wrapper functions.
|
||||||
|
"""
|
||||||
|
num_retries = kwargs.get("num_retries", None)
|
||||||
|
if num_retries is None:
|
||||||
|
num_retries = litellm.num_retries
|
||||||
|
if kwargs.get("retry_policy", None):
|
||||||
|
retry_policy_num_retries = get_num_retries_from_retry_policy(
|
||||||
|
exception=exception,
|
||||||
|
retry_policy=kwargs.get("retry_policy"),
|
||||||
|
)
|
||||||
|
kwargs["retry_policy"] = reset_retry_policy()
|
||||||
|
if retry_policy_num_retries is not None:
|
||||||
|
num_retries = retry_policy_num_retries
|
||||||
|
|
||||||
|
return num_retries, kwargs
|
||||||
|
|
||||||
|
|
||||||
def client(original_function): # noqa: PLR0915
|
def client(original_function): # noqa: PLR0915
|
||||||
rules_obj = Rules()
|
rules_obj = Rules()
|
||||||
|
|
||||||
|
@ -736,19 +758,6 @@ def client(original_function): # noqa: PLR0915
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _get_num_retries(
|
|
||||||
kwargs: Dict[str, Any], exception: Exception
|
|
||||||
) -> Tuple[Optional[int], Dict[str, Any]]:
|
|
||||||
num_retries = kwargs.get("num_retries", None) or litellm.num_retries or None
|
|
||||||
if kwargs.get("retry_policy", None):
|
|
||||||
num_retries = get_num_retries_from_retry_policy(
|
|
||||||
exception=exception,
|
|
||||||
retry_policy=kwargs.get("retry_policy"),
|
|
||||||
)
|
|
||||||
kwargs["retry_policy"] = reset_retry_policy()
|
|
||||||
|
|
||||||
return num_retries, kwargs
|
|
||||||
|
|
||||||
@wraps(original_function)
|
@wraps(original_function)
|
||||||
def wrapper(*args, **kwargs): # noqa: PLR0915
|
def wrapper(*args, **kwargs): # noqa: PLR0915
|
||||||
# DO NOT MOVE THIS. It always needs to run first
|
# DO NOT MOVE THIS. It always needs to run first
|
||||||
|
@ -1200,7 +1209,7 @@ def client(original_function): # noqa: PLR0915
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
call_type = original_function.__name__
|
call_type = original_function.__name__
|
||||||
num_retries, kwargs = _get_num_retries(kwargs=kwargs, exception=e)
|
num_retries, kwargs = _get_wrapper_num_retries(kwargs=kwargs, exception=e)
|
||||||
if call_type == CallTypes.acompletion.value:
|
if call_type == CallTypes.acompletion.value:
|
||||||
context_window_fallback_dict = kwargs.get(
|
context_window_fallback_dict = kwargs.get(
|
||||||
"context_window_fallback_dict", {}
|
"context_window_fallback_dict", {}
|
||||||
|
|
|
@ -201,31 +201,6 @@ async def test_can_team_call_model(model, expect_to_work):
|
||||||
assert not model_in_access_group(**args)
|
assert not model_in_access_group(**args)
|
||||||
|
|
||||||
|
|
||||||
def test_common_checks_import():
|
|
||||||
"""
|
|
||||||
Enforce that common_checks can only be imported by the 'user_api_key_auth()' function.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from litellm.proxy.auth.user_api_key_auth import common_checks
|
|
||||||
from litellm.proxy._types import CommonProxyErrors
|
|
||||||
|
|
||||||
common_checks(
|
|
||||||
request_body={},
|
|
||||||
team_object=None,
|
|
||||||
user_object=None,
|
|
||||||
end_user_object=None,
|
|
||||||
global_proxy_spend=None,
|
|
||||||
general_settings={},
|
|
||||||
route="",
|
|
||||||
llm_router=None,
|
|
||||||
)
|
|
||||||
pytest.fail(
|
|
||||||
"common_checks can only be imported by the 'user_api_key_auth()' function."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
assert CommonProxyErrors.not_premium_user.value in str(e)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_is_valid_fallback_model():
|
async def test_is_valid_fallback_model():
|
||||||
from litellm.proxy.auth.auth_checks import is_valid_fallback_model
|
from litellm.proxy.auth.auth_checks import is_valid_fallback_model
|
||||||
|
|
|
@ -74,6 +74,7 @@ async def test_aaaaazure_tenant_id_auth(respx_mock: MockRouter):
|
||||||
created=int(datetime.now().timestamp()),
|
created=int(datetime.now().timestamp()),
|
||||||
)
|
)
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
mock_request = respx_mock.post(url__regex=r".*/chat/completions.*").mock(
|
mock_request = respx_mock.post(url__regex=r".*/chat/completions.*").mock(
|
||||||
return_value=httpx.Response(200, json=obj.model_dump(mode="json"))
|
return_value=httpx.Response(200, json=obj.model_dump(mode="json"))
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,7 +20,7 @@ import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from respx import MockRouter
|
||||||
import httpx
|
import httpx
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
|
@ -1480,3 +1480,17 @@ def test_get_potential_model_names():
|
||||||
model="bedrock/ap-northeast-1/anthropic.claude-instant-v1",
|
model="bedrock/ap-northeast-1/anthropic.claude-instant-v1",
|
||||||
custom_llm_provider="bedrock",
|
custom_llm_provider="bedrock",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_retries", [0, 1, 5])
|
||||||
|
def test_get_num_retries(num_retries):
|
||||||
|
from litellm.utils import _get_wrapper_num_retries
|
||||||
|
|
||||||
|
assert _get_wrapper_num_retries(
|
||||||
|
kwargs={"num_retries": num_retries}, exception=Exception("test")
|
||||||
|
) == (
|
||||||
|
num_retries,
|
||||||
|
{
|
||||||
|
"num_retries": num_retries,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue