diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index afc265761..b011d9512 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -51,6 +51,7 @@ from litellm.types.llms.openai import ( ChatCompletionResponseMessage, ChatCompletionToolCallChunk, ChatCompletionToolCallFunctionChunk, + ChatCompletionDeltaChunk, ) @@ -1859,29 +1860,59 @@ class AWSEventStreamDecoder: self.parser = EventStreamJSONParser() def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: - text = "" - tool_str = "" - is_finished = False - finish_reason = "" - usage: Optional[ConverseTokenUsageBlock] = None - if "delta" in chunk_data: - delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"]) - if "text" in delta_obj: - text = delta_obj["text"] - elif "toolUse" in delta_obj: - tool_str = delta_obj["toolUse"]["input"] - elif "stopReason" in chunk_data: - finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) - elif "usage" in chunk_data: - usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore - response = GenericStreamingChunk( - text=text, - tool_str=tool_str, - is_finished=is_finished, - finish_reason=finish_reason, - usage=usage, - ) - return response + try: + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ConverseTokenUsageBlock] = None + + index = int(chunk_data.get("contentBlockIndex", 0)) + if "start" in chunk_data: + start_obj = ContentBlockStartEvent(**chunk_data["start"]) + if ( + start_obj is not None + and "toolUse" in start_obj + and start_obj["toolUse"] is not None + ): + tool_use = { + "id": start_obj["toolUse"]["toolUseId"], + "type": "function", + "function": { + "name": start_obj["toolUse"]["name"], + "arguments": "", + }, + } + elif "delta" in chunk_data: + delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"]) + if "text" in delta_obj: + text = delta_obj["text"] + elif "toolUse" in delta_obj: + tool_use = { + "id": None, + "type": "function", + "function": { + "name": None, + "arguments": delta_obj["toolUse"]["input"], + }, + } + elif "stopReason" in chunk_data: + finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) + is_finished = True + elif "usage" in chunk_data: + usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore + + response = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=index, + ) + return response + except Exception as e: + raise Exception("Received streaming error - {}".format(str(e))) def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: text = "" @@ -1890,12 +1921,16 @@ class AWSEventStreamDecoder: if "outputText" in chunk_data: text = chunk_data["outputText"] # ai21 mapping - if "ai21" in self.model: # fake ai21 streaming + elif "ai21" in self.model: # fake ai21 streaming text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore is_finished = True finish_reason = "stop" ######## bedrock.anthropic mappings ############### - elif "delta" in chunk_data: + elif ( + "contentBlockIndex" in chunk_data + or "stopReason" in chunk_data + or "metrics" in chunk_data + ): return self.converse_chunk_parser(chunk_data=chunk_data) ######## bedrock.mistral mappings ############### elif "outputs" in chunk_data: @@ -1905,7 +1940,7 @@ class AWSEventStreamDecoder: ): text = chunk_data["outputs"][0]["text"] stop_reason = chunk_data.get("stop_reason", None) - if stop_reason != None: + if stop_reason is not None: is_finished = True finish_reason = stop_reason ######## bedrock.cohere mappings ############### @@ -1926,8 +1961,9 @@ class AWSEventStreamDecoder: text=text, is_finished=is_finished, finish_reason=finish_reason, - tool_str="", usage=None, + index=0, + tool_use=None, ) def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]: diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index a5e098b02..ce69c0bf9 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -2535,7 +2535,10 @@ def streaming_and_function_calling_format_tests(idx, chunk): return extracted_chunk, finished -def test_openai_streaming_and_function_calling(): +@pytest.mark.parametrize( + "model", ["gpt-3.5-turbo", "anthropic.claude-3-sonnet-20240229-v1:0"] +) +def test_streaming_and_function_calling(model): tools = [ { "type": "function", @@ -2556,16 +2559,21 @@ def test_openai_streaming_and_function_calling(): }, } ] + messages = [{"role": "user", "content": "What is the weather like in Boston?"}] try: - response = completion( - model="gpt-3.5-turbo", + litellm.set_verbose = True + response: litellm.CustomStreamWrapper = completion( + model=model, tools=tools, messages=messages, stream=True, - ) + tool_choice="required", + ) # type: ignore # Add any assertions here to check the response for idx, chunk in enumerate(response): + # continue + print("\n{}\n".format(chunk)) if idx == 0: assert ( chunk.choices[0].delta.tool_calls[0].function.arguments is not None @@ -2573,6 +2581,7 @@ def test_openai_streaming_and_function_calling(): assert isinstance( chunk.choices[0].delta.tool_calls[0].function.arguments, str ) + # assert False except Exception as e: pytest.fail(f"Error occurred: {e}") raise e diff --git a/litellm/tests/test_text_completion.py b/litellm/tests/test_text_completion.py index 6a093af23..65d5bcac2 100644 --- a/litellm/tests/test_text_completion.py +++ b/litellm/tests/test_text_completion.py @@ -3990,6 +3990,7 @@ def test_async_text_completion(): asyncio.run(test_get_response()) +@pytest.mark.skip(reason="Tgai endpoints are unstable") def test_async_text_completion_together_ai(): litellm.set_verbose = True print("test_async_text_completion") diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index b06075092..95ebc9742 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -1,5 +1,6 @@ from typing import TypedDict, Any, Union, Optional, Literal, List import json +from .openai import ChatCompletionToolCallChunk from typing_extensions import ( Self, Protocol, @@ -118,6 +119,15 @@ class ToolBlockDeltaEvent(TypedDict): input: str +class ToolUseBlockStartEvent(TypedDict): + name: str + toolUseId: str + + +class ContentBlockStartEvent(TypedDict, total=False): + toolUse: Optional[ToolUseBlockStartEvent] + + class ContentBlockDeltaEvent(TypedDict, total=False): """ Either 'text' or 'toolUse' will be specified for Converse API streaming response. @@ -138,10 +148,11 @@ class RequestObject(TypedDict, total=False): class GenericStreamingChunk(TypedDict): text: Required[str] - tool_str: Required[str] + tool_use: Optional[ChatCompletionToolCallChunk] is_finished: Required[bool] finish_reason: Required[str] usage: Optional[ConverseTokenUsageBlock] + index: int class Document(TypedDict): diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 7861e394c..66aec4906 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -296,14 +296,27 @@ class ListBatchRequest(TypedDict, total=False): class ChatCompletionToolCallFunctionChunk(TypedDict): - name: str + name: Optional[str] arguments: str class ChatCompletionToolCallChunk(TypedDict): + id: Optional[str] + type: Literal["function"] + function: ChatCompletionToolCallFunctionChunk + + +class ChatCompletionDeltaToolCallChunk(TypedDict): id: str type: Literal["function"] function: ChatCompletionToolCallFunctionChunk + index: int + + +class ChatCompletionDeltaChunk(TypedDict, total=False): + content: Optional[str] + tool_calls: List[ChatCompletionDeltaToolCallChunk] + role: str class ChatCompletionResponseMessage(TypedDict, total=False): diff --git a/litellm/utils.py b/litellm/utils.py index ed405419b..dbaa3d093 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -63,6 +63,11 @@ claude_json_str = json.dumps(json_data) import importlib.metadata from ._logging import verbose_logger from .types.router import LiteLLM_Params +from .types.llms.openai import ( + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, + ChatCompletionDeltaToolCallChunk, +) from .integrations.traceloop import TraceloopLogger from .integrations.athina import AthinaLogger from .integrations.helicone import HeliconeLogger @@ -3250,7 +3255,7 @@ def client(original_function): stream=kwargs.get("stream", False), ) - if kwargs.get("stream", False) == True: + if kwargs.get("stream", False) is True: cached_result = CustomStreamWrapper( completion_stream=cached_result, model=model, @@ -11301,7 +11306,6 @@ class CustomStreamWrapper: raise StopIteration response_obj: GenericStreamingChunk = chunk completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] @@ -11316,6 +11320,10 @@ class CustomStreamWrapper: completion_tokens=response_obj["usage"]["outputTokens"], total_tokens=response_obj["usage"]["totalTokens"], ) + + if "tool_use" in response_obj and response_obj["tool_use"] is not None: + completion_obj["tool_calls"] = [response_obj["tool_use"]] + elif self.custom_llm_provider == "sagemaker": print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") response_obj = self.handle_sagemaker_stream(chunk) @@ -11332,7 +11340,6 @@ class CustomStreamWrapper: new_chunk = self.completion_stream[:chunk_size] completion_obj["content"] = new_chunk self.completion_stream = self.completion_stream[chunk_size:] - time.sleep(0.05) elif self.custom_llm_provider == "palm": # fake streaming response_obj = {} @@ -11345,7 +11352,6 @@ class CustomStreamWrapper: new_chunk = self.completion_stream[:chunk_size] completion_obj["content"] = new_chunk self.completion_stream = self.completion_stream[chunk_size:] - time.sleep(0.05) elif self.custom_llm_provider == "ollama": response_obj = self.handle_ollama_stream(chunk) completion_obj["content"] = response_obj["text"] @@ -11432,7 +11438,7 @@ class CustomStreamWrapper: # for azure, we need to pass the model from the orignal chunk self.model = chunk.model response_obj = self.handle_openai_chat_completion_chunk(chunk) - if response_obj == None: + if response_obj is None: return completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") @@ -11565,7 +11571,7 @@ class CustomStreamWrapper: else: if ( self.stream_options is not None - and self.stream_options["include_usage"] == True + and self.stream_options["include_usage"] is True ): return model_response return @@ -11590,8 +11596,14 @@ class CustomStreamWrapper: return model_response elif ( "content" in completion_obj - and isinstance(completion_obj["content"], str) - and len(completion_obj["content"]) > 0 + and ( + isinstance(completion_obj["content"], str) + and len(completion_obj["content"]) > 0 + ) + or ( + "tool_calls" in completion_obj + and len(completion_obj["tool_calls"]) > 0 + ) ): # cannot set content of an OpenAI Object to be an empty string hold, model_response_str = self.check_special_tokens( chunk=completion_obj["content"], @@ -11647,7 +11659,7 @@ class CustomStreamWrapper: else: ## else completion_obj["content"] = model_response_str - if self.sent_first_chunk == False: + if self.sent_first_chunk is False: completion_obj["role"] = "assistant" self.sent_first_chunk = True model_response.choices[0].delta = Delta(**completion_obj) @@ -11656,7 +11668,7 @@ class CustomStreamWrapper: else: return elif self.received_finish_reason is not None: - if self.sent_last_chunk == True: + if self.sent_last_chunk is True: raise StopIteration # flush any remaining holding chunk if len(self.holding_chunk) > 0: