_convert_tool_response_to_message

This commit is contained in:
Ishaan Jaff 2024-11-14 10:51:04 -08:00
parent da84056e59
commit c3a2c77b55

View file

@ -44,7 +44,9 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper
from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import Message as LitellmMessage
from litellm.types.utils import PromptTokensDetailsWrapper
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ...base import BaseLLM
@ -94,6 +96,7 @@ async def make_call(
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_aclient
@ -119,7 +122,9 @@ async def make_call(
raise AnthropicError(status_code=500, message=str(e))
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
streaming_response=response.aiter_lines(),
sync_stream=False,
json_mode=json_mode,
)
# LOGGING
@ -142,6 +147,7 @@ def make_sync_call(
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_client # re-use a module level client
@ -175,7 +181,7 @@ def make_sync_call(
)
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode
)
# LOGGING
@ -270,10 +276,10 @@ class AnthropicChatCompletion(BaseLLM):
"arguments"
)
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
values: Optional[dict] = args.get("values")
if values is not None:
_message = litellm.Message(content=json.dumps(values))
_message = self._convert_tool_response_to_message(
tool_calls=tool_calls,
)
if _message is not None:
completion_response["stop_reason"] = "stop"
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
@ -318,6 +324,35 @@ class AnthropicChatCompletion(BaseLLM):
model_response._hidden_params = _hidden_params
return model_response
@staticmethod
def _convert_tool_response_to_message(
tool_calls: List[ChatCompletionToolCallChunk],
) -> Optional[LitellmMessage]:
"""
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
"""
## HANDLE JSON MODE - anthropic returns single function call
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
try:
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
values: Optional[dict] = args.get("values")
if values is not None:
_message = litellm.Message(content=json.dumps(values))
return _message
else:
# a lot of the times the `values` key is not present in the tool response
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
_message = litellm.Message(content=json.dumps(args))
return _message
except json.JSONDecodeError:
# json decode error does occur, return the original tool response str
return litellm.Message(content=json_mode_content_str)
return None
async def acompletion_stream_function(
self,
model: str,
@ -334,6 +369,7 @@ class AnthropicChatCompletion(BaseLLM):
stream,
_is_function_call,
data: dict,
json_mode: bool,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -350,6 +386,7 @@ class AnthropicChatCompletion(BaseLLM):
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
json_mode=json_mode,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
@ -500,6 +537,7 @@ class AnthropicChatCompletion(BaseLLM):
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
json_mode=json_mode,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
@ -547,6 +585,7 @@ class AnthropicChatCompletion(BaseLLM):
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
json_mode=json_mode,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
@ -605,11 +644,12 @@ class AnthropicChatCompletion(BaseLLM):
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
def __init__(self, streaming_response, sync_stream: bool, json_mode: bool):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List[ContentBlockDelta] = []
self.tool_index = -1
self.json_mode = json_mode
def check_empty_tool_call_args(self) -> bool:
"""
@ -771,6 +811,14 @@ class ModelResponseIterator:
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
)
if self.json_mode is True and tool_use is not None:
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=[tool_use]
)
if message is not None:
text = message.content or ""
tool_use = None
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,