forked from phoenix/litellm-mirror
_convert_tool_response_to_message
This commit is contained in:
parent
da84056e59
commit
c3a2c77b55
1 changed files with 56 additions and 8 deletions
|
@ -44,7 +44,9 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper
|
from litellm.types.utils import GenericStreamingChunk
|
||||||
|
from litellm.types.utils import Message as LitellmMessage
|
||||||
|
from litellm.types.utils import PromptTokensDetailsWrapper
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
from ...base import BaseLLM
|
from ...base import BaseLLM
|
||||||
|
@ -94,6 +96,7 @@ async def make_call(
|
||||||
messages: list,
|
messages: list,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
json_mode: bool,
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
) -> Tuple[Any, httpx.Headers]:
|
||||||
if client is None:
|
if client is None:
|
||||||
client = litellm.module_level_aclient
|
client = litellm.module_level_aclient
|
||||||
|
@ -119,7 +122,9 @@ async def make_call(
|
||||||
raise AnthropicError(status_code=500, message=str(e))
|
raise AnthropicError(status_code=500, message=str(e))
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.aiter_lines(), sync_stream=False
|
streaming_response=response.aiter_lines(),
|
||||||
|
sync_stream=False,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
|
@ -142,6 +147,7 @@ def make_sync_call(
|
||||||
messages: list,
|
messages: list,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
json_mode: bool,
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
) -> Tuple[Any, httpx.Headers]:
|
||||||
if client is None:
|
if client is None:
|
||||||
client = litellm.module_level_client # re-use a module level client
|
client = litellm.module_level_client # re-use a module level client
|
||||||
|
@ -175,7 +181,7 @@ def make_sync_call(
|
||||||
)
|
)
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.iter_lines(), sync_stream=True
|
streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
|
@ -270,10 +276,10 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
"arguments"
|
"arguments"
|
||||||
)
|
)
|
||||||
if json_mode_content_str is not None:
|
if json_mode_content_str is not None:
|
||||||
args = json.loads(json_mode_content_str)
|
_message = self._convert_tool_response_to_message(
|
||||||
values: Optional[dict] = args.get("values")
|
tool_calls=tool_calls,
|
||||||
if values is not None:
|
)
|
||||||
_message = litellm.Message(content=json.dumps(values))
|
if _message is not None:
|
||||||
completion_response["stop_reason"] = "stop"
|
completion_response["stop_reason"] = "stop"
|
||||||
model_response.choices[0].message = _message # type: ignore
|
model_response.choices[0].message = _message # type: ignore
|
||||||
model_response._hidden_params["original_response"] = completion_response[
|
model_response._hidden_params["original_response"] = completion_response[
|
||||||
|
@ -318,6 +324,35 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
model_response._hidden_params = _hidden_params
|
model_response._hidden_params = _hidden_params
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_tool_response_to_message(
|
||||||
|
tool_calls: List[ChatCompletionToolCallChunk],
|
||||||
|
) -> Optional[LitellmMessage]:
|
||||||
|
"""
|
||||||
|
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
|
||||||
|
|
||||||
|
"""
|
||||||
|
## HANDLE JSON MODE - anthropic returns single function call
|
||||||
|
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
|
||||||
|
"arguments"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if json_mode_content_str is not None:
|
||||||
|
args = json.loads(json_mode_content_str)
|
||||||
|
values: Optional[dict] = args.get("values")
|
||||||
|
if values is not None:
|
||||||
|
_message = litellm.Message(content=json.dumps(values))
|
||||||
|
return _message
|
||||||
|
else:
|
||||||
|
# a lot of the times the `values` key is not present in the tool response
|
||||||
|
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||||
|
_message = litellm.Message(content=json.dumps(args))
|
||||||
|
return _message
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# json decode error does occur, return the original tool response str
|
||||||
|
return litellm.Message(content=json_mode_content_str)
|
||||||
|
return None
|
||||||
|
|
||||||
async def acompletion_stream_function(
|
async def acompletion_stream_function(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -334,6 +369,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
stream,
|
stream,
|
||||||
_is_function_call,
|
_is_function_call,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
json_mode: bool,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -350,6 +386,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -500,6 +537,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
_is_function_call=_is_function_call,
|
_is_function_call=_is_function_call,
|
||||||
|
json_mode=json_mode,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
@ -547,6 +585,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
return CustomStreamWrapper(
|
return CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -605,11 +644,12 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseIterator:
|
class ModelResponseIterator:
|
||||||
def __init__(self, streaming_response, sync_stream: bool):
|
def __init__(self, streaming_response, sync_stream: bool, json_mode: bool):
|
||||||
self.streaming_response = streaming_response
|
self.streaming_response = streaming_response
|
||||||
self.response_iterator = self.streaming_response
|
self.response_iterator = self.streaming_response
|
||||||
self.content_blocks: List[ContentBlockDelta] = []
|
self.content_blocks: List[ContentBlockDelta] = []
|
||||||
self.tool_index = -1
|
self.tool_index = -1
|
||||||
|
self.json_mode = json_mode
|
||||||
|
|
||||||
def check_empty_tool_call_args(self) -> bool:
|
def check_empty_tool_call_args(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -771,6 +811,14 @@ class ModelResponseIterator:
|
||||||
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
|
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.json_mode is True and tool_use is not None:
|
||||||
|
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||||
|
tool_calls=[tool_use]
|
||||||
|
)
|
||||||
|
if message is not None:
|
||||||
|
text = message.content or ""
|
||||||
|
tool_use = None
|
||||||
|
|
||||||
returned_chunk = GenericStreamingChunk(
|
returned_chunk = GenericStreamingChunk(
|
||||||
text=text,
|
text=text,
|
||||||
tool_use=tool_use,
|
tool_use=tool_use,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue