forked from phoenix/litellm-mirror
[Feature]: json_schema in response support for Anthropic (#6748)
* _convert_tool_response_to_message * fix ModelResponseIterator * fix test_json_response_format * test_json_response_format_stream * fix _convert_tool_response_to_message * use helper _handle_json_mode_chunk * fix _process_response * unit testing for test_convert_tool_response_to_message_no_arguments * update doc for JSON mode
This commit is contained in:
parent
a70a0688d8
commit
6ae0bc4a11
4 changed files with 221 additions and 9 deletions
|
@ -75,6 +75,7 @@ Works for:
|
||||||
- Google AI Studio - Gemini models
|
- Google AI Studio - Gemini models
|
||||||
- Vertex AI models (Gemini + Anthropic)
|
- Vertex AI models (Gemini + Anthropic)
|
||||||
- Bedrock Models
|
- Bedrock Models
|
||||||
|
- Anthropic API Models
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="sdk" label="SDK">
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
|
@ -44,7 +44,9 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper
|
from litellm.types.utils import GenericStreamingChunk
|
||||||
|
from litellm.types.utils import Message as LitellmMessage
|
||||||
|
from litellm.types.utils import PromptTokensDetailsWrapper
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
from ...base import BaseLLM
|
from ...base import BaseLLM
|
||||||
|
@ -94,6 +96,7 @@ async def make_call(
|
||||||
messages: list,
|
messages: list,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
json_mode: bool,
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
) -> Tuple[Any, httpx.Headers]:
|
||||||
if client is None:
|
if client is None:
|
||||||
client = litellm.module_level_aclient
|
client = litellm.module_level_aclient
|
||||||
|
@ -119,7 +122,9 @@ async def make_call(
|
||||||
raise AnthropicError(status_code=500, message=str(e))
|
raise AnthropicError(status_code=500, message=str(e))
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.aiter_lines(), sync_stream=False
|
streaming_response=response.aiter_lines(),
|
||||||
|
sync_stream=False,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
|
@ -142,6 +147,7 @@ def make_sync_call(
|
||||||
messages: list,
|
messages: list,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
json_mode: bool,
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
) -> Tuple[Any, httpx.Headers]:
|
||||||
if client is None:
|
if client is None:
|
||||||
client = litellm.module_level_client # re-use a module level client
|
client = litellm.module_level_client # re-use a module level client
|
||||||
|
@ -175,7 +181,7 @@ def make_sync_call(
|
||||||
)
|
)
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.iter_lines(), sync_stream=True
|
streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
|
@ -270,11 +276,12 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
"arguments"
|
"arguments"
|
||||||
)
|
)
|
||||||
if json_mode_content_str is not None:
|
if json_mode_content_str is not None:
|
||||||
args = json.loads(json_mode_content_str)
|
_converted_message = self._convert_tool_response_to_message(
|
||||||
values: Optional[dict] = args.get("values")
|
tool_calls=tool_calls,
|
||||||
if values is not None:
|
)
|
||||||
_message = litellm.Message(content=json.dumps(values))
|
if _converted_message is not None:
|
||||||
completion_response["stop_reason"] = "stop"
|
completion_response["stop_reason"] = "stop"
|
||||||
|
_message = _converted_message
|
||||||
model_response.choices[0].message = _message # type: ignore
|
model_response.choices[0].message = _message # type: ignore
|
||||||
model_response._hidden_params["original_response"] = completion_response[
|
model_response._hidden_params["original_response"] = completion_response[
|
||||||
"content"
|
"content"
|
||||||
|
@ -318,6 +325,37 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
model_response._hidden_params = _hidden_params
|
model_response._hidden_params = _hidden_params
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_tool_response_to_message(
|
||||||
|
tool_calls: List[ChatCompletionToolCallChunk],
|
||||||
|
) -> Optional[LitellmMessage]:
|
||||||
|
"""
|
||||||
|
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
|
||||||
|
|
||||||
|
"""
|
||||||
|
## HANDLE JSON MODE - anthropic returns single function call
|
||||||
|
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
|
||||||
|
"arguments"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if json_mode_content_str is not None:
|
||||||
|
args = json.loads(json_mode_content_str)
|
||||||
|
if (
|
||||||
|
isinstance(args, dict)
|
||||||
|
and (values := args.get("values")) is not None
|
||||||
|
):
|
||||||
|
_message = litellm.Message(content=json.dumps(values))
|
||||||
|
return _message
|
||||||
|
else:
|
||||||
|
# a lot of the times the `values` key is not present in the tool response
|
||||||
|
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||||
|
_message = litellm.Message(content=json.dumps(args))
|
||||||
|
return _message
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# json decode error does occur, return the original tool response str
|
||||||
|
return litellm.Message(content=json_mode_content_str)
|
||||||
|
return None
|
||||||
|
|
||||||
async def acompletion_stream_function(
|
async def acompletion_stream_function(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -334,6 +372,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
stream,
|
stream,
|
||||||
_is_function_call,
|
_is_function_call,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
json_mode: bool,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -350,6 +389,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -501,6 +541,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
_is_function_call=_is_function_call,
|
_is_function_call=_is_function_call,
|
||||||
|
json_mode=json_mode,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
@ -548,6 +589,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
return CustomStreamWrapper(
|
return CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -606,11 +648,14 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseIterator:
|
class ModelResponseIterator:
|
||||||
def __init__(self, streaming_response, sync_stream: bool):
|
def __init__(
|
||||||
|
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||||
|
):
|
||||||
self.streaming_response = streaming_response
|
self.streaming_response = streaming_response
|
||||||
self.response_iterator = self.streaming_response
|
self.response_iterator = self.streaming_response
|
||||||
self.content_blocks: List[ContentBlockDelta] = []
|
self.content_blocks: List[ContentBlockDelta] = []
|
||||||
self.tool_index = -1
|
self.tool_index = -1
|
||||||
|
self.json_mode = json_mode
|
||||||
|
|
||||||
def check_empty_tool_call_args(self) -> bool:
|
def check_empty_tool_call_args(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -772,6 +817,8 @@ class ModelResponseIterator:
|
||||||
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
|
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
|
||||||
)
|
)
|
||||||
|
|
||||||
|
text, tool_use = self._handle_json_mode_chunk(text=text, tool_use=tool_use)
|
||||||
|
|
||||||
returned_chunk = GenericStreamingChunk(
|
returned_chunk = GenericStreamingChunk(
|
||||||
text=text,
|
text=text,
|
||||||
tool_use=tool_use,
|
tool_use=tool_use,
|
||||||
|
@ -786,6 +833,34 @@ class ModelResponseIterator:
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||||
|
|
||||||
|
def _handle_json_mode_chunk(
|
||||||
|
self, text: str, tool_use: Optional[ChatCompletionToolCallChunk]
|
||||||
|
) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]:
|
||||||
|
"""
|
||||||
|
If JSON mode is enabled, convert the tool call to a message.
|
||||||
|
|
||||||
|
Anthropic returns the JSON schema as part of the tool call
|
||||||
|
OpenAI returns the JSON schema as part of the content, this handles placing it in the content
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: str
|
||||||
|
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||||
|
Returns:
|
||||||
|
Tuple[str, Optional[ChatCompletionToolCallChunk]]
|
||||||
|
|
||||||
|
text: The text to use in the content
|
||||||
|
tool_use: The ChatCompletionToolCallChunk to use in the chunk response
|
||||||
|
"""
|
||||||
|
if self.json_mode is True and tool_use is not None:
|
||||||
|
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||||
|
tool_calls=[tool_use]
|
||||||
|
)
|
||||||
|
if message is not None:
|
||||||
|
text = message.content or ""
|
||||||
|
tool_use = None
|
||||||
|
|
||||||
|
return text, tool_use
|
||||||
|
|
||||||
# Sync iterator
|
# Sync iterator
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -48,6 +48,9 @@ class BaseLLMChatTest(ABC):
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
|
||||||
|
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
|
||||||
def test_message_with_name(self):
|
def test_message_with_name(self):
|
||||||
base_completion_call_args = self.get_base_completion_call_args()
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -82,6 +85,49 @@ class BaseLLMChatTest(ABC):
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
# OpenAI guarantees that the JSON schema is returned in the content
|
||||||
|
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
|
||||||
|
def test_json_response_format_stream(self):
|
||||||
|
"""
|
||||||
|
Test that the JSON response format with streaming is supported by the LLM API
|
||||||
|
"""
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Your output should be a JSON object with no additional properties. ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Respond with this in json. city=San Francisco, state=CA, weather=sunny, temp=60",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = litellm.completion(
|
||||||
|
**base_completion_call_args,
|
||||||
|
messages=messages,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
for chunk in response:
|
||||||
|
content += chunk.choices[0].delta.content or ""
|
||||||
|
|
||||||
|
print("content=", content)
|
||||||
|
|
||||||
|
# OpenAI guarantees that the JSON schema is returned in the content
|
||||||
|
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||||
|
# we need to assert that the JSON schema was returned in the content, (for Anthropic we were returning it as part of the tool call)
|
||||||
|
assert content is not None
|
||||||
|
assert len(content) > 0
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pdf_messages(self):
|
def pdf_messages(self):
|
||||||
import base64
|
import base64
|
||||||
|
|
|
@ -33,8 +33,10 @@ from litellm import (
|
||||||
)
|
)
|
||||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
||||||
from litellm.types.llms.anthropic import AnthropicResponse
|
from litellm.types.llms.anthropic import AnthropicResponse
|
||||||
|
from litellm.types.utils import GenericStreamingChunk, ChatCompletionToolCallChunk
|
||||||
|
from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
|
||||||
from litellm.llms.anthropic.common_utils import process_anthropic_headers
|
from litellm.llms.anthropic.common_utils import process_anthropic_headers
|
||||||
|
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
|
||||||
from httpx import Headers
|
from httpx import Headers
|
||||||
from base_llm_unit_tests import BaseLLMChatTest
|
from base_llm_unit_tests import BaseLLMChatTest
|
||||||
|
|
||||||
|
@ -694,3 +696,91 @@ class TestAnthropicCompletion(BaseLLMChatTest):
|
||||||
assert _document_validation["type"] == "document"
|
assert _document_validation["type"] == "document"
|
||||||
assert _document_validation["source"]["media_type"] == "application/pdf"
|
assert _document_validation["source"]["media_type"] == "application/pdf"
|
||||||
assert _document_validation["source"]["type"] == "base64"
|
assert _document_validation["source"]["type"] == "base64"
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tool_response_to_message_with_values():
|
||||||
|
"""Test converting a tool response with 'values' key to a message"""
|
||||||
|
tool_calls = [
|
||||||
|
ChatCompletionToolCallChunk(
|
||||||
|
id="test_id",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolCallFunctionChunk(
|
||||||
|
name="json_tool_call",
|
||||||
|
arguments='{"values": {"name": "John", "age": 30}}',
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||||
|
tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
assert message is not None
|
||||||
|
assert message.content == '{"name": "John", "age": 30}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tool_response_to_message_without_values():
|
||||||
|
"""
|
||||||
|
Test converting a tool response without 'values' key to a message
|
||||||
|
|
||||||
|
Anthropic API returns the JSON schema in the tool call, OpenAI Spec expects it in the message. This test ensures that the tool call is converted to a message correctly.
|
||||||
|
|
||||||
|
Relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||||
|
"""
|
||||||
|
tool_calls = [
|
||||||
|
ChatCompletionToolCallChunk(
|
||||||
|
id="test_id",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolCallFunctionChunk(
|
||||||
|
name="json_tool_call", arguments='{"name": "John", "age": 30}'
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||||
|
tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
assert message is not None
|
||||||
|
assert message.content == '{"name": "John", "age": 30}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tool_response_to_message_invalid_json():
|
||||||
|
"""Test converting a tool response with invalid JSON"""
|
||||||
|
tool_calls = [
|
||||||
|
ChatCompletionToolCallChunk(
|
||||||
|
id="test_id",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolCallFunctionChunk(
|
||||||
|
name="json_tool_call", arguments="invalid json"
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||||
|
tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
assert message is not None
|
||||||
|
assert message.content == "invalid json"
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tool_response_to_message_no_arguments():
|
||||||
|
"""Test converting a tool response with no arguments"""
|
||||||
|
tool_calls = [
|
||||||
|
ChatCompletionToolCallChunk(
|
||||||
|
id="test_id",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolCallFunctionChunk(name="json_tool_call"),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||||
|
tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
assert message is None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue