forked from phoenix/litellm-mirror
[Feature]: json_schema in response support for Anthropic (#6748)
* _convert_tool_response_to_message * fix ModelResponseIterator * fix test_json_response_format * test_json_response_format_stream * fix _convert_tool_response_to_message * use helper _handle_json_mode_chunk * fix _process_response * unit testing for test_convert_tool_response_to_message_no_arguments * update doc for JSON mode
This commit is contained in:
parent
a70a0688d8
commit
6ae0bc4a11
4 changed files with 221 additions and 9 deletions
|
@ -75,6 +75,7 @@ Works for:
|
|||
- Google AI Studio - Gemini models
|
||||
- Vertex AI models (Gemini + Anthropic)
|
||||
- Bedrock Models
|
||||
- Anthropic API Models
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
|
|
@ -44,7 +44,9 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.types.utils import Message as LitellmMessage
|
||||
from litellm.types.utils import PromptTokensDetailsWrapper
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
from ...base import BaseLLM
|
||||
|
@ -94,6 +96,7 @@ async def make_call(
|
|||
messages: list,
|
||||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
json_mode: bool,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
if client is None:
|
||||
client = litellm.module_level_aclient
|
||||
|
@ -119,7 +122,9 @@ async def make_call(
|
|||
raise AnthropicError(status_code=500, message=str(e))
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.aiter_lines(), sync_stream=False
|
||||
streaming_response=response.aiter_lines(),
|
||||
sync_stream=False,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
|
@ -142,6 +147,7 @@ def make_sync_call(
|
|||
messages: list,
|
||||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
json_mode: bool,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
if client is None:
|
||||
client = litellm.module_level_client # re-use a module level client
|
||||
|
@ -175,7 +181,7 @@ def make_sync_call(
|
|||
)
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
|
@ -270,11 +276,12 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
"arguments"
|
||||
)
|
||||
if json_mode_content_str is not None:
|
||||
args = json.loads(json_mode_content_str)
|
||||
values: Optional[dict] = args.get("values")
|
||||
if values is not None:
|
||||
_message = litellm.Message(content=json.dumps(values))
|
||||
_converted_message = self._convert_tool_response_to_message(
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
if _converted_message is not None:
|
||||
completion_response["stop_reason"] = "stop"
|
||||
_message = _converted_message
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = completion_response[
|
||||
"content"
|
||||
|
@ -318,6 +325,37 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
model_response._hidden_params = _hidden_params
|
||||
return model_response
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_response_to_message(
|
||||
tool_calls: List[ChatCompletionToolCallChunk],
|
||||
) -> Optional[LitellmMessage]:
|
||||
"""
|
||||
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
|
||||
|
||||
"""
|
||||
## HANDLE JSON MODE - anthropic returns single function call
|
||||
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
|
||||
"arguments"
|
||||
)
|
||||
try:
|
||||
if json_mode_content_str is not None:
|
||||
args = json.loads(json_mode_content_str)
|
||||
if (
|
||||
isinstance(args, dict)
|
||||
and (values := args.get("values")) is not None
|
||||
):
|
||||
_message = litellm.Message(content=json.dumps(values))
|
||||
return _message
|
||||
else:
|
||||
# a lot of the times the `values` key is not present in the tool response
|
||||
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||
_message = litellm.Message(content=json.dumps(args))
|
||||
return _message
|
||||
except json.JSONDecodeError:
|
||||
# json decode error does occur, return the original tool response str
|
||||
return litellm.Message(content=json_mode_content_str)
|
||||
return None
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -334,6 +372,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
stream,
|
||||
_is_function_call,
|
||||
data: dict,
|
||||
json_mode: bool,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -350,6 +389,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -501,6 +541,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
_is_function_call=_is_function_call,
|
||||
json_mode=json_mode,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
|
@ -548,6 +589,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -606,11 +648,14 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, streaming_response, sync_stream: bool):
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.content_blocks: List[ContentBlockDelta] = []
|
||||
self.tool_index = -1
|
||||
self.json_mode = json_mode
|
||||
|
||||
def check_empty_tool_call_args(self) -> bool:
|
||||
"""
|
||||
|
@ -772,6 +817,8 @@ class ModelResponseIterator:
|
|||
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
|
||||
)
|
||||
|
||||
text, tool_use = self._handle_json_mode_chunk(text=text, tool_use=tool_use)
|
||||
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
|
@ -786,6 +833,34 @@ class ModelResponseIterator:
|
|||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
def _handle_json_mode_chunk(
|
||||
self, text: str, tool_use: Optional[ChatCompletionToolCallChunk]
|
||||
) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]:
|
||||
"""
|
||||
If JSON mode is enabled, convert the tool call to a message.
|
||||
|
||||
Anthropic returns the JSON schema as part of the tool call
|
||||
OpenAI returns the JSON schema as part of the content, this handles placing it in the content
|
||||
|
||||
Args:
|
||||
text: str
|
||||
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||
Returns:
|
||||
Tuple[str, Optional[ChatCompletionToolCallChunk]]
|
||||
|
||||
text: The text to use in the content
|
||||
tool_use: The ChatCompletionToolCallChunk to use in the chunk response
|
||||
"""
|
||||
if self.json_mode is True and tool_use is not None:
|
||||
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||
tool_calls=[tool_use]
|
||||
)
|
||||
if message is not None:
|
||||
text = message.content or ""
|
||||
tool_use = None
|
||||
|
||||
return text, tool_use
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
|
|
@ -48,6 +48,9 @@ class BaseLLMChatTest(ABC):
|
|||
)
|
||||
assert response is not None
|
||||
|
||||
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
|
||||
assert response.choices[0].message.content is not None
|
||||
|
||||
def test_message_with_name(self):
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
messages = [
|
||||
|
@ -82,6 +85,49 @@ class BaseLLMChatTest(ABC):
|
|||
|
||||
print(response)
|
||||
|
||||
# OpenAI guarantees that the JSON schema is returned in the content
|
||||
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||
assert response.choices[0].message.content is not None
|
||||
|
||||
def test_json_response_format_stream(self):
|
||||
"""
|
||||
Test that the JSON response format with streaming is supported by the LLM API
|
||||
"""
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
litellm.set_verbose = True
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your output should be a JSON object with no additional properties. ",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Respond with this in json. city=San Francisco, state=CA, weather=sunny, temp=60",
|
||||
},
|
||||
]
|
||||
|
||||
response = litellm.completion(
|
||||
**base_completion_call_args,
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
content = ""
|
||||
for chunk in response:
|
||||
content += chunk.choices[0].delta.content or ""
|
||||
|
||||
print("content=", content)
|
||||
|
||||
# OpenAI guarantees that the JSON schema is returned in the content
|
||||
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||
# we need to assert that the JSON schema was returned in the content, (for Anthropic we were returning it as part of the tool call)
|
||||
assert content is not None
|
||||
assert len(content) > 0
|
||||
|
||||
@pytest.fixture
|
||||
def pdf_messages(self):
|
||||
import base64
|
||||
|
|
|
@ -33,8 +33,10 @@ from litellm import (
|
|||
)
|
||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
||||
from litellm.types.llms.anthropic import AnthropicResponse
|
||||
|
||||
from litellm.types.utils import GenericStreamingChunk, ChatCompletionToolCallChunk
|
||||
from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
|
||||
from litellm.llms.anthropic.common_utils import process_anthropic_headers
|
||||
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
|
||||
from httpx import Headers
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
|
@ -694,3 +696,91 @@ class TestAnthropicCompletion(BaseLLMChatTest):
|
|||
assert _document_validation["type"] == "document"
|
||||
assert _document_validation["source"]["media_type"] == "application/pdf"
|
||||
assert _document_validation["source"]["type"] == "base64"
|
||||
|
||||
|
||||
def test_convert_tool_response_to_message_with_values():
|
||||
"""Test converting a tool response with 'values' key to a message"""
|
||||
tool_calls = [
|
||||
ChatCompletionToolCallChunk(
|
||||
id="test_id",
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name="json_tool_call",
|
||||
arguments='{"values": {"name": "John", "age": 30}}',
|
||||
),
|
||||
index=0,
|
||||
)
|
||||
]
|
||||
|
||||
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
assert message is not None
|
||||
assert message.content == '{"name": "John", "age": 30}'
|
||||
|
||||
|
||||
def test_convert_tool_response_to_message_without_values():
|
||||
"""
|
||||
Test converting a tool response without 'values' key to a message
|
||||
|
||||
Anthropic API returns the JSON schema in the tool call, OpenAI Spec expects it in the message. This test ensures that the tool call is converted to a message correctly.
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||
"""
|
||||
tool_calls = [
|
||||
ChatCompletionToolCallChunk(
|
||||
id="test_id",
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name="json_tool_call", arguments='{"name": "John", "age": 30}'
|
||||
),
|
||||
index=0,
|
||||
)
|
||||
]
|
||||
|
||||
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
assert message is not None
|
||||
assert message.content == '{"name": "John", "age": 30}'
|
||||
|
||||
|
||||
def test_convert_tool_response_to_message_invalid_json():
|
||||
"""Test converting a tool response with invalid JSON"""
|
||||
tool_calls = [
|
||||
ChatCompletionToolCallChunk(
|
||||
id="test_id",
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name="json_tool_call", arguments="invalid json"
|
||||
),
|
||||
index=0,
|
||||
)
|
||||
]
|
||||
|
||||
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
assert message is not None
|
||||
assert message.content == "invalid json"
|
||||
|
||||
|
||||
def test_convert_tool_response_to_message_no_arguments():
|
||||
"""Test converting a tool response with no arguments"""
|
||||
tool_calls = [
|
||||
ChatCompletionToolCallChunk(
|
||||
id="test_id",
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(name="json_tool_call"),
|
||||
index=0,
|
||||
)
|
||||
]
|
||||
|
||||
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
assert message is None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue