Compare commits

...
Sign in to create a new pull request.

9 commits

Author SHA1 Message Date
Ishaan Jaff
0f75cd1837 update doc for JSON mode 2024-11-14 16:58:58 -08:00
Ishaan Jaff
26f7a6b7a2 unit testing for test_convert_tool_response_to_message_no_arguments 2024-11-14 13:12:39 -08:00
Ishaan Jaff
d77fd30f2f fix _process_response 2024-11-14 12:57:53 -08:00
Ishaan Jaff
1cdee5a50a use helper _handle_json_mode_chunk 2024-11-14 12:56:37 -08:00
Ishaan Jaff
ce235facd0 fix _convert_tool_response_to_message 2024-11-14 12:45:37 -08:00
Ishaan Jaff
f5f36fb96c test_json_response_format_stream 2024-11-14 11:32:52 -08:00
Ishaan Jaff
eaf5723b94 fix test_json_response_format 2024-11-14 10:58:00 -08:00
Ishaan Jaff
a866bdb01e fix ModelResponseIterator 2024-11-14 10:53:37 -08:00
Ishaan Jaff
c3a2c77b55 _convert_tool_response_to_message 2024-11-14 10:51:04 -08:00
4 changed files with 221 additions and 9 deletions

View file

@ -75,6 +75,7 @@ Works for:
- Google AI Studio - Gemini models
- Vertex AI models (Gemini + Anthropic)
- Bedrock Models
- Anthropic API Models
<Tabs>
<TabItem value="sdk" label="SDK">

View file

@ -44,7 +44,9 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper
from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import Message as LitellmMessage
from litellm.types.utils import PromptTokensDetailsWrapper
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ...base import BaseLLM
@ -94,6 +96,7 @@ async def make_call(
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_aclient
@ -119,7 +122,9 @@ async def make_call(
raise AnthropicError(status_code=500, message=str(e))
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
streaming_response=response.aiter_lines(),
sync_stream=False,
json_mode=json_mode,
)
# LOGGING
@ -142,6 +147,7 @@ def make_sync_call(
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_client # re-use a module level client
@ -175,7 +181,7 @@ def make_sync_call(
)
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode
)
# LOGGING
@ -270,11 +276,12 @@ class AnthropicChatCompletion(BaseLLM):
"arguments"
)
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
values: Optional[dict] = args.get("values")
if values is not None:
_message = litellm.Message(content=json.dumps(values))
_converted_message = self._convert_tool_response_to_message(
tool_calls=tool_calls,
)
if _converted_message is not None:
completion_response["stop_reason"] = "stop"
_message = _converted_message
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
@ -318,6 +325,37 @@ class AnthropicChatCompletion(BaseLLM):
model_response._hidden_params = _hidden_params
return model_response
@staticmethod
def _convert_tool_response_to_message(
tool_calls: List[ChatCompletionToolCallChunk],
) -> Optional[LitellmMessage]:
"""
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
"""
## HANDLE JSON MODE - anthropic returns single function call
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
try:
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
if (
isinstance(args, dict)
and (values := args.get("values")) is not None
):
_message = litellm.Message(content=json.dumps(values))
return _message
else:
# a lot of the times the `values` key is not present in the tool response
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
_message = litellm.Message(content=json.dumps(args))
return _message
except json.JSONDecodeError:
# json decode error does occur, return the original tool response str
return litellm.Message(content=json_mode_content_str)
return None
async def acompletion_stream_function(
self,
model: str,
@ -334,6 +372,7 @@ class AnthropicChatCompletion(BaseLLM):
stream,
_is_function_call,
data: dict,
json_mode: bool,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -350,6 +389,7 @@ class AnthropicChatCompletion(BaseLLM):
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
json_mode=json_mode,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
@ -500,6 +540,7 @@ class AnthropicChatCompletion(BaseLLM):
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
json_mode=json_mode,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
@ -547,6 +588,7 @@ class AnthropicChatCompletion(BaseLLM):
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
json_mode=json_mode,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
@ -605,11 +647,14 @@ class AnthropicChatCompletion(BaseLLM):
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List[ContentBlockDelta] = []
self.tool_index = -1
self.json_mode = json_mode
def check_empty_tool_call_args(self) -> bool:
"""
@ -771,6 +816,8 @@ class ModelResponseIterator:
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
)
text, tool_use = self._handle_json_mode_chunk(text=text, tool_use=tool_use)
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
@ -785,6 +832,34 @@ class ModelResponseIterator:
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
def _handle_json_mode_chunk(
self, text: str, tool_use: Optional[ChatCompletionToolCallChunk]
) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]:
"""
If JSON mode is enabled, convert the tool call to a message.
Anthropic returns the JSON schema as part of the tool call
OpenAI returns the JSON schema as part of the content, this handles placing it in the content
Args:
text: str
tool_use: Optional[ChatCompletionToolCallChunk]
Returns:
Tuple[str, Optional[ChatCompletionToolCallChunk]]
text: The text to use in the content
tool_use: The ChatCompletionToolCallChunk to use in the chunk response
"""
if self.json_mode is True and tool_use is not None:
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=[tool_use]
)
if message is not None:
text = message.content or ""
tool_use = None
return text, tool_use
# Sync iterator
def __iter__(self):
return self

View file

@ -45,6 +45,9 @@ class BaseLLMChatTest(ABC):
)
assert response is not None
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
assert response.choices[0].message.content is not None
def test_message_with_name(self):
base_completion_call_args = self.get_base_completion_call_args()
messages = [
@ -79,6 +82,49 @@ class BaseLLMChatTest(ABC):
print(response)
# OpenAI guarantees that the JSON schema is returned in the content
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
assert response.choices[0].message.content is not None
def test_json_response_format_stream(self):
"""
Test that the JSON response format with streaming is supported by the LLM API
"""
base_completion_call_args = self.get_base_completion_call_args()
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your output should be a JSON object with no additional properties. ",
},
{
"role": "user",
"content": "Respond with this in json. city=San Francisco, state=CA, weather=sunny, temp=60",
},
]
response = litellm.completion(
**base_completion_call_args,
messages=messages,
response_format={"type": "json_object"},
stream=True,
)
print(response)
content = ""
for chunk in response:
content += chunk.choices[0].delta.content or ""
print("content=", content)
# OpenAI guarantees that the JSON schema is returned in the content
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
# we need to assert that the JSON schema was returned in the content, (for Anthropic we were returning it as part of the tool call)
assert content is not None
assert len(content) > 0
@pytest.fixture
def pdf_messages(self):
import base64

View file

@ -33,8 +33,10 @@ from litellm import (
)
from litellm.adapters.anthropic_adapter import anthropic_adapter
from litellm.types.llms.anthropic import AnthropicResponse
from litellm.types.utils import GenericStreamingChunk, ChatCompletionToolCallChunk
from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
from litellm.llms.anthropic.common_utils import process_anthropic_headers
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
from httpx import Headers
from base_llm_unit_tests import BaseLLMChatTest
@ -694,3 +696,91 @@ class TestAnthropicCompletion(BaseLLMChatTest):
assert _document_validation["type"] == "document"
assert _document_validation["source"]["media_type"] == "application/pdf"
assert _document_validation["source"]["type"] == "base64"
def test_convert_tool_response_to_message_with_values():
"""Test converting a tool response with 'values' key to a message"""
tool_calls = [
ChatCompletionToolCallChunk(
id="test_id",
type="function",
function=ChatCompletionToolCallFunctionChunk(
name="json_tool_call",
arguments='{"values": {"name": "John", "age": 30}}',
),
index=0,
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
assert message is not None
assert message.content == '{"name": "John", "age": 30}'
def test_convert_tool_response_to_message_without_values():
"""
Test converting a tool response without 'values' key to a message
Anthropic API returns the JSON schema in the tool call, OpenAI Spec expects it in the message. This test ensures that the tool call is converted to a message correctly.
Relevant issue: https://github.com/BerriAI/litellm/issues/6741
"""
tool_calls = [
ChatCompletionToolCallChunk(
id="test_id",
type="function",
function=ChatCompletionToolCallFunctionChunk(
name="json_tool_call", arguments='{"name": "John", "age": 30}'
),
index=0,
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
assert message is not None
assert message.content == '{"name": "John", "age": 30}'
def test_convert_tool_response_to_message_invalid_json():
"""Test converting a tool response with invalid JSON"""
tool_calls = [
ChatCompletionToolCallChunk(
id="test_id",
type="function",
function=ChatCompletionToolCallFunctionChunk(
name="json_tool_call", arguments="invalid json"
),
index=0,
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
assert message is not None
assert message.content == "invalid json"
def test_convert_tool_response_to_message_no_arguments():
"""Test converting a tool response with no arguments"""
tool_calls = [
ChatCompletionToolCallChunk(
id="test_id",
type="function",
function=ChatCompletionToolCallFunctionChunk(name="json_tool_call"),
index=0,
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
assert message is None