_convert_tool_response_to_message

This commit is contained in:
Ishaan Jaff 2024-11-14 10:51:04 -08:00
parent da84056e59
commit c3a2c77b55

View file

@ -44,7 +44,9 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
) )
from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import Message as LitellmMessage
from litellm.types.utils import PromptTokensDetailsWrapper
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ...base import BaseLLM from ...base import BaseLLM
@ -94,6 +96,7 @@ async def make_call(
messages: list, messages: list,
logging_obj, logging_obj,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]: ) -> Tuple[Any, httpx.Headers]:
if client is None: if client is None:
client = litellm.module_level_aclient client = litellm.module_level_aclient
@ -119,7 +122,9 @@ async def make_call(
raise AnthropicError(status_code=500, message=str(e)) raise AnthropicError(status_code=500, message=str(e))
completion_stream = ModelResponseIterator( completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False streaming_response=response.aiter_lines(),
sync_stream=False,
json_mode=json_mode,
) )
# LOGGING # LOGGING
@ -142,6 +147,7 @@ def make_sync_call(
messages: list, messages: list,
logging_obj, logging_obj,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]: ) -> Tuple[Any, httpx.Headers]:
if client is None: if client is None:
client = litellm.module_level_client # re-use a module level client client = litellm.module_level_client # re-use a module level client
@ -175,7 +181,7 @@ def make_sync_call(
) )
completion_stream = ModelResponseIterator( completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode
) )
# LOGGING # LOGGING
@ -270,10 +276,10 @@ class AnthropicChatCompletion(BaseLLM):
"arguments" "arguments"
) )
if json_mode_content_str is not None: if json_mode_content_str is not None:
args = json.loads(json_mode_content_str) _message = self._convert_tool_response_to_message(
values: Optional[dict] = args.get("values") tool_calls=tool_calls,
if values is not None: )
_message = litellm.Message(content=json.dumps(values)) if _message is not None:
completion_response["stop_reason"] = "stop" completion_response["stop_reason"] = "stop"
model_response.choices[0].message = _message # type: ignore model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[ model_response._hidden_params["original_response"] = completion_response[
@ -318,6 +324,35 @@ class AnthropicChatCompletion(BaseLLM):
model_response._hidden_params = _hidden_params model_response._hidden_params = _hidden_params
return model_response return model_response
@staticmethod
def _convert_tool_response_to_message(
tool_calls: List[ChatCompletionToolCallChunk],
) -> Optional[LitellmMessage]:
"""
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
"""
## HANDLE JSON MODE - anthropic returns single function call
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
try:
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
values: Optional[dict] = args.get("values")
if values is not None:
_message = litellm.Message(content=json.dumps(values))
return _message
else:
# a lot of the times the `values` key is not present in the tool response
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
_message = litellm.Message(content=json.dumps(args))
return _message
except json.JSONDecodeError:
# json decode error does occur, return the original tool response str
return litellm.Message(content=json_mode_content_str)
return None
async def acompletion_stream_function( async def acompletion_stream_function(
self, self,
model: str, model: str,
@ -334,6 +369,7 @@ class AnthropicChatCompletion(BaseLLM):
stream, stream,
_is_function_call, _is_function_call,
data: dict, data: dict,
json_mode: bool,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -350,6 +386,7 @@ class AnthropicChatCompletion(BaseLLM):
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
timeout=timeout, timeout=timeout,
json_mode=json_mode,
) )
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -500,6 +537,7 @@ class AnthropicChatCompletion(BaseLLM):
optional_params=optional_params, optional_params=optional_params,
stream=stream, stream=stream,
_is_function_call=_is_function_call, _is_function_call=_is_function_call,
json_mode=json_mode,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
headers=headers, headers=headers,
@ -547,6 +585,7 @@ class AnthropicChatCompletion(BaseLLM):
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
timeout=timeout, timeout=timeout,
json_mode=json_mode,
) )
return CustomStreamWrapper( return CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -605,11 +644,12 @@ class AnthropicChatCompletion(BaseLLM):
class ModelResponseIterator: class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool): def __init__(self, streaming_response, sync_stream: bool, json_mode: bool):
self.streaming_response = streaming_response self.streaming_response = streaming_response
self.response_iterator = self.streaming_response self.response_iterator = self.streaming_response
self.content_blocks: List[ContentBlockDelta] = [] self.content_blocks: List[ContentBlockDelta] = []
self.tool_index = -1 self.tool_index = -1
self.json_mode = json_mode
def check_empty_tool_call_args(self) -> bool: def check_empty_tool_call_args(self) -> bool:
""" """
@ -771,6 +811,14 @@ class ModelResponseIterator:
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500 status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
) )
if self.json_mode is True and tool_use is not None:
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=[tool_use]
)
if message is not None:
text = message.content or ""
tool_use = None
returned_chunk = GenericStreamingChunk( returned_chunk = GenericStreamingChunk(
text=text, text=text,
tool_use=tool_use, tool_use=tool_use,