This commit is contained in:
zishaansunderji 2025-04-24 00:56:01 -07:00 committed by GitHub
commit 42f3c922f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 390 additions and 6 deletions

View file

@ -1046,6 +1046,7 @@ from .exceptions import (
JSONSchemaValidationError,
LITELLM_EXCEPTION_TYPES,
MockException,
MidStreamFallbackError,
)
from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server

View file

@ -550,6 +550,67 @@ class InternalServerError(openai.InternalServerError): # type: ignore
return _message
class MidStreamFallbackError(ServiceUnavailableError): # type: ignore
def __init__(
self,
message: str,
model: str,
llm_provider: str,
original_exception: Optional[Exception] = None,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
generated_content: str = "",
is_pre_first_chunk: bool = False,
):
self.status_code = 503 # Service Unavailable
self.message = f"litellm.MidStreamFallbackError: {message}"
self.model = model
self.llm_provider = llm_provider
self.original_exception = original_exception
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
self.generated_content = generated_content
self.is_pre_first_chunk = is_pre_first_chunk
# Create a response if one wasn't provided
if response is None:
self.response = httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="POST",
url=f"https://{llm_provider}.com/v1/",
),
)
else:
self.response = response
# Call the parent constructor
super().__init__(
message=self.message,
llm_provider=llm_provider,
model=model,
response=self.response,
litellm_debug_info=self.litellm_debug_info,
max_retries=self.max_retries,
num_retries=self.num_retries,
)
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
if self.original_exception:
_message += f" Original exception: {type(self.original_exception).__name__}: {str(self.original_exception)}"
return _message
def __repr__(self):
return self.__str__()
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
class APIError(openai.APIError): # type: ignore
def __init__(
@ -735,6 +796,7 @@ LITELLM_EXCEPTION_TYPES = [
OpenAIError,
InternalServerError,
JSONSchemaValidationError,
MidStreamFallbackError,
]

View file

@ -15,6 +15,7 @@ from ..exceptions import (
ContentPolicyViolationError,
ContextWindowExceededError,
InternalServerError,
MidStreamFallbackError,
NotFoundError,
PermissionDeniedError,
RateLimitError,
@ -486,13 +487,32 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model,
llm_provider="anthropic",
)
elif "overloaded_error" in error_str:
elif ("overloaded" in error_str.lower() or
(hasattr(original_exception, "status_code") and
original_exception.status_code == 529)):
exception_mapping_worked = True
raise InternalServerError(
message="AnthropicError - {}".format(error_str),
model=model,
llm_provider="anthropic",
)
streaming_obj = getattr(original_exception, "streaming_obj", None)
generated_content = ""
is_pre_first_chunk = True
if streaming_obj:
is_pre_first_chunk = getattr(streaming_obj, "sent_first_chunk", False) is False
generated_content = getattr(streaming_obj, "response_uptil_now", "")
raise MidStreamFallbackError(
message="AnthropicException - {}".format(error_str),
model=model,
custom_llm_provider="anthropic",
original_exception=original_exception,
generated_content=generated_content,
is_pre_first_chunk=is_pre_first_chunk,
)
else:
raise InternalServerError(
message="AnthropicException - {}".format(error_str),
model=model,
llm_provider="anthropic",
)
if "Invalid API Key" in error_str:
exception_mapping_worked = True
raise AuthenticationError(

View file

@ -779,6 +779,22 @@ class CustomStreamWrapper:
is_chunk_non_empty
): # cannot set content of an OpenAI Object to be an empty string
self.safety_checker()
# Check if this is a response from a mid-stream fallback
is_mid_stream_fallback = (
hasattr(model_response, "_hidden_params") and
isinstance(model_response._hidden_params, dict) and
model_response._hidden_params.get("metadata", {}).get("mid_stream_fallback", False) or
model_response._hidden_params.get("additional_headers", {}).get("x-litellm-mid-stream-fallback", False)
)
if is_mid_stream_fallback and self.sent_first_chunk is False:
# Skip sending the role in the first chunk since it was sent by the original call
if hasattr(model_response.choices[0].delta, "role"):
del model_response.choices[0].delta.role
self.sent_first_chunk = True
hold, model_response_str = self.check_special_tokens(
chunk=completion_obj["content"],
finish_reason=model_response.choices[0].finish_reason,
@ -1614,6 +1630,24 @@ class CustomStreamWrapper:
threading.Thread(
target=self.logging_obj.failure_handler, args=(e, traceback_exception)
).start()
if (
(isinstance(e, litellm.InternalServerError) and "overloaded" in str(e).lower()) or
(isinstance(e, Exception) and "overloaded" in str(e).lower() and self.custom_llm_provider == "anthropic")
):
e = litellm.MidStreamFallbackError(
message=str(e),
model=self.model,
llm_provider=self.custom_llm_provider or "anthropic",
original_exception=e,
generated_content=self.response_uptil_now,
is_pre_first_chunk=not self.sent_first_chunk,
)
setattr(e, "streaming_obj", self)
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(
target=self.logging_obj.failure_handler, args=(e, traceback_exception),
).start()
if isinstance(e, OpenAIError):
raise e
else:
@ -1803,6 +1837,24 @@ class CustomStreamWrapper:
raise e
except Exception as e:
traceback_exception = traceback.format_exc()
if (
(isinstance(e, litellm.InternalServerError) and "overloaded" in str(e).lower()) or
(isinstance(e, Exception) and "overloaded" in str(e).lower() and self.custom_llm_provider == "anthropic")
):
e = litellm.MidStreamFallbackError(
message=str(e),
model=self.model,
llm_provider=self.custom_llm_provider or "anthropic",
original_exception=e,
generated_content=self.response_uptil_now,
is_pre_first_chunk=not self.sent_first_chunk,
)
setattr(e, "streaming_obj", self)
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(
target=self.logging_obj.failure_handler, args=(e, traceback_exception),
).start()
if self.logging_obj is not None:
## LOGGING
threading.Thread(

View file

@ -3250,6 +3250,45 @@ class Router:
return response
if isinstance(e, litellm.MidStreamFallbackError):
fallbacks_source = fallbacks
if fallbacks_source is None:
fallbacks_source = []
verbose_router_logger.info(
f"Got MidStreamFallbackError. Pre-first chunk: {e.is_pre_first_chunk}, Generated content: {len(e.generated_content)} chars"
)
fallback_model_group: Optional[
List[str]
] = self._get_fallback_model_group_from_fallbacks(
fallbacks=fallbacks_source,
model_group=model_group,
)
if fallback_model_group is None:
raise original_exception
input_kwargs.update(
{
"fallback_model_group": fallback_model_group,
"original_model_group": original_model_group,
"is_mid_stream_fallback": True,
"previous_content": e.generated_content,
}
)
verbose_router_logger.info(
f"Attempting {'pre-stream' if e.is_pre_first_chunk else 'mid-stream'} fallback. "
f"Already generated: {len(e.generated_content)} characters"
)
response = await run_async_fallback(
*args,
**input_kwargs,
)
return response
if isinstance(e, litellm.ContextWindowExceededError):
if context_window_fallbacks is not None:
fallback_model_group: Optional[List[str]] = (

View file

@ -119,6 +119,41 @@ async def run_async_fallback(
error_from_fallbacks = original_exception
# Handle mid-stream fallbacks by preserving already generated content
is_mid_stream = kwargs.pop("is_mid_stream_fallback", False)
previous_content = kwargs.pop("previous_content", "")
# If this is a mid-stream fallback and we have previous content, prepare messages
if is_mid_stream and previous_content and "messages" in kwargs:
messages = kwargs.get("messages", [])
if isinstance(messages, list) and len(messages) > 0:
if previous_content.strip():
# Check for a system message
system_msg_idx = None
for i, msg in enumerate(messages):
if msg.get("role") == "system":
system_msg_idx = i
break
continuation_text = f"The following is the beginning of an assistant's response. Continue from where it left off: '{previous_content}'"
if system_msg_idx is not None:
# Append to existing system message
messages[system_msg_idx]["content"] = messages[system_msg_idx].get("content", "") + "\n\n" + continuation_text
else:
# Add a new system message
messages.insert(0, {"role": "assistant", "content": continuation_text})
# Update kwargs with modified messages
kwargs["messages"] = messages
# Add to metadata to track this was a mid-stream fallback
kwargs.setdefault("metadata", {}).update({
"is_mid_stream_fallback": True,
"fallback_depth": fallback_depth,
"previous_content_length": len(previous_content)
})
for mg in fallback_model_group:
if mg == original_model_group:
continue
@ -139,11 +174,22 @@ async def run_async_fallback(
response = await litellm_router.async_function_with_fallbacks(
*args, **kwargs
)
if hasattr(response, "_hidden_params"):
response._hidden_params.setdefault("metadata", {})["mid_stream_fallback"] = True
# Also add to additional_headers for header propagation
response._hidden_params.setdefault("additional_headers", {})["x-litellm-mid-stream-fallback"] = True
verbose_router_logger.info("Successful fallback b/w models.")
response = add_fallback_headers_to_response(
response=response,
attempted_fallbacks=fallback_depth,
)
# If this was a mid-stream fallback, also add that to response headers
if is_mid_stream and hasattr(response, "_hidden_params"):
response._hidden_params.setdefault("additional_headers", {})
response._hidden_params["additional_headers"]["x-litellm-mid-stream-fallback"] = True
response._hidden_params["additional_headers"]["x-litellm-previous-content-length"] = len(previous_content)
# callback for successfull_fallback_event():
await log_success_fallback_event(
original_model_group=original_model_group,

View file

@ -0,0 +1,164 @@
import os
import sys
from typing import List
from dotenv import load_dotenv
from litellm.exceptions import ServiceUnavailableError, MidStreamFallbackError
from litellm.types.llms.openai import AllMessageValues
from litellm.router import Router
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import litellm
from litellm import completion, completion_cost, embedding
from litellm.types.utils import Delta, StreamingChoices
litellm.set_verbose = True
litellm.success_callback = []
user_message = "Write a short poem about the sky"
messages: List[AllMessageValues] = [{"content": user_message, "role": "user"}]
@pytest.mark.asyncio
async def test_anthropic_overload_fallback():
"""
Test that when an Anthropic model fails mid-stream, it can fallback to another model
"""
# Create a router with Claude model and a fallback
_router = Router(
model_list=[
{
"model_name": "claude-anthropic",
"litellm_params": {
"model": "anthropic/claude-3-5-sonnet-20240620",
"api_key": os.environ.get("ANTHROPIC_API_KEY", "fake-key"),
},
},
{
"model_name": "claude-aws",
"litellm_params": {
"model": "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
"api_key": os.environ.get("AWS_ACCESS_KEY_ID", "fake-key"),
},
},
],
fallbacks=[{"claude-anthropic": ["claude-aws"]}],
)
# Messages for testing
messages = [{"role": "user", "content": "Tell me about yourself"}]
# Patch acompletion to simulate both the error and the fallback
with patch('litellm.acompletion') as mock_acompletion:
call_count = 0
async def mock_completion(*args, **kwargs):
nonlocal call_count
call_count += 1
model = kwargs.get("model", "")
is_fallback = kwargs.get("metadata", {}).get("mid_stream_fallback", False)
if call_count == 1: # First call - original model
# Return a generator that will raise an error
async def error_generator():
# First yield some content
for i in range(3):
chunk = litellm.ModelResponse(
id=f"chatcmpl-test-{i}",
choices=[
StreamingChoices(
delta=Delta(
content=f"Token {i} ",
role="assistant" if i == 0 else None,
),
index=0,
)
],
model="anthropic/claude-3-5-sonnet-20240620",
)
yield chunk
# Then raise the error
raise ServiceUnavailableError(
message="AnthropicException - Overloaded. Handle with litellm.InternalServerError.",
model="anthropic/claude-3-5-sonnet-20240620",
llm_provider="anthropic",
)
return error_generator()
else: # Second call - fallback model
# Return a successful generator
async def success_generator():
for i in range(3):
chunk = litellm.ModelResponse(
id=f"chatcmpl-fallback-{i}",
choices=[
StreamingChoices(
delta=Delta(
content=f"Fallback token {i} ",
role="assistant" if i == 0 else None,
),
index=0,
)
],
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
)
# Add fallback header
chunk._hidden_params = {
"additional_headers": {
"x-litellm-fallback-used": True
}
}
yield chunk
return success_generator()
mock_acompletion.side_effect = mock_completion
# Execute the test
chunks = []
full_content = ""
try:
async for chunk in await _router.acompletion(
model="claude-anthropic",
messages=messages,
stream=True,
):
chunks.append(chunk)
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
full_content += chunk.choices[0].delta.content
# Verify we got chunks from both models
print(f"Full content: {full_content}")
assert "Token" in full_content, "Should contain content from the original model"
assert "Fallback token" in full_content, "Should contain content from the fallback model"
# Verify at least one chunk has fallback headers
has_fallback_header = False
for chunk in chunks:
if (hasattr(chunk, "_hidden_params") and
chunk._hidden_params.get("additional_headers", {}).get("x-litellm-fallback-used", False)):
has_fallback_header = True
break
assert has_fallback_header, "Should have fallback headers"
print(f"Test passed! Mid-stream fallback worked correctly. Full content: {full_content}")
except Exception as e:
print(f"Error during streaming: {e}")
# Print additional information for debugging
print(f"Number of chunks: {len(chunks)}")
for i, chunk in enumerate(chunks):
print(f"Chunk {i}: {chunk}")
raise