mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Openrouter streaming fixes + Anthropic 'file' message support (#9667)
* fix(openrouter/transformation.py): Handle error in openrouter stream Fixes https://github.com/Aider-AI/aider/issues/3550 * test(test_openrouter_chat_transformation.py): add unit tests * feat(anthropic/chat/transformation.py): add openai 'file' message content type support Closes https://github.com/BerriAI/litellm/issues/9463 * fix(factory.py): add bedrock converse support for openai 'file' message content type Closes https://github.com/BerriAI/litellm/issues/9463
This commit is contained in:
parent
cba4a4abcb
commit
b01de8030b
6 changed files with 243 additions and 21 deletions
|
@ -22,6 +22,7 @@ from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
ChatCompletionAssistantMessage,
|
ChatCompletionAssistantMessage,
|
||||||
ChatCompletionAssistantToolCall,
|
ChatCompletionAssistantToolCall,
|
||||||
|
ChatCompletionFileObject,
|
||||||
ChatCompletionFunctionMessage,
|
ChatCompletionFunctionMessage,
|
||||||
ChatCompletionImageObject,
|
ChatCompletionImageObject,
|
||||||
ChatCompletionTextObject,
|
ChatCompletionTextObject,
|
||||||
|
@ -1455,6 +1456,25 @@ def anthropic_messages_pt( # noqa: PLR0915
|
||||||
user_content.append(_content_element)
|
user_content.append(_content_element)
|
||||||
elif m.get("type", "") == "document":
|
elif m.get("type", "") == "document":
|
||||||
user_content.append(cast(AnthropicMessagesDocumentParam, m))
|
user_content.append(cast(AnthropicMessagesDocumentParam, m))
|
||||||
|
elif m.get("type", "") == "file":
|
||||||
|
file_message = cast(ChatCompletionFileObject, m)
|
||||||
|
file_data = file_message["file"].get("file_data")
|
||||||
|
if file_data:
|
||||||
|
image_chunk = convert_to_anthropic_image_obj(
|
||||||
|
openai_image_url=file_data,
|
||||||
|
format=file_message["file"].get("format"),
|
||||||
|
)
|
||||||
|
anthropic_document_param = (
|
||||||
|
AnthropicMessagesDocumentParam(
|
||||||
|
type="document",
|
||||||
|
source=AnthropicContentParamSource(
|
||||||
|
type="base64",
|
||||||
|
media_type=image_chunk["media_type"],
|
||||||
|
data=image_chunk["data"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
user_content.append(anthropic_document_param)
|
||||||
elif isinstance(user_message_types_block["content"], str):
|
elif isinstance(user_message_types_block["content"], str):
|
||||||
_anthropic_content_text_element: AnthropicMessagesTextParam = {
|
_anthropic_content_text_element: AnthropicMessagesTextParam = {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
|
@ -2885,6 +2905,11 @@ class BedrockConverseMessagesProcessor:
|
||||||
image_url=image_url, format=format
|
image_url=image_url, format=format
|
||||||
)
|
)
|
||||||
_parts.append(_part) # type: ignore
|
_parts.append(_part) # type: ignore
|
||||||
|
elif element["type"] == "file":
|
||||||
|
_part = await BedrockConverseMessagesProcessor._async_process_file_message(
|
||||||
|
message=cast(ChatCompletionFileObject, element)
|
||||||
|
)
|
||||||
|
_parts.append(_part)
|
||||||
_cache_point_block = (
|
_cache_point_block = (
|
||||||
litellm.AmazonConverseConfig()._get_cache_point_block(
|
litellm.AmazonConverseConfig()._get_cache_point_block(
|
||||||
message_block=cast(
|
message_block=cast(
|
||||||
|
@ -3054,6 +3079,45 @@ class BedrockConverseMessagesProcessor:
|
||||||
reasoning_content_blocks.append(bedrock_content_block)
|
reasoning_content_blocks.append(bedrock_content_block)
|
||||||
return reasoning_content_blocks
|
return reasoning_content_blocks
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_file_message(message: ChatCompletionFileObject) -> BedrockContentBlock:
|
||||||
|
file_message = message["file"]
|
||||||
|
file_data = file_message.get("file_data")
|
||||||
|
file_id = file_message.get("file_id")
|
||||||
|
|
||||||
|
if file_data is None and file_id is None:
|
||||||
|
raise litellm.BadRequestError(
|
||||||
|
message="file_data and file_id cannot both be None. Got={}".format(
|
||||||
|
message
|
||||||
|
),
|
||||||
|
model="",
|
||||||
|
llm_provider="bedrock",
|
||||||
|
)
|
||||||
|
format = file_message.get("format")
|
||||||
|
return BedrockImageProcessor.process_image_sync(
|
||||||
|
image_url=cast(str, file_id or file_data), format=format
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _async_process_file_message(
|
||||||
|
message: ChatCompletionFileObject,
|
||||||
|
) -> BedrockContentBlock:
|
||||||
|
file_message = message["file"]
|
||||||
|
file_data = file_message.get("file_data")
|
||||||
|
file_id = file_message.get("file_id")
|
||||||
|
format = file_message.get("format")
|
||||||
|
if file_data is None and file_id is None:
|
||||||
|
raise litellm.BadRequestError(
|
||||||
|
message="file_data and file_id cannot both be None. Got={}".format(
|
||||||
|
message
|
||||||
|
),
|
||||||
|
model="",
|
||||||
|
llm_provider="bedrock",
|
||||||
|
)
|
||||||
|
return await BedrockImageProcessor.process_image_async(
|
||||||
|
image_url=cast(str, file_id or file_data), format=format
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _bedrock_converse_messages_pt( # noqa: PLR0915
|
def _bedrock_converse_messages_pt( # noqa: PLR0915
|
||||||
messages: List,
|
messages: List,
|
||||||
|
@ -3126,6 +3190,13 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
|
||||||
format=format,
|
format=format,
|
||||||
)
|
)
|
||||||
_parts.append(_part) # type: ignore
|
_parts.append(_part) # type: ignore
|
||||||
|
elif element["type"] == "file":
|
||||||
|
_part = (
|
||||||
|
BedrockConverseMessagesProcessor._process_file_message(
|
||||||
|
message=cast(ChatCompletionFileObject, element)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_parts.append(_part)
|
||||||
_cache_point_block = (
|
_cache_point_block = (
|
||||||
litellm.AmazonConverseConfig()._get_cache_point_block(
|
litellm.AmazonConverseConfig()._get_cache_point_block(
|
||||||
message_block=cast(
|
message_block=cast(
|
||||||
|
|
|
@ -12,6 +12,7 @@ import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
from litellm.types.llms.openrouter import OpenRouterErrorMessage
|
||||||
from litellm.types.utils import ModelResponse, ModelResponseStream
|
from litellm.types.utils import ModelResponse, ModelResponseStream
|
||||||
|
|
||||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
@ -71,6 +72,24 @@ class OpenrouterConfig(OpenAIGPTConfig):
|
||||||
class OpenRouterChatCompletionStreamingHandler(BaseModelResponseIterator):
|
class OpenRouterChatCompletionStreamingHandler(BaseModelResponseIterator):
|
||||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||||
try:
|
try:
|
||||||
|
## HANDLE ERROR IN CHUNK ##
|
||||||
|
if "error" in chunk:
|
||||||
|
error_chunk = chunk["error"]
|
||||||
|
error_message = OpenRouterErrorMessage(
|
||||||
|
message="Message: {}, Metadata: {}, User ID: {}".format(
|
||||||
|
error_chunk["message"],
|
||||||
|
error_chunk.get("metadata", {}),
|
||||||
|
error_chunk.get("user_id", ""),
|
||||||
|
),
|
||||||
|
code=error_chunk["code"],
|
||||||
|
metadata=error_chunk.get("metadata", {}),
|
||||||
|
)
|
||||||
|
raise OpenRouterException(
|
||||||
|
message=error_message["message"],
|
||||||
|
status_code=error_message["code"],
|
||||||
|
headers=error_message["metadata"].get("headers", {}),
|
||||||
|
)
|
||||||
|
|
||||||
new_choices = []
|
new_choices = []
|
||||||
for choice in chunk["choices"]:
|
for choice in chunk["choices"]:
|
||||||
choice["delta"]["reasoning_content"] = choice["delta"].get("reasoning")
|
choice["delta"]["reasoning_content"] = choice["delta"].get("reasoning")
|
||||||
|
|
|
@ -1,25 +1,31 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "gpt-4o"
|
- model_name: "gpt-4o"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_base: http://0.0.0.0:8090
|
api_base: http://0.0.0.0:8090
|
||||||
rpm: 3
|
rpm: 3
|
||||||
- model_name: "gpt-4o-mini-openai"
|
- model_name: "gpt-4o-mini-openai"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-4o-mini
|
model: gpt-4o-mini
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
- model_name: "openai/*"
|
- model_name: "openai/*"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/*
|
model: openai/*
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
- model_name: "bedrock-nova"
|
- model_name: "bedrock-nova"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: us.amazon.nova-pro-v1:0
|
model: us.amazon.nova-pro-v1:0
|
||||||
- model_name: "gemini-2.0-flash"
|
- model_name: "gemini-2.0-flash"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gemini/gemini-2.0-flash
|
model: gemini/gemini-2.0-flash
|
||||||
api_key: os.environ/GEMINI_API_KEY
|
api_key: os.environ/GEMINI_API_KEY
|
||||||
|
- model_name: openrouter_model
|
||||||
|
litellm_params:
|
||||||
|
model: openrouter/openrouter_model
|
||||||
|
api_key: os.environ/OPENROUTER_API_KEY
|
||||||
|
api_base: http://0.0.0.0:8090
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
num_retries: 0
|
num_retries: 0
|
||||||
|
|
9
litellm/types/llms/openrouter.py
Normal file
9
litellm/types/llms/openrouter.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterErrorMessage(TypedDict):
|
||||||
|
message: str
|
||||||
|
code: int
|
||||||
|
metadata: Dict
|
|
@ -0,0 +1,81 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.llms.openrouter.chat.transformation import (
|
||||||
|
OpenRouterChatCompletionStreamingHandler,
|
||||||
|
OpenRouterException,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenRouterChatCompletionStreamingHandler:
|
||||||
|
def test_chunk_parser_successful(self):
|
||||||
|
handler = OpenRouterChatCompletionStreamingHandler(
|
||||||
|
streaming_response=None, sync_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test input chunk
|
||||||
|
chunk = {
|
||||||
|
"id": "test_id",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "test_model",
|
||||||
|
"choices": [
|
||||||
|
{"delta": {"content": "test content", "reasoning": "test reasoning"}}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse chunk
|
||||||
|
result = handler.chunk_parser(chunk)
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert result.id == "test_id"
|
||||||
|
assert result.object == "chat.completion.chunk"
|
||||||
|
assert result.created == 1234567890
|
||||||
|
assert result.model == "test_model"
|
||||||
|
assert len(result.choices) == 1
|
||||||
|
assert result.choices[0]["delta"]["reasoning_content"] == "test reasoning"
|
||||||
|
|
||||||
|
def test_chunk_parser_error_response(self):
|
||||||
|
handler = OpenRouterChatCompletionStreamingHandler(
|
||||||
|
streaming_response=None, sync_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test error chunk
|
||||||
|
error_chunk = {
|
||||||
|
"error": {
|
||||||
|
"message": "test error",
|
||||||
|
"code": 400,
|
||||||
|
"metadata": {"key": "value"},
|
||||||
|
"user_id": "test_user",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Verify error handling
|
||||||
|
with pytest.raises(OpenRouterException) as exc_info:
|
||||||
|
handler.chunk_parser(error_chunk)
|
||||||
|
|
||||||
|
assert "Message: test error" in str(exc_info.value)
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
def test_chunk_parser_key_error(self):
|
||||||
|
handler = OpenRouterChatCompletionStreamingHandler(
|
||||||
|
streaming_response=None, sync_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test invalid chunk missing required fields
|
||||||
|
invalid_chunk = {"incomplete": "data"}
|
||||||
|
|
||||||
|
# Verify KeyError handling
|
||||||
|
with pytest.raises(OpenRouterException) as exc_info:
|
||||||
|
handler.chunk_parser(invalid_chunk)
|
||||||
|
|
||||||
|
assert "KeyError" in str(exc_info.value)
|
||||||
|
assert exc_info.value.status_code == 400
|
|
@ -198,6 +198,42 @@ class BaseLLMChatTest(ABC):
|
||||||
messages=image_messages,
|
messages=image_messages,
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
|
||||||
|
def test_file_data_unit_test(self, pdf_messages):
|
||||||
|
from litellm.utils import supports_pdf_input, return_raw_request
|
||||||
|
from litellm.types.utils import CallTypes
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import convert_to_anthropic_image_obj
|
||||||
|
|
||||||
|
media_chunk = convert_to_anthropic_image_obj(
|
||||||
|
openai_image_url=pdf_messages,
|
||||||
|
format=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_content = [
|
||||||
|
{"type": "text", "text": "What's this file about?"},
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"file": {
|
||||||
|
"file_data": pdf_messages,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
image_messages = [{"role": "user", "content": file_content}]
|
||||||
|
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
|
||||||
|
if not supports_pdf_input(base_completion_call_args["model"], None):
|
||||||
|
pytest.skip("Model does not support image input")
|
||||||
|
|
||||||
|
raw_request = return_raw_request(
|
||||||
|
endpoint=CallTypes.completion,
|
||||||
|
kwargs={**base_completion_call_args, "messages": image_messages},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("RAW REQUEST", raw_request)
|
||||||
|
|
||||||
|
assert media_chunk["data"] in json.dumps(raw_request)
|
||||||
|
|
||||||
def test_message_with_name(self):
|
def test_message_with_name(self):
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue