LiteLLM Minor Fixes & Improvements (12/27/2024) - p1 (#7448)

* feat(main.py): mock_response() - support 'litellm.ContextWindowExceededError' in mock response

enabled quicker router/fallback/proxy debug on context window errors

* feat(exception_mapping_utils.py): extract special litellm errors from error str if calling `litellm_proxy/` as provider

Closes https://github.com/BerriAI/litellm/issues/7259

* fix(user_api_key_auth.py): specify 'Received Proxy Server Request' is span kind server

Closes https://github.com/BerriAI/litellm/issues/7298
This commit is contained in:
Krish Dholakia 2024-12-27 19:04:39 -08:00 committed by GitHub
parent cca9cfe667
commit 67b39bacf7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 83 additions and 5 deletions

View file

@ -337,20 +337,22 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
): ):
self.status_code = 400 self.status_code = 400
self.message = "litellm.ContextWindowExceededError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
request = httpx.Request(method="POST", url="https://api.openai.com/v1") request = httpx.Request(method="POST", url="https://api.openai.com/v1")
self.response = httpx.Response(status_code=400, request=request) self.response = httpx.Response(status_code=400, request=request)
super().__init__( super().__init__(
message=self.message, message=message,
model=self.model, # type: ignore model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore llm_provider=self.llm_provider, # type: ignore
response=self.response, response=self.response,
litellm_debug_info=self.litellm_debug_info, litellm_debug_info=self.litellm_debug_info,
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# set after, to make it clear the raised error is a context window exceeded error
self.message = "litellm.ContextWindowExceededError: {}".format(self.message)
def __str__(self): def __str__(self):
_message = self.message _message = self.message
if self.num_retries: if self.num_retries:

View file

@ -84,6 +84,7 @@ class OpenTelemetry(CustomLogger):
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import SpanKind
if config is None: if config is None:
config = OpenTelemetryConfig.from_env() config = OpenTelemetryConfig.from_env()
@ -99,6 +100,8 @@ class OpenTelemetry(CustomLogger):
trace.set_tracer_provider(provider) trace.set_tracer_provider(provider)
self.tracer = trace.get_tracer(LITELLM_TRACER_NAME) self.tracer = trace.get_tracer(LITELLM_TRACER_NAME)
self.span_kind = SpanKind
_debug_otel = str(os.getenv("DEBUG_OTEL", "False")).lower() _debug_otel = str(os.getenv("DEBUG_OTEL", "False")).lower()
if _debug_otel == "true": if _debug_otel == "true":

View file

@ -1,6 +1,6 @@
import json import json
import traceback import traceback
from typing import Optional from typing import Any, Optional
import httpx import httpx
@ -84,6 +84,41 @@ def _get_response_headers(original_exception: Exception) -> Optional[httpx.Heade
return _response_headers return _response_headers
import re
def extract_and_raise_litellm_exception(
response: Optional[Any],
error_str: str,
model: str,
custom_llm_provider: str,
):
"""
Covers scenario where litellm sdk calling proxy.
Enables raising the special errors raised by litellm, eg. ContextWindowExceededError.
Relevant Issue: https://github.com/BerriAI/litellm/issues/7259
"""
pattern = r"litellm\.\w+Error"
# Search for the exception in the error string
match = re.search(pattern, error_str)
# Extract the exception if found
if match:
exception_name = match.group(0)
exception_name = exception_name.strip().replace("litellm.", "")
raised_exception_obj = getattr(litellm, exception_name, None)
if raised_exception_obj:
raise raised_exception_obj(
message=error_str,
llm_provider=custom_llm_provider,
model=model,
response=response,
)
def exception_type( # type: ignore # noqa: PLR0915 def exception_type( # type: ignore # noqa: PLR0915
model, model,
original_exception, original_exception,
@ -197,6 +232,15 @@ def exception_type( # type: ignore # noqa: PLR0915
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
if (
custom_llm_provider == "litellm_proxy"
): # handle special case where calling litellm proxy + exception str contains error message
extract_and_raise_litellm_exception(
response=getattr(original_exception, "response", None),
error_str=error_str,
model=model,
custom_llm_provider=custom_llm_provider,
)
if ( if (
custom_llm_provider == "openai" custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"

View file

@ -550,6 +550,17 @@ def _handle_mock_potential_exceptions(
), # type: ignore ), # type: ignore
model=model, model=model,
) )
elif (
isinstance(mock_response, str)
and mock_response == "litellm.ContextWindowExceededError"
):
raise litellm.ContextWindowExceededError(
message="this is a mock context window exceeded error",
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif ( elif (
isinstance(mock_response, str) isinstance(mock_response, str)
and mock_response == "litellm.InternalServerError" and mock_response == "litellm.InternalServerError"
@ -734,7 +745,7 @@ def mock_completion(
except Exception as e: except Exception as e:
if isinstance(e, openai.APIError): if isinstance(e, openai.APIError):
raise e raise e
raise Exception("Mock completion response failed") raise Exception("Mock completion response failed - {}".format(e))
@client @client

View file

@ -14,4 +14,4 @@ model_list:
router_settings: router_settings:
routing_strategy: usage-based-routing-v2 routing_strategy: usage-based-routing-v2
disable_cooldowns: True disable_cooldowns: True

View file

@ -281,12 +281,14 @@ async def user_api_key_auth( # noqa: PLR0915
) )
if open_telemetry_logger is not None: if open_telemetry_logger is not None:
parent_otel_span = open_telemetry_logger.tracer.start_span( parent_otel_span = open_telemetry_logger.tracer.start_span(
name="Received Proxy Server Request", name="Received Proxy Server Request",
start_time=_to_ns(start_time), start_time=_to_ns(start_time),
context=open_telemetry_logger.get_traceparent_from_header( context=open_telemetry_logger.get_traceparent_from_header(
headers=request.headers headers=request.headers
), ),
kind=open_telemetry_logger.span_kind.SERVER,
) )
### USER-DEFINED AUTH FUNCTION ### ### USER-DEFINED AUTH FUNCTION ###

View file

@ -1189,3 +1189,19 @@ def test_exceptions_base_class():
assert isinstance(e, litellm.RateLimitError) assert isinstance(e, litellm.RateLimitError)
assert e.code == "429" assert e.code == "429"
assert e.type == "throttling_error" assert e.type == "throttling_error"
def test_context_window_exceeded_error_from_litellm_proxy():
from httpx import Response
from litellm.litellm_core_utils.exception_mapping_utils import (
extract_and_raise_litellm_exception,
)
args = {
"response": Response(status_code=400, text="Bad Request"),
"error_str": "Error code: 400 - {'error': {'message': \"litellm.ContextWindowExceededError: litellm.BadRequestError: this is a mock context window exceeded error\\nmodel=gpt-3.5-turbo. context_window_fallbacks=None. fallbacks=None.\\n\\nSet 'context_window_fallback' - https://docs.litellm.ai/docs/routing#fallbacks\\nReceived Model Group=gpt-3.5-turbo\\nAvailable Model Group Fallbacks=None\", 'type': None, 'param': None, 'code': '400'}}",
"model": "gpt-3.5-turbo",
"custom_llm_provider": "litellm_proxy",
}
with pytest.raises(litellm.ContextWindowExceededError):
extract_and_raise_litellm_exception(**args)