fix(bedrock_httpx.py): fix tool calling for anthropic bedrock calls w/ streaming

Fixes https://github.com/BerriAI/litellm/issues/4091
This commit is contained in:
Krrish Dholakia 2024-06-10 14:20:25 -07:00
parent 6306914e56
commit 84652dd946
5 changed files with 117 additions and 40 deletions

View file

@ -51,6 +51,7 @@ from litellm.types.llms.openai import (
ChatCompletionResponseMessage, ChatCompletionResponseMessage,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaChunk,
) )
@ -1859,29 +1860,57 @@ class AWSEventStreamDecoder:
self.parser = EventStreamJSONParser() self.parser = EventStreamJSONParser()
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
try:
text = "" text = ""
tool_str = "" tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
usage: Optional[ConverseTokenUsageBlock] = None usage: Optional[ConverseTokenUsageBlock] = None
if "delta" in chunk_data:
index = int(chunk_data.get("contentBlockIndex", 0))
if "start" in chunk_data:
start_obj = ContentBlockStartEvent(**chunk_data["start"])
if (
start_obj is not None
and "toolUse" in start_obj
and start_obj["toolUse"] is not None
):
tool_use = {
"id": start_obj["toolUse"]["toolUseId"],
"type": "function",
"function": {
"name": start_obj["toolUse"]["name"],
"arguments": "",
},
}
elif "delta" in chunk_data:
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"]) delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
if "text" in delta_obj: if "text" in delta_obj:
text = delta_obj["text"] text = delta_obj["text"]
elif "toolUse" in delta_obj: elif "toolUse" in delta_obj:
tool_str = delta_obj["toolUse"]["input"] tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": delta_obj["toolUse"]["input"],
},
}
elif "stopReason" in chunk_data: elif "stopReason" in chunk_data:
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
elif "usage" in chunk_data: elif "usage" in chunk_data:
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
response = GenericStreamingChunk( response = GenericStreamingChunk(
text=text, text=text,
tool_str=tool_str, tool_use=tool_use,
is_finished=is_finished, is_finished=is_finished,
finish_reason=finish_reason, finish_reason=finish_reason,
usage=usage, usage=usage,
index=index,
) )
return response return response
except Exception as e:
raise Exception("Received streaming error - {}".format(str(e)))
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = "" text = ""
@ -1890,12 +1919,12 @@ class AWSEventStreamDecoder:
if "outputText" in chunk_data: if "outputText" in chunk_data:
text = chunk_data["outputText"] text = chunk_data["outputText"]
# ai21 mapping # ai21 mapping
if "ai21" in self.model: # fake ai21 streaming elif "ai21" in self.model: # fake ai21 streaming
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
is_finished = True is_finished = True
finish_reason = "stop" finish_reason = "stop"
######## bedrock.anthropic mappings ############### ######## bedrock.anthropic mappings ###############
elif "delta" in chunk_data: elif "contentBlockIndex" in chunk_data:
return self.converse_chunk_parser(chunk_data=chunk_data) return self.converse_chunk_parser(chunk_data=chunk_data)
######## bedrock.mistral mappings ############### ######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data: elif "outputs" in chunk_data:
@ -1905,7 +1934,7 @@ class AWSEventStreamDecoder:
): ):
text = chunk_data["outputs"][0]["text"] text = chunk_data["outputs"][0]["text"]
stop_reason = chunk_data.get("stop_reason", None) stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None: if stop_reason is not None:
is_finished = True is_finished = True
finish_reason = stop_reason finish_reason = stop_reason
######## bedrock.cohere mappings ############### ######## bedrock.cohere mappings ###############
@ -1926,8 +1955,9 @@ class AWSEventStreamDecoder:
text=text, text=text,
is_finished=is_finished, is_finished=is_finished,
finish_reason=finish_reason, finish_reason=finish_reason,
tool_str="",
usage=None, usage=None,
index=0,
tool_use=None,
) )
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]: def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:

View file

@ -2535,7 +2535,10 @@ def streaming_and_function_calling_format_tests(idx, chunk):
return extracted_chunk, finished return extracted_chunk, finished
def test_openai_streaming_and_function_calling(): @pytest.mark.parametrize(
"model", ["gpt-3.5-turbo", "anthropic.claude-3-sonnet-20240229-v1:0"]
)
def test_streaming_and_function_calling(model):
tools = [ tools = [
{ {
"type": "function", "type": "function",
@ -2556,23 +2559,30 @@ def test_openai_streaming_and_function_calling():
}, },
} }
] ]
messages = [{"role": "user", "content": "What is the weather like in Boston?"}] messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
try: try:
response = completion( litellm.set_verbose = True
model="gpt-3.5-turbo", response: litellm.CustomStreamWrapper = completion(
model=model,
tools=tools, tools=tools,
messages=messages, messages=messages,
stream=True, stream=True,
) tool_choice="required",
) # type: ignore
# Add any assertions here to check the response # Add any assertions here to check the response
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
# continue
# print("\n{}\n".format(chunk))
if idx == 0: if idx == 0:
print(chunk)
assert ( assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None chunk.choices[0].delta.tool_calls[0].function.arguments is not None
) )
assert isinstance( assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str chunk.choices[0].delta.tool_calls[0].function.arguments, str
) )
# assert False
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
raise e raise e

View file

@ -1,5 +1,6 @@
from typing import TypedDict, Any, Union, Optional, Literal, List from typing import TypedDict, Any, Union, Optional, Literal, List
import json import json
from .openai import ChatCompletionToolCallChunk
from typing_extensions import ( from typing_extensions import (
Self, Self,
Protocol, Protocol,
@ -118,6 +119,15 @@ class ToolBlockDeltaEvent(TypedDict):
input: str input: str
class ToolUseBlockStartEvent(TypedDict):
name: str
toolUseId: str
class ContentBlockStartEvent(TypedDict, total=False):
toolUse: Optional[ToolUseBlockStartEvent]
class ContentBlockDeltaEvent(TypedDict, total=False): class ContentBlockDeltaEvent(TypedDict, total=False):
""" """
Either 'text' or 'toolUse' will be specified for Converse API streaming response. Either 'text' or 'toolUse' will be specified for Converse API streaming response.
@ -138,10 +148,11 @@ class RequestObject(TypedDict, total=False):
class GenericStreamingChunk(TypedDict): class GenericStreamingChunk(TypedDict):
text: Required[str] text: Required[str]
tool_str: Required[str] tool_use: Optional[ChatCompletionToolCallChunk]
is_finished: Required[bool] is_finished: Required[bool]
finish_reason: Required[str] finish_reason: Required[str]
usage: Optional[ConverseTokenUsageBlock] usage: Optional[ConverseTokenUsageBlock]
index: int
class Document(TypedDict): class Document(TypedDict):

View file

@ -296,14 +296,27 @@ class ListBatchRequest(TypedDict, total=False):
class ChatCompletionToolCallFunctionChunk(TypedDict): class ChatCompletionToolCallFunctionChunk(TypedDict):
name: str name: Optional[str]
arguments: str arguments: str
class ChatCompletionToolCallChunk(TypedDict): class ChatCompletionToolCallChunk(TypedDict):
id: Optional[str]
type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk
class ChatCompletionDeltaToolCallChunk(TypedDict):
id: str id: str
type: Literal["function"] type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk function: ChatCompletionToolCallFunctionChunk
index: int
class ChatCompletionDeltaChunk(TypedDict, total=False):
content: Optional[str]
tool_calls: List[ChatCompletionDeltaToolCallChunk]
role: str
class ChatCompletionResponseMessage(TypedDict, total=False): class ChatCompletionResponseMessage(TypedDict, total=False):

View file

@ -63,6 +63,11 @@ claude_json_str = json.dumps(json_data)
import importlib.metadata import importlib.metadata
from ._logging import verbose_logger from ._logging import verbose_logger
from .types.router import LiteLLM_Params from .types.router import LiteLLM_Params
from .types.llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaToolCallChunk,
)
from .integrations.traceloop import TraceloopLogger from .integrations.traceloop import TraceloopLogger
from .integrations.athina import AthinaLogger from .integrations.athina import AthinaLogger
from .integrations.helicone import HeliconeLogger from .integrations.helicone import HeliconeLogger
@ -3218,7 +3223,7 @@ def client(original_function):
stream=kwargs.get("stream", False), stream=kwargs.get("stream", False),
) )
if kwargs.get("stream", False) == True: if kwargs.get("stream", False) is True:
cached_result = CustomStreamWrapper( cached_result = CustomStreamWrapper(
completion_stream=cached_result, completion_stream=cached_result,
model=model, model=model,
@ -11468,6 +11473,10 @@ class CustomStreamWrapper:
completion_tokens=response_obj["usage"]["outputTokens"], completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"], total_tokens=response_obj["usage"]["totalTokens"],
) )
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker": elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk) response_obj = self.handle_sagemaker_stream(chunk)
@ -11484,7 +11493,6 @@ class CustomStreamWrapper:
new_chunk = self.completion_stream[:chunk_size] new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:] self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "palm": elif self.custom_llm_provider == "palm":
# fake streaming # fake streaming
response_obj = {} response_obj = {}
@ -11497,7 +11505,6 @@ class CustomStreamWrapper:
new_chunk = self.completion_stream[:chunk_size] new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:] self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "ollama": elif self.custom_llm_provider == "ollama":
response_obj = self.handle_ollama_stream(chunk) response_obj = self.handle_ollama_stream(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
@ -11717,7 +11724,7 @@ class CustomStreamWrapper:
else: else:
if ( if (
self.stream_options is not None self.stream_options is not None
and self.stream_options["include_usage"] == True and self.stream_options["include_usage"] is True
): ):
return model_response return model_response
return return
@ -11742,8 +11749,14 @@ class CustomStreamWrapper:
return model_response return model_response
elif ( elif (
"content" in completion_obj "content" in completion_obj
and isinstance(completion_obj["content"], str) and (
isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) > 0 and len(completion_obj["content"]) > 0
)
or (
"tool_calls" in completion_obj
and len(completion_obj["tool_calls"]) > 0
)
): # cannot set content of an OpenAI Object to be an empty string ): # cannot set content of an OpenAI Object to be an empty string
hold, model_response_str = self.check_special_tokens( hold, model_response_str = self.check_special_tokens(
chunk=completion_obj["content"], chunk=completion_obj["content"],
@ -11799,7 +11812,7 @@ class CustomStreamWrapper:
else: else:
## else ## else
completion_obj["content"] = model_response_str completion_obj["content"] = model_response_str
if self.sent_first_chunk == False: if self.sent_first_chunk is False:
completion_obj["role"] = "assistant" completion_obj["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj) model_response.choices[0].delta = Delta(**completion_obj)