mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge 64ec1bb016
into b82af5b826
This commit is contained in:
commit
42f3c922f4
7 changed files with 390 additions and 6 deletions
|
@ -1046,6 +1046,7 @@ from .exceptions import (
|
|||
JSONSchemaValidationError,
|
||||
LITELLM_EXCEPTION_TYPES,
|
||||
MockException,
|
||||
MidStreamFallbackError,
|
||||
)
|
||||
from .budget_manager import BudgetManager
|
||||
from .proxy.proxy_cli import run_server
|
||||
|
|
|
@ -550,6 +550,67 @@ class InternalServerError(openai.InternalServerError): # type: ignore
|
|||
return _message
|
||||
|
||||
|
||||
class MidStreamFallbackError(ServiceUnavailableError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
model: str,
|
||||
llm_provider: str,
|
||||
original_exception: Optional[Exception] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
generated_content: str = "",
|
||||
is_pre_first_chunk: bool = False,
|
||||
):
|
||||
self.status_code = 503 # Service Unavailable
|
||||
self.message = f"litellm.MidStreamFallbackError: {message}"
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.original_exception = original_exception
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
self.generated_content = generated_content
|
||||
self.is_pre_first_chunk = is_pre_first_chunk
|
||||
|
||||
# Create a response if one wasn't provided
|
||||
if response is None:
|
||||
self.response = httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
method="POST",
|
||||
url=f"https://{llm_provider}.com/v1/",
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.response = response
|
||||
|
||||
# Call the parent constructor
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
llm_provider=llm_provider,
|
||||
model=model,
|
||||
response=self.response,
|
||||
litellm_debug_info=self.litellm_debug_info,
|
||||
max_retries=self.max_retries,
|
||||
num_retries=self.num_retries,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
if self.original_exception:
|
||||
_message += f" Original exception: {type(self.original_exception).__name__}: {str(self.original_exception)}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
|
||||
class APIError(openai.APIError): # type: ignore
|
||||
def __init__(
|
||||
|
@ -735,6 +796,7 @@ LITELLM_EXCEPTION_TYPES = [
|
|||
OpenAIError,
|
||||
InternalServerError,
|
||||
JSONSchemaValidationError,
|
||||
MidStreamFallbackError,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from ..exceptions import (
|
|||
ContentPolicyViolationError,
|
||||
ContextWindowExceededError,
|
||||
InternalServerError,
|
||||
MidStreamFallbackError,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
|
@ -486,13 +487,32 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
model=model,
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
elif "overloaded_error" in error_str:
|
||||
elif ("overloaded" in error_str.lower() or
|
||||
(hasattr(original_exception, "status_code") and
|
||||
original_exception.status_code == 529)):
|
||||
exception_mapping_worked = True
|
||||
raise InternalServerError(
|
||||
message="AnthropicError - {}".format(error_str),
|
||||
model=model,
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
|
||||
streaming_obj = getattr(original_exception, "streaming_obj", None)
|
||||
generated_content = ""
|
||||
is_pre_first_chunk = True
|
||||
if streaming_obj:
|
||||
is_pre_first_chunk = getattr(streaming_obj, "sent_first_chunk", False) is False
|
||||
generated_content = getattr(streaming_obj, "response_uptil_now", "")
|
||||
|
||||
raise MidStreamFallbackError(
|
||||
message="AnthropicException - {}".format(error_str),
|
||||
model=model,
|
||||
custom_llm_provider="anthropic",
|
||||
original_exception=original_exception,
|
||||
generated_content=generated_content,
|
||||
is_pre_first_chunk=is_pre_first_chunk,
|
||||
)
|
||||
else:
|
||||
raise InternalServerError(
|
||||
message="AnthropicException - {}".format(error_str),
|
||||
model=model,
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
if "Invalid API Key" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
|
|
|
@ -779,6 +779,22 @@ class CustomStreamWrapper:
|
|||
is_chunk_non_empty
|
||||
): # cannot set content of an OpenAI Object to be an empty string
|
||||
self.safety_checker()
|
||||
|
||||
# Check if this is a response from a mid-stream fallback
|
||||
is_mid_stream_fallback = (
|
||||
hasattr(model_response, "_hidden_params") and
|
||||
isinstance(model_response._hidden_params, dict) and
|
||||
model_response._hidden_params.get("metadata", {}).get("mid_stream_fallback", False) or
|
||||
model_response._hidden_params.get("additional_headers", {}).get("x-litellm-mid-stream-fallback", False)
|
||||
)
|
||||
if is_mid_stream_fallback and self.sent_first_chunk is False:
|
||||
# Skip sending the role in the first chunk since it was sent by the original call
|
||||
if hasattr(model_response.choices[0].delta, "role"):
|
||||
del model_response.choices[0].delta.role
|
||||
|
||||
self.sent_first_chunk = True
|
||||
|
||||
|
||||
hold, model_response_str = self.check_special_tokens(
|
||||
chunk=completion_obj["content"],
|
||||
finish_reason=model_response.choices[0].finish_reason,
|
||||
|
@ -1614,6 +1630,24 @@ class CustomStreamWrapper:
|
|||
threading.Thread(
|
||||
target=self.logging_obj.failure_handler, args=(e, traceback_exception)
|
||||
).start()
|
||||
if (
|
||||
(isinstance(e, litellm.InternalServerError) and "overloaded" in str(e).lower()) or
|
||||
(isinstance(e, Exception) and "overloaded" in str(e).lower() and self.custom_llm_provider == "anthropic")
|
||||
):
|
||||
e = litellm.MidStreamFallbackError(
|
||||
message=str(e),
|
||||
model=self.model,
|
||||
llm_provider=self.custom_llm_provider or "anthropic",
|
||||
original_exception=e,
|
||||
generated_content=self.response_uptil_now,
|
||||
is_pre_first_chunk=not self.sent_first_chunk,
|
||||
)
|
||||
setattr(e, "streaming_obj", self)
|
||||
|
||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||
threading.Thread(
|
||||
target=self.logging_obj.failure_handler, args=(e, traceback_exception),
|
||||
).start()
|
||||
if isinstance(e, OpenAIError):
|
||||
raise e
|
||||
else:
|
||||
|
@ -1803,6 +1837,24 @@ class CustomStreamWrapper:
|
|||
raise e
|
||||
except Exception as e:
|
||||
traceback_exception = traceback.format_exc()
|
||||
if (
|
||||
(isinstance(e, litellm.InternalServerError) and "overloaded" in str(e).lower()) or
|
||||
(isinstance(e, Exception) and "overloaded" in str(e).lower() and self.custom_llm_provider == "anthropic")
|
||||
):
|
||||
e = litellm.MidStreamFallbackError(
|
||||
message=str(e),
|
||||
model=self.model,
|
||||
llm_provider=self.custom_llm_provider or "anthropic",
|
||||
original_exception=e,
|
||||
generated_content=self.response_uptil_now,
|
||||
is_pre_first_chunk=not self.sent_first_chunk,
|
||||
)
|
||||
setattr(e, "streaming_obj", self)
|
||||
|
||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||
threading.Thread(
|
||||
target=self.logging_obj.failure_handler, args=(e, traceback_exception),
|
||||
).start()
|
||||
if self.logging_obj is not None:
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
|
|
|
@ -3250,6 +3250,45 @@ class Router:
|
|||
|
||||
return response
|
||||
|
||||
|
||||
if isinstance(e, litellm.MidStreamFallbackError):
|
||||
fallbacks_source = fallbacks
|
||||
if fallbacks_source is None:
|
||||
fallbacks_source = []
|
||||
|
||||
verbose_router_logger.info(
|
||||
f"Got MidStreamFallbackError. Pre-first chunk: {e.is_pre_first_chunk}, Generated content: {len(e.generated_content)} chars"
|
||||
)
|
||||
|
||||
fallback_model_group: Optional[
|
||||
List[str]
|
||||
] = self._get_fallback_model_group_from_fallbacks(
|
||||
fallbacks=fallbacks_source,
|
||||
model_group=model_group,
|
||||
)
|
||||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
||||
input_kwargs.update(
|
||||
{
|
||||
"fallback_model_group": fallback_model_group,
|
||||
"original_model_group": original_model_group,
|
||||
"is_mid_stream_fallback": True,
|
||||
"previous_content": e.generated_content,
|
||||
}
|
||||
)
|
||||
|
||||
verbose_router_logger.info(
|
||||
f"Attempting {'pre-stream' if e.is_pre_first_chunk else 'mid-stream'} fallback. "
|
||||
f"Already generated: {len(e.generated_content)} characters"
|
||||
)
|
||||
|
||||
response = await run_async_fallback(
|
||||
*args,
|
||||
**input_kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
if isinstance(e, litellm.ContextWindowExceededError):
|
||||
if context_window_fallbacks is not None:
|
||||
fallback_model_group: Optional[List[str]] = (
|
||||
|
|
|
@ -119,6 +119,41 @@ async def run_async_fallback(
|
|||
|
||||
error_from_fallbacks = original_exception
|
||||
|
||||
# Handle mid-stream fallbacks by preserving already generated content
|
||||
is_mid_stream = kwargs.pop("is_mid_stream_fallback", False)
|
||||
previous_content = kwargs.pop("previous_content", "")
|
||||
|
||||
# If this is a mid-stream fallback and we have previous content, prepare messages
|
||||
if is_mid_stream and previous_content and "messages" in kwargs:
|
||||
messages = kwargs.get("messages", [])
|
||||
|
||||
if isinstance(messages, list) and len(messages) > 0:
|
||||
if previous_content.strip():
|
||||
# Check for a system message
|
||||
system_msg_idx = None
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.get("role") == "system":
|
||||
system_msg_idx = i
|
||||
break
|
||||
|
||||
continuation_text = f"The following is the beginning of an assistant's response. Continue from where it left off: '{previous_content}'"
|
||||
|
||||
if system_msg_idx is not None:
|
||||
# Append to existing system message
|
||||
messages[system_msg_idx]["content"] = messages[system_msg_idx].get("content", "") + "\n\n" + continuation_text
|
||||
else:
|
||||
# Add a new system message
|
||||
messages.insert(0, {"role": "assistant", "content": continuation_text})
|
||||
|
||||
# Update kwargs with modified messages
|
||||
kwargs["messages"] = messages
|
||||
# Add to metadata to track this was a mid-stream fallback
|
||||
kwargs.setdefault("metadata", {}).update({
|
||||
"is_mid_stream_fallback": True,
|
||||
"fallback_depth": fallback_depth,
|
||||
"previous_content_length": len(previous_content)
|
||||
})
|
||||
|
||||
for mg in fallback_model_group:
|
||||
if mg == original_model_group:
|
||||
continue
|
||||
|
@ -139,11 +174,22 @@ async def run_async_fallback(
|
|||
response = await litellm_router.async_function_with_fallbacks(
|
||||
*args, **kwargs
|
||||
)
|
||||
if hasattr(response, "_hidden_params"):
|
||||
response._hidden_params.setdefault("metadata", {})["mid_stream_fallback"] = True
|
||||
# Also add to additional_headers for header propagation
|
||||
response._hidden_params.setdefault("additional_headers", {})["x-litellm-mid-stream-fallback"] = True
|
||||
verbose_router_logger.info("Successful fallback b/w models.")
|
||||
response = add_fallback_headers_to_response(
|
||||
response=response,
|
||||
attempted_fallbacks=fallback_depth,
|
||||
)
|
||||
|
||||
# If this was a mid-stream fallback, also add that to response headers
|
||||
if is_mid_stream and hasattr(response, "_hidden_params"):
|
||||
response._hidden_params.setdefault("additional_headers", {})
|
||||
response._hidden_params["additional_headers"]["x-litellm-mid-stream-fallback"] = True
|
||||
response._hidden_params["additional_headers"]["x-litellm-previous-content-length"] = len(previous_content)
|
||||
|
||||
# callback for successfull_fallback_event():
|
||||
await log_success_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
|
|
164
tests/local_testing/test_anthropic_overload_fallback.py
Normal file
164
tests/local_testing/test_anthropic_overload_fallback.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
from dotenv import load_dotenv
|
||||
from litellm.exceptions import ServiceUnavailableError, MidStreamFallbackError
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.router import Router
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import completion, completion_cost, embedding
|
||||
from litellm.types.utils import Delta, StreamingChoices
|
||||
|
||||
litellm.set_verbose = True
|
||||
litellm.success_callback = []
|
||||
user_message = "Write a short poem about the sky"
|
||||
messages: List[AllMessageValues] = [{"content": user_message, "role": "user"}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_overload_fallback():
|
||||
"""
|
||||
Test that when an Anthropic model fails mid-stream, it can fallback to another model
|
||||
"""
|
||||
# Create a router with Claude model and a fallback
|
||||
_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "claude-anthropic",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||
"api_key": os.environ.get("ANTHROPIC_API_KEY", "fake-key"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "claude-aws",
|
||||
"litellm_params": {
|
||||
"model": "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"api_key": os.environ.get("AWS_ACCESS_KEY_ID", "fake-key"),
|
||||
},
|
||||
},
|
||||
],
|
||||
fallbacks=[{"claude-anthropic": ["claude-aws"]}],
|
||||
)
|
||||
|
||||
# Messages for testing
|
||||
messages = [{"role": "user", "content": "Tell me about yourself"}]
|
||||
|
||||
# Patch acompletion to simulate both the error and the fallback
|
||||
with patch('litellm.acompletion') as mock_acompletion:
|
||||
call_count = 0
|
||||
|
||||
async def mock_completion(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
model = kwargs.get("model", "")
|
||||
is_fallback = kwargs.get("metadata", {}).get("mid_stream_fallback", False)
|
||||
|
||||
if call_count == 1: # First call - original model
|
||||
# Return a generator that will raise an error
|
||||
async def error_generator():
|
||||
# First yield some content
|
||||
for i in range(3):
|
||||
chunk = litellm.ModelResponse(
|
||||
id=f"chatcmpl-test-{i}",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
delta=Delta(
|
||||
content=f"Token {i} ",
|
||||
role="assistant" if i == 0 else None,
|
||||
),
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
)
|
||||
yield chunk
|
||||
|
||||
# Then raise the error
|
||||
raise ServiceUnavailableError(
|
||||
message="AnthropicException - Overloaded. Handle with litellm.InternalServerError.",
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
|
||||
return error_generator()
|
||||
|
||||
else: # Second call - fallback model
|
||||
# Return a successful generator
|
||||
async def success_generator():
|
||||
for i in range(3):
|
||||
chunk = litellm.ModelResponse(
|
||||
id=f"chatcmpl-fallback-{i}",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
delta=Delta(
|
||||
content=f"Fallback token {i} ",
|
||||
role="assistant" if i == 0 else None,
|
||||
),
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
)
|
||||
# Add fallback header
|
||||
chunk._hidden_params = {
|
||||
"additional_headers": {
|
||||
"x-litellm-fallback-used": True
|
||||
}
|
||||
}
|
||||
yield chunk
|
||||
|
||||
return success_generator()
|
||||
|
||||
mock_acompletion.side_effect = mock_completion
|
||||
|
||||
# Execute the test
|
||||
chunks = []
|
||||
full_content = ""
|
||||
|
||||
try:
|
||||
async for chunk in await _router.acompletion(
|
||||
model="claude-anthropic",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
):
|
||||
chunks.append(chunk)
|
||||
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
|
||||
full_content += chunk.choices[0].delta.content
|
||||
|
||||
# Verify we got chunks from both models
|
||||
print(f"Full content: {full_content}")
|
||||
assert "Token" in full_content, "Should contain content from the original model"
|
||||
assert "Fallback token" in full_content, "Should contain content from the fallback model"
|
||||
|
||||
# Verify at least one chunk has fallback headers
|
||||
has_fallback_header = False
|
||||
for chunk in chunks:
|
||||
if (hasattr(chunk, "_hidden_params") and
|
||||
chunk._hidden_params.get("additional_headers", {}).get("x-litellm-fallback-used", False)):
|
||||
has_fallback_header = True
|
||||
break
|
||||
|
||||
assert has_fallback_header, "Should have fallback headers"
|
||||
|
||||
print(f"Test passed! Mid-stream fallback worked correctly. Full content: {full_content}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during streaming: {e}")
|
||||
# Print additional information for debugging
|
||||
print(f"Number of chunks: {len(chunks)}")
|
||||
for i, chunk in enumerate(chunks):
|
||||
print(f"Chunk {i}: {chunk}")
|
||||
raise
|
Loading…
Add table
Add a link
Reference in a new issue