mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge 64ec1bb016
into b82af5b826
This commit is contained in:
commit
42f3c922f4
7 changed files with 390 additions and 6 deletions
|
@ -1046,6 +1046,7 @@ from .exceptions import (
|
||||||
JSONSchemaValidationError,
|
JSONSchemaValidationError,
|
||||||
LITELLM_EXCEPTION_TYPES,
|
LITELLM_EXCEPTION_TYPES,
|
||||||
MockException,
|
MockException,
|
||||||
|
MidStreamFallbackError,
|
||||||
)
|
)
|
||||||
from .budget_manager import BudgetManager
|
from .budget_manager import BudgetManager
|
||||||
from .proxy.proxy_cli import run_server
|
from .proxy.proxy_cli import run_server
|
||||||
|
|
|
@ -550,6 +550,67 @@ class InternalServerError(openai.InternalServerError): # type: ignore
|
||||||
return _message
|
return _message
|
||||||
|
|
||||||
|
|
||||||
|
class MidStreamFallbackError(ServiceUnavailableError): # type: ignore
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
model: str,
|
||||||
|
llm_provider: str,
|
||||||
|
original_exception: Optional[Exception] = None,
|
||||||
|
response: Optional[httpx.Response] = None,
|
||||||
|
litellm_debug_info: Optional[str] = None,
|
||||||
|
max_retries: Optional[int] = None,
|
||||||
|
num_retries: Optional[int] = None,
|
||||||
|
generated_content: str = "",
|
||||||
|
is_pre_first_chunk: bool = False,
|
||||||
|
):
|
||||||
|
self.status_code = 503 # Service Unavailable
|
||||||
|
self.message = f"litellm.MidStreamFallbackError: {message}"
|
||||||
|
self.model = model
|
||||||
|
self.llm_provider = llm_provider
|
||||||
|
self.original_exception = original_exception
|
||||||
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.num_retries = num_retries
|
||||||
|
self.generated_content = generated_content
|
||||||
|
self.is_pre_first_chunk = is_pre_first_chunk
|
||||||
|
|
||||||
|
# Create a response if one wasn't provided
|
||||||
|
if response is None:
|
||||||
|
self.response = httpx.Response(
|
||||||
|
status_code=self.status_code,
|
||||||
|
request=httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url=f"https://{llm_provider}.com/v1/",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.response = response
|
||||||
|
|
||||||
|
# Call the parent constructor
|
||||||
|
super().__init__(
|
||||||
|
message=self.message,
|
||||||
|
llm_provider=llm_provider,
|
||||||
|
model=model,
|
||||||
|
response=self.response,
|
||||||
|
litellm_debug_info=self.litellm_debug_info,
|
||||||
|
max_retries=self.max_retries,
|
||||||
|
num_retries=self.num_retries,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
_message = self.message
|
||||||
|
if self.num_retries:
|
||||||
|
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||||
|
if self.max_retries:
|
||||||
|
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||||
|
if self.original_exception:
|
||||||
|
_message += f" Original exception: {type(self.original_exception).__name__}: {str(self.original_exception)}"
|
||||||
|
return _message
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
|
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
|
||||||
class APIError(openai.APIError): # type: ignore
|
class APIError(openai.APIError): # type: ignore
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -735,6 +796,7 @@ LITELLM_EXCEPTION_TYPES = [
|
||||||
OpenAIError,
|
OpenAIError,
|
||||||
InternalServerError,
|
InternalServerError,
|
||||||
JSONSchemaValidationError,
|
JSONSchemaValidationError,
|
||||||
|
MidStreamFallbackError,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ from ..exceptions import (
|
||||||
ContentPolicyViolationError,
|
ContentPolicyViolationError,
|
||||||
ContextWindowExceededError,
|
ContextWindowExceededError,
|
||||||
InternalServerError,
|
InternalServerError,
|
||||||
|
MidStreamFallbackError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
PermissionDeniedError,
|
PermissionDeniedError,
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
|
@ -486,13 +487,32 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="anthropic",
|
llm_provider="anthropic",
|
||||||
)
|
)
|
||||||
elif "overloaded_error" in error_str:
|
elif ("overloaded" in error_str.lower() or
|
||||||
|
(hasattr(original_exception, "status_code") and
|
||||||
|
original_exception.status_code == 529)):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise InternalServerError(
|
|
||||||
message="AnthropicError - {}".format(error_str),
|
streaming_obj = getattr(original_exception, "streaming_obj", None)
|
||||||
model=model,
|
generated_content = ""
|
||||||
llm_provider="anthropic",
|
is_pre_first_chunk = True
|
||||||
)
|
if streaming_obj:
|
||||||
|
is_pre_first_chunk = getattr(streaming_obj, "sent_first_chunk", False) is False
|
||||||
|
generated_content = getattr(streaming_obj, "response_uptil_now", "")
|
||||||
|
|
||||||
|
raise MidStreamFallbackError(
|
||||||
|
message="AnthropicException - {}".format(error_str),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="anthropic",
|
||||||
|
original_exception=original_exception,
|
||||||
|
generated_content=generated_content,
|
||||||
|
is_pre_first_chunk=is_pre_first_chunk,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise InternalServerError(
|
||||||
|
message="AnthropicException - {}".format(error_str),
|
||||||
|
model=model,
|
||||||
|
llm_provider="anthropic",
|
||||||
|
)
|
||||||
if "Invalid API Key" in error_str:
|
if "Invalid API Key" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise AuthenticationError(
|
raise AuthenticationError(
|
||||||
|
|
|
@ -779,6 +779,22 @@ class CustomStreamWrapper:
|
||||||
is_chunk_non_empty
|
is_chunk_non_empty
|
||||||
): # cannot set content of an OpenAI Object to be an empty string
|
): # cannot set content of an OpenAI Object to be an empty string
|
||||||
self.safety_checker()
|
self.safety_checker()
|
||||||
|
|
||||||
|
# Check if this is a response from a mid-stream fallback
|
||||||
|
is_mid_stream_fallback = (
|
||||||
|
hasattr(model_response, "_hidden_params") and
|
||||||
|
isinstance(model_response._hidden_params, dict) and
|
||||||
|
model_response._hidden_params.get("metadata", {}).get("mid_stream_fallback", False) or
|
||||||
|
model_response._hidden_params.get("additional_headers", {}).get("x-litellm-mid-stream-fallback", False)
|
||||||
|
)
|
||||||
|
if is_mid_stream_fallback and self.sent_first_chunk is False:
|
||||||
|
# Skip sending the role in the first chunk since it was sent by the original call
|
||||||
|
if hasattr(model_response.choices[0].delta, "role"):
|
||||||
|
del model_response.choices[0].delta.role
|
||||||
|
|
||||||
|
self.sent_first_chunk = True
|
||||||
|
|
||||||
|
|
||||||
hold, model_response_str = self.check_special_tokens(
|
hold, model_response_str = self.check_special_tokens(
|
||||||
chunk=completion_obj["content"],
|
chunk=completion_obj["content"],
|
||||||
finish_reason=model_response.choices[0].finish_reason,
|
finish_reason=model_response.choices[0].finish_reason,
|
||||||
|
@ -1614,6 +1630,24 @@ class CustomStreamWrapper:
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.logging_obj.failure_handler, args=(e, traceback_exception)
|
target=self.logging_obj.failure_handler, args=(e, traceback_exception)
|
||||||
).start()
|
).start()
|
||||||
|
if (
|
||||||
|
(isinstance(e, litellm.InternalServerError) and "overloaded" in str(e).lower()) or
|
||||||
|
(isinstance(e, Exception) and "overloaded" in str(e).lower() and self.custom_llm_provider == "anthropic")
|
||||||
|
):
|
||||||
|
e = litellm.MidStreamFallbackError(
|
||||||
|
message=str(e),
|
||||||
|
model=self.model,
|
||||||
|
llm_provider=self.custom_llm_provider or "anthropic",
|
||||||
|
original_exception=e,
|
||||||
|
generated_content=self.response_uptil_now,
|
||||||
|
is_pre_first_chunk=not self.sent_first_chunk,
|
||||||
|
)
|
||||||
|
setattr(e, "streaming_obj", self)
|
||||||
|
|
||||||
|
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||||
|
threading.Thread(
|
||||||
|
target=self.logging_obj.failure_handler, args=(e, traceback_exception),
|
||||||
|
).start()
|
||||||
if isinstance(e, OpenAIError):
|
if isinstance(e, OpenAIError):
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
|
@ -1803,6 +1837,24 @@ class CustomStreamWrapper:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback_exception = traceback.format_exc()
|
traceback_exception = traceback.format_exc()
|
||||||
|
if (
|
||||||
|
(isinstance(e, litellm.InternalServerError) and "overloaded" in str(e).lower()) or
|
||||||
|
(isinstance(e, Exception) and "overloaded" in str(e).lower() and self.custom_llm_provider == "anthropic")
|
||||||
|
):
|
||||||
|
e = litellm.MidStreamFallbackError(
|
||||||
|
message=str(e),
|
||||||
|
model=self.model,
|
||||||
|
llm_provider=self.custom_llm_provider or "anthropic",
|
||||||
|
original_exception=e,
|
||||||
|
generated_content=self.response_uptil_now,
|
||||||
|
is_pre_first_chunk=not self.sent_first_chunk,
|
||||||
|
)
|
||||||
|
setattr(e, "streaming_obj", self)
|
||||||
|
|
||||||
|
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||||
|
threading.Thread(
|
||||||
|
target=self.logging_obj.failure_handler, args=(e, traceback_exception),
|
||||||
|
).start()
|
||||||
if self.logging_obj is not None:
|
if self.logging_obj is not None:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
|
|
|
@ -3250,6 +3250,45 @@ class Router:
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
if isinstance(e, litellm.MidStreamFallbackError):
|
||||||
|
fallbacks_source = fallbacks
|
||||||
|
if fallbacks_source is None:
|
||||||
|
fallbacks_source = []
|
||||||
|
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"Got MidStreamFallbackError. Pre-first chunk: {e.is_pre_first_chunk}, Generated content: {len(e.generated_content)} chars"
|
||||||
|
)
|
||||||
|
|
||||||
|
fallback_model_group: Optional[
|
||||||
|
List[str]
|
||||||
|
] = self._get_fallback_model_group_from_fallbacks(
|
||||||
|
fallbacks=fallbacks_source,
|
||||||
|
model_group=model_group,
|
||||||
|
)
|
||||||
|
if fallback_model_group is None:
|
||||||
|
raise original_exception
|
||||||
|
|
||||||
|
input_kwargs.update(
|
||||||
|
{
|
||||||
|
"fallback_model_group": fallback_model_group,
|
||||||
|
"original_model_group": original_model_group,
|
||||||
|
"is_mid_stream_fallback": True,
|
||||||
|
"previous_content": e.generated_content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"Attempting {'pre-stream' if e.is_pre_first_chunk else 'mid-stream'} fallback. "
|
||||||
|
f"Already generated: {len(e.generated_content)} characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await run_async_fallback(
|
||||||
|
*args,
|
||||||
|
**input_kwargs,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
if isinstance(e, litellm.ContextWindowExceededError):
|
if isinstance(e, litellm.ContextWindowExceededError):
|
||||||
if context_window_fallbacks is not None:
|
if context_window_fallbacks is not None:
|
||||||
fallback_model_group: Optional[List[str]] = (
|
fallback_model_group: Optional[List[str]] = (
|
||||||
|
|
|
@ -119,6 +119,41 @@ async def run_async_fallback(
|
||||||
|
|
||||||
error_from_fallbacks = original_exception
|
error_from_fallbacks = original_exception
|
||||||
|
|
||||||
|
# Handle mid-stream fallbacks by preserving already generated content
|
||||||
|
is_mid_stream = kwargs.pop("is_mid_stream_fallback", False)
|
||||||
|
previous_content = kwargs.pop("previous_content", "")
|
||||||
|
|
||||||
|
# If this is a mid-stream fallback and we have previous content, prepare messages
|
||||||
|
if is_mid_stream and previous_content and "messages" in kwargs:
|
||||||
|
messages = kwargs.get("messages", [])
|
||||||
|
|
||||||
|
if isinstance(messages, list) and len(messages) > 0:
|
||||||
|
if previous_content.strip():
|
||||||
|
# Check for a system message
|
||||||
|
system_msg_idx = None
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
if msg.get("role") == "system":
|
||||||
|
system_msg_idx = i
|
||||||
|
break
|
||||||
|
|
||||||
|
continuation_text = f"The following is the beginning of an assistant's response. Continue from where it left off: '{previous_content}'"
|
||||||
|
|
||||||
|
if system_msg_idx is not None:
|
||||||
|
# Append to existing system message
|
||||||
|
messages[system_msg_idx]["content"] = messages[system_msg_idx].get("content", "") + "\n\n" + continuation_text
|
||||||
|
else:
|
||||||
|
# Add a new system message
|
||||||
|
messages.insert(0, {"role": "assistant", "content": continuation_text})
|
||||||
|
|
||||||
|
# Update kwargs with modified messages
|
||||||
|
kwargs["messages"] = messages
|
||||||
|
# Add to metadata to track this was a mid-stream fallback
|
||||||
|
kwargs.setdefault("metadata", {}).update({
|
||||||
|
"is_mid_stream_fallback": True,
|
||||||
|
"fallback_depth": fallback_depth,
|
||||||
|
"previous_content_length": len(previous_content)
|
||||||
|
})
|
||||||
|
|
||||||
for mg in fallback_model_group:
|
for mg in fallback_model_group:
|
||||||
if mg == original_model_group:
|
if mg == original_model_group:
|
||||||
continue
|
continue
|
||||||
|
@ -139,11 +174,22 @@ async def run_async_fallback(
|
||||||
response = await litellm_router.async_function_with_fallbacks(
|
response = await litellm_router.async_function_with_fallbacks(
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
)
|
)
|
||||||
|
if hasattr(response, "_hidden_params"):
|
||||||
|
response._hidden_params.setdefault("metadata", {})["mid_stream_fallback"] = True
|
||||||
|
# Also add to additional_headers for header propagation
|
||||||
|
response._hidden_params.setdefault("additional_headers", {})["x-litellm-mid-stream-fallback"] = True
|
||||||
verbose_router_logger.info("Successful fallback b/w models.")
|
verbose_router_logger.info("Successful fallback b/w models.")
|
||||||
response = add_fallback_headers_to_response(
|
response = add_fallback_headers_to_response(
|
||||||
response=response,
|
response=response,
|
||||||
attempted_fallbacks=fallback_depth,
|
attempted_fallbacks=fallback_depth,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If this was a mid-stream fallback, also add that to response headers
|
||||||
|
if is_mid_stream and hasattr(response, "_hidden_params"):
|
||||||
|
response._hidden_params.setdefault("additional_headers", {})
|
||||||
|
response._hidden_params["additional_headers"]["x-litellm-mid-stream-fallback"] = True
|
||||||
|
response._hidden_params["additional_headers"]["x-litellm-previous-content-length"] = len(previous_content)
|
||||||
|
|
||||||
# callback for successfull_fallback_event():
|
# callback for successfull_fallback_event():
|
||||||
await log_success_fallback_event(
|
await log_success_fallback_event(
|
||||||
original_model_group=original_model_group,
|
original_model_group=original_model_group,
|
||||||
|
|
164
tests/local_testing/test_anthropic_overload_fallback.py
Normal file
164
tests/local_testing/test_anthropic_overload_fallback.py
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import List
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from litellm.exceptions import ServiceUnavailableError, MidStreamFallbackError
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.router import Router
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import completion, completion_cost, embedding
|
||||||
|
from litellm.types.utils import Delta, StreamingChoices
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
litellm.success_callback = []
|
||||||
|
user_message = "Write a short poem about the sky"
|
||||||
|
messages: List[AllMessageValues] = [{"content": user_message, "role": "user"}]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_overload_fallback():
|
||||||
|
"""
|
||||||
|
Test that when an Anthropic model fails mid-stream, it can fallback to another model
|
||||||
|
"""
|
||||||
|
# Create a router with Claude model and a fallback
|
||||||
|
_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "claude-anthropic",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
"api_key": os.environ.get("ANTHROPIC_API_KEY", "fake-key"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "claude-aws",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"api_key": os.environ.get("AWS_ACCESS_KEY_ID", "fake-key"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
fallbacks=[{"claude-anthropic": ["claude-aws"]}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Messages for testing
|
||||||
|
messages = [{"role": "user", "content": "Tell me about yourself"}]
|
||||||
|
|
||||||
|
# Patch acompletion to simulate both the error and the fallback
|
||||||
|
with patch('litellm.acompletion') as mock_acompletion:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def mock_completion(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
model = kwargs.get("model", "")
|
||||||
|
is_fallback = kwargs.get("metadata", {}).get("mid_stream_fallback", False)
|
||||||
|
|
||||||
|
if call_count == 1: # First call - original model
|
||||||
|
# Return a generator that will raise an error
|
||||||
|
async def error_generator():
|
||||||
|
# First yield some content
|
||||||
|
for i in range(3):
|
||||||
|
chunk = litellm.ModelResponse(
|
||||||
|
id=f"chatcmpl-test-{i}",
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
delta=Delta(
|
||||||
|
content=f"Token {i} ",
|
||||||
|
role="assistant" if i == 0 else None,
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Then raise the error
|
||||||
|
raise ServiceUnavailableError(
|
||||||
|
message="AnthropicException - Overloaded. Handle with litellm.InternalServerError.",
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
llm_provider="anthropic",
|
||||||
|
)
|
||||||
|
|
||||||
|
return error_generator()
|
||||||
|
|
||||||
|
else: # Second call - fallback model
|
||||||
|
# Return a successful generator
|
||||||
|
async def success_generator():
|
||||||
|
for i in range(3):
|
||||||
|
chunk = litellm.ModelResponse(
|
||||||
|
id=f"chatcmpl-fallback-{i}",
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
delta=Delta(
|
||||||
|
content=f"Fallback token {i} ",
|
||||||
|
role="assistant" if i == 0 else None,
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
)
|
||||||
|
# Add fallback header
|
||||||
|
chunk._hidden_params = {
|
||||||
|
"additional_headers": {
|
||||||
|
"x-litellm-fallback-used": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return success_generator()
|
||||||
|
|
||||||
|
mock_acompletion.side_effect = mock_completion
|
||||||
|
|
||||||
|
# Execute the test
|
||||||
|
chunks = []
|
||||||
|
full_content = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in await _router.acompletion(
|
||||||
|
model="claude-anthropic",
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
):
|
||||||
|
chunks.append(chunk)
|
||||||
|
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
|
||||||
|
full_content += chunk.choices[0].delta.content
|
||||||
|
|
||||||
|
# Verify we got chunks from both models
|
||||||
|
print(f"Full content: {full_content}")
|
||||||
|
assert "Token" in full_content, "Should contain content from the original model"
|
||||||
|
assert "Fallback token" in full_content, "Should contain content from the fallback model"
|
||||||
|
|
||||||
|
# Verify at least one chunk has fallback headers
|
||||||
|
has_fallback_header = False
|
||||||
|
for chunk in chunks:
|
||||||
|
if (hasattr(chunk, "_hidden_params") and
|
||||||
|
chunk._hidden_params.get("additional_headers", {}).get("x-litellm-fallback-used", False)):
|
||||||
|
has_fallback_header = True
|
||||||
|
break
|
||||||
|
|
||||||
|
assert has_fallback_header, "Should have fallback headers"
|
||||||
|
|
||||||
|
print(f"Test passed! Mid-stream fallback worked correctly. Full content: {full_content}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during streaming: {e}")
|
||||||
|
# Print additional information for debugging
|
||||||
|
print(f"Number of chunks: {len(chunks)}")
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
print(f"Chunk {i}: {chunk}")
|
||||||
|
raise
|
Loading…
Add table
Add a link
Reference in a new issue