diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 2d119a28f..d56834c6e 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -44,7 +44,9 @@ from litellm.types.llms.openai import ( ChatCompletionToolCallFunctionChunk, ChatCompletionUsageBlock, ) -from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper +from litellm.types.utils import GenericStreamingChunk +from litellm.types.utils import Message as LitellmMessage +from litellm.types.utils import PromptTokensDetailsWrapper from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from ...base import BaseLLM @@ -94,6 +96,7 @@ async def make_call( messages: list, logging_obj, timeout: Optional[Union[float, httpx.Timeout]], + json_mode: bool, ) -> Tuple[Any, httpx.Headers]: if client is None: client = litellm.module_level_aclient @@ -119,7 +122,9 @@ async def make_call( raise AnthropicError(status_code=500, message=str(e)) completion_stream = ModelResponseIterator( - streaming_response=response.aiter_lines(), sync_stream=False + streaming_response=response.aiter_lines(), + sync_stream=False, + json_mode=json_mode, ) # LOGGING @@ -142,6 +147,7 @@ def make_sync_call( messages: list, logging_obj, timeout: Optional[Union[float, httpx.Timeout]], + json_mode: bool, ) -> Tuple[Any, httpx.Headers]: if client is None: client = litellm.module_level_client # re-use a module level client @@ -175,7 +181,7 @@ def make_sync_call( ) completion_stream = ModelResponseIterator( - streaming_response=response.iter_lines(), sync_stream=True + streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode ) # LOGGING @@ -270,10 +276,10 @@ class AnthropicChatCompletion(BaseLLM): "arguments" ) if json_mode_content_str is not None: - args = json.loads(json_mode_content_str) - values: Optional[dict] = args.get("values") - if values is not None: - _message = litellm.Message(content=json.dumps(values)) + _message = self._convert_tool_response_to_message( + tool_calls=tool_calls, + ) + if _message is not None: completion_response["stop_reason"] = "stop" model_response.choices[0].message = _message # type: ignore model_response._hidden_params["original_response"] = completion_response[ @@ -318,6 +324,35 @@ class AnthropicChatCompletion(BaseLLM): model_response._hidden_params = _hidden_params return model_response + @staticmethod + def _convert_tool_response_to_message( + tool_calls: List[ChatCompletionToolCallChunk], + ) -> Optional[LitellmMessage]: + """ + In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format + + """ + ## HANDLE JSON MODE - anthropic returns single function call + json_mode_content_str: Optional[str] = tool_calls[0]["function"].get( + "arguments" + ) + try: + if json_mode_content_str is not None: + args = json.loads(json_mode_content_str) + values: Optional[dict] = args.get("values") + if values is not None: + _message = litellm.Message(content=json.dumps(values)) + return _message + else: + # a lot of the times the `values` key is not present in the tool response + # relevant issue: https://github.com/BerriAI/litellm/issues/6741 + _message = litellm.Message(content=json.dumps(args)) + return _message + except json.JSONDecodeError: + # json decode error does occur, return the original tool response str + return litellm.Message(content=json_mode_content_str) + return None + async def acompletion_stream_function( self, model: str, @@ -334,6 +369,7 @@ class AnthropicChatCompletion(BaseLLM): stream, _is_function_call, data: dict, + json_mode: bool, optional_params=None, litellm_params=None, logger_fn=None, @@ -350,6 +386,7 @@ class AnthropicChatCompletion(BaseLLM): messages=messages, logging_obj=logging_obj, timeout=timeout, + json_mode=json_mode, ) streamwrapper = CustomStreamWrapper( completion_stream=completion_stream, @@ -500,6 +537,7 @@ class AnthropicChatCompletion(BaseLLM): optional_params=optional_params, stream=stream, _is_function_call=_is_function_call, + json_mode=json_mode, litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, @@ -547,6 +585,7 @@ class AnthropicChatCompletion(BaseLLM): messages=messages, logging_obj=logging_obj, timeout=timeout, + json_mode=json_mode, ) return CustomStreamWrapper( completion_stream=completion_stream, @@ -605,11 +644,12 @@ class AnthropicChatCompletion(BaseLLM): class ModelResponseIterator: - def __init__(self, streaming_response, sync_stream: bool): + def __init__(self, streaming_response, sync_stream: bool, json_mode: bool): self.streaming_response = streaming_response self.response_iterator = self.streaming_response self.content_blocks: List[ContentBlockDelta] = [] self.tool_index = -1 + self.json_mode = json_mode def check_empty_tool_call_args(self) -> bool: """ @@ -771,6 +811,14 @@ class ModelResponseIterator: status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500 ) + if self.json_mode is True and tool_use is not None: + message = AnthropicChatCompletion._convert_tool_response_to_message( + tool_calls=[tool_use] + ) + if message is not None: + text = message.content or "" + tool_use = None + returned_chunk = GenericStreamingChunk( text=text, tool_use=tool_use,