From 84652dd946ab0e2bf285dab3132b0b643ea6cbc2 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 10 Jun 2024 14:20:25 -0700 Subject: [PATCH] fix(bedrock_httpx.py): fix tool calling for anthropic bedrock calls w/ streaming Fixes https://github.com/BerriAI/litellm/issues/4091 --- litellm/llms/bedrock_httpx.py | 84 ++++++++++++++++++++++----------- litellm/tests/test_streaming.py | 18 +++++-- litellm/types/llms/bedrock.py | 13 ++++- litellm/types/llms/openai.py | 15 +++++- litellm/utils.py | 27 ++++++++--- 5 files changed, 117 insertions(+), 40 deletions(-) diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index afc265761..336cbd3bb 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -51,6 +51,7 @@ from litellm.types.llms.openai import ( ChatCompletionResponseMessage, ChatCompletionToolCallChunk, ChatCompletionToolCallFunctionChunk, + ChatCompletionDeltaChunk, ) @@ -1859,29 +1860,57 @@ class AWSEventStreamDecoder: self.parser = EventStreamJSONParser() def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: - text = "" - tool_str = "" - is_finished = False - finish_reason = "" - usage: Optional[ConverseTokenUsageBlock] = None - if "delta" in chunk_data: - delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"]) - if "text" in delta_obj: - text = delta_obj["text"] - elif "toolUse" in delta_obj: - tool_str = delta_obj["toolUse"]["input"] - elif "stopReason" in chunk_data: - finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) - elif "usage" in chunk_data: - usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore - response = GenericStreamingChunk( - text=text, - tool_str=tool_str, - is_finished=is_finished, - finish_reason=finish_reason, - usage=usage, - ) - return response + try: + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ConverseTokenUsageBlock] = None + + index = int(chunk_data.get("contentBlockIndex", 0)) + if "start" in chunk_data: + start_obj = ContentBlockStartEvent(**chunk_data["start"]) + if ( + start_obj is not None + and "toolUse" in start_obj + and start_obj["toolUse"] is not None + ): + tool_use = { + "id": start_obj["toolUse"]["toolUseId"], + "type": "function", + "function": { + "name": start_obj["toolUse"]["name"], + "arguments": "", + }, + } + elif "delta" in chunk_data: + delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"]) + if "text" in delta_obj: + text = delta_obj["text"] + elif "toolUse" in delta_obj: + tool_use = { + "id": None, + "type": "function", + "function": { + "name": None, + "arguments": delta_obj["toolUse"]["input"], + }, + } + elif "stopReason" in chunk_data: + finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) + elif "usage" in chunk_data: + usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore + response = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=index, + ) + return response + except Exception as e: + raise Exception("Received streaming error - {}".format(str(e))) def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: text = "" @@ -1890,12 +1919,12 @@ class AWSEventStreamDecoder: if "outputText" in chunk_data: text = chunk_data["outputText"] # ai21 mapping - if "ai21" in self.model: # fake ai21 streaming + elif "ai21" in self.model: # fake ai21 streaming text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore is_finished = True finish_reason = "stop" ######## bedrock.anthropic mappings ############### - elif "delta" in chunk_data: + elif "contentBlockIndex" in chunk_data: return self.converse_chunk_parser(chunk_data=chunk_data) ######## bedrock.mistral mappings ############### elif "outputs" in chunk_data: @@ -1905,7 +1934,7 @@ class AWSEventStreamDecoder: ): text = chunk_data["outputs"][0]["text"] stop_reason = chunk_data.get("stop_reason", None) - if stop_reason != None: + if stop_reason is not None: is_finished = True finish_reason = stop_reason ######## bedrock.cohere mappings ############### @@ -1926,8 +1955,9 @@ class AWSEventStreamDecoder: text=text, is_finished=is_finished, finish_reason=finish_reason, - tool_str="", usage=None, + index=0, + tool_use=None, ) def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]: diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index a5e098b02..ee9f1cec6 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -2535,7 +2535,10 @@ def streaming_and_function_calling_format_tests(idx, chunk): return extracted_chunk, finished -def test_openai_streaming_and_function_calling(): +@pytest.mark.parametrize( + "model", ["gpt-3.5-turbo", "anthropic.claude-3-sonnet-20240229-v1:0"] +) +def test_streaming_and_function_calling(model): tools = [ { "type": "function", @@ -2556,23 +2559,30 @@ def test_openai_streaming_and_function_calling(): }, } ] + messages = [{"role": "user", "content": "What is the weather like in Boston?"}] try: - response = completion( - model="gpt-3.5-turbo", + litellm.set_verbose = True + response: litellm.CustomStreamWrapper = completion( + model=model, tools=tools, messages=messages, stream=True, - ) + tool_choice="required", + ) # type: ignore # Add any assertions here to check the response for idx, chunk in enumerate(response): + # continue + # print("\n{}\n".format(chunk)) if idx == 0: + print(chunk) assert ( chunk.choices[0].delta.tool_calls[0].function.arguments is not None ) assert isinstance( chunk.choices[0].delta.tool_calls[0].function.arguments, str ) + # assert False except Exception as e: pytest.fail(f"Error occurred: {e}") raise e diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index b06075092..95ebc9742 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -1,5 +1,6 @@ from typing import TypedDict, Any, Union, Optional, Literal, List import json +from .openai import ChatCompletionToolCallChunk from typing_extensions import ( Self, Protocol, @@ -118,6 +119,15 @@ class ToolBlockDeltaEvent(TypedDict): input: str +class ToolUseBlockStartEvent(TypedDict): + name: str + toolUseId: str + + +class ContentBlockStartEvent(TypedDict, total=False): + toolUse: Optional[ToolUseBlockStartEvent] + + class ContentBlockDeltaEvent(TypedDict, total=False): """ Either 'text' or 'toolUse' will be specified for Converse API streaming response. @@ -138,10 +148,11 @@ class RequestObject(TypedDict, total=False): class GenericStreamingChunk(TypedDict): text: Required[str] - tool_str: Required[str] + tool_use: Optional[ChatCompletionToolCallChunk] is_finished: Required[bool] finish_reason: Required[str] usage: Optional[ConverseTokenUsageBlock] + index: int class Document(TypedDict): diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 7861e394c..66aec4906 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -296,14 +296,27 @@ class ListBatchRequest(TypedDict, total=False): class ChatCompletionToolCallFunctionChunk(TypedDict): - name: str + name: Optional[str] arguments: str class ChatCompletionToolCallChunk(TypedDict): + id: Optional[str] + type: Literal["function"] + function: ChatCompletionToolCallFunctionChunk + + +class ChatCompletionDeltaToolCallChunk(TypedDict): id: str type: Literal["function"] function: ChatCompletionToolCallFunctionChunk + index: int + + +class ChatCompletionDeltaChunk(TypedDict, total=False): + content: Optional[str] + tool_calls: List[ChatCompletionDeltaToolCallChunk] + role: str class ChatCompletionResponseMessage(TypedDict, total=False): diff --git a/litellm/utils.py b/litellm/utils.py index 410f9ad88..fc6ac9fec 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -63,6 +63,11 @@ claude_json_str = json.dumps(json_data) import importlib.metadata from ._logging import verbose_logger from .types.router import LiteLLM_Params +from .types.llms.openai import ( + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, + ChatCompletionDeltaToolCallChunk, +) from .integrations.traceloop import TraceloopLogger from .integrations.athina import AthinaLogger from .integrations.helicone import HeliconeLogger @@ -3218,7 +3223,7 @@ def client(original_function): stream=kwargs.get("stream", False), ) - if kwargs.get("stream", False) == True: + if kwargs.get("stream", False) is True: cached_result = CustomStreamWrapper( completion_stream=cached_result, model=model, @@ -11468,6 +11473,10 @@ class CustomStreamWrapper: completion_tokens=response_obj["usage"]["outputTokens"], total_tokens=response_obj["usage"]["totalTokens"], ) + + if "tool_use" in response_obj and response_obj["tool_use"] is not None: + completion_obj["tool_calls"] = [response_obj["tool_use"]] + elif self.custom_llm_provider == "sagemaker": print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") response_obj = self.handle_sagemaker_stream(chunk) @@ -11484,7 +11493,6 @@ class CustomStreamWrapper: new_chunk = self.completion_stream[:chunk_size] completion_obj["content"] = new_chunk self.completion_stream = self.completion_stream[chunk_size:] - time.sleep(0.05) elif self.custom_llm_provider == "palm": # fake streaming response_obj = {} @@ -11497,7 +11505,6 @@ class CustomStreamWrapper: new_chunk = self.completion_stream[:chunk_size] completion_obj["content"] = new_chunk self.completion_stream = self.completion_stream[chunk_size:] - time.sleep(0.05) elif self.custom_llm_provider == "ollama": response_obj = self.handle_ollama_stream(chunk) completion_obj["content"] = response_obj["text"] @@ -11717,7 +11724,7 @@ class CustomStreamWrapper: else: if ( self.stream_options is not None - and self.stream_options["include_usage"] == True + and self.stream_options["include_usage"] is True ): return model_response return @@ -11742,8 +11749,14 @@ class CustomStreamWrapper: return model_response elif ( "content" in completion_obj - and isinstance(completion_obj["content"], str) - and len(completion_obj["content"]) > 0 + and ( + isinstance(completion_obj["content"], str) + and len(completion_obj["content"]) > 0 + ) + or ( + "tool_calls" in completion_obj + and len(completion_obj["tool_calls"]) > 0 + ) ): # cannot set content of an OpenAI Object to be an empty string hold, model_response_str = self.check_special_tokens( chunk=completion_obj["content"], @@ -11799,7 +11812,7 @@ class CustomStreamWrapper: else: ## else completion_obj["content"] = model_response_str - if self.sent_first_chunk == False: + if self.sent_first_chunk is False: completion_obj["role"] = "assistant" self.sent_first_chunk = True model_response.choices[0].delta = Delta(**completion_obj)