diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 0f52d3abc..8f91ecc26 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -11,6 +11,7 @@ from .prompt_templates.factory import ( construct_tool_use_system_prompt, extract_between_tags, parse_xml_params, + contains_tag, ) import httpx @@ -679,6 +680,7 @@ def completion( timeout=None, ): exception_mapping_worked = False + _is_function_call = False try: # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) @@ -727,8 +729,10 @@ def completion( system_messages.append(message["content"]) system_prompt_idx.append(idx) if len(system_prompt_idx) > 0: - inference_params["system"] = '\n'.join(system_messages) - messages = [i for j, i in enumerate(messages) if j not in system_prompt_idx] + inference_params["system"] = "\n".join(system_messages) + messages = [ + i for j, i in enumerate(messages) if j not in system_prompt_idx + ] # Format rest of message according to anthropic guidelines messages = prompt_factory( model=model, messages=messages, custom_llm_provider="anthropic" @@ -742,6 +746,7 @@ def completion( inference_params[k] = v ## Handle Tool Calling if "tools" in inference_params: + _is_function_call = True tool_calling_system_prompt = construct_tool_use_system_prompt( tools=inference_params["tools"] ) @@ -823,7 +828,7 @@ def completion( ## COMPLETION CALL accept = "application/json" contentType = "application/json" - if stream == True: + if stream == True and _is_function_call == False: if provider == "ai21": ## LOGGING request_str = f""" @@ -918,7 +923,9 @@ def completion( elif provider == "anthropic": if model.startswith("anthropic.claude-3"): outputText = response_body.get("content")[0].get("text", None) - if "" in outputText: # OUTPUT PARSE FUNCTION CALL + if outputText is not None and contains_tag( + "invoke", outputText + ): # OUTPUT PARSE FUNCTION CALL function_name = extract_between_tags("tool_name", outputText)[0] function_arguments_str = extract_between_tags("invoke", outputText)[ 0 @@ -941,6 +948,56 @@ def completion( content=None, ) model_response.choices[0].message = _message # type: ignore + if _is_function_call == True and stream is not None and stream == True: + print_verbose( + f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK" + ) + # return an iterator + streaming_model_response = ModelResponse(stream=True) + streaming_model_response.choices[0].finish_reason = ( + model_response.choices[0].finish_reason + ) + # streaming_model_response.choices = [litellm.utils.StreamingChoices()] + streaming_choice = litellm.utils.StreamingChoices() + streaming_choice.index = model_response.choices[0].index + _tool_calls = [] + print_verbose( + f"type of model_response.choices[0]: {type(model_response.choices[0])}" + ) + print_verbose(f"type of streaming_choice: {type(streaming_choice)}") + if isinstance(model_response.choices[0], litellm.Choices): + if getattr( + model_response.choices[0].message, "tool_calls", None + ) is not None and isinstance( + model_response.choices[0].message.tool_calls, list + ): + for tool_call in model_response.choices[ + 0 + ].message.tool_calls: + _tool_call = {**tool_call.dict(), "index": 0} + _tool_calls.append(_tool_call) + delta_obj = litellm.utils.Delta( + content=getattr( + model_response.choices[0].message, "content", None + ), + role=model_response.choices[0].message.role, + tool_calls=_tool_calls, + ) + streaming_choice.delta = delta_obj + streaming_model_response.choices = [streaming_choice] + completion_stream = model_response_iterator( + model_response=streaming_model_response + ) + print_verbose( + f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" + ) + return litellm.CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + model_response["finish_reason"] = response_body["stop_reason"] _usage = litellm.Usage( prompt_tokens=response_body["usage"]["input_tokens"], @@ -1029,6 +1086,10 @@ def completion( raise BedrockError(status_code=500, message=traceback.format_exc()) +async def model_response_iterator(model_response): + yield model_response + + def _embedding_func_single( model: str, input: str, diff --git a/litellm/main.py b/litellm/main.py index b516c5565..5f2b34482 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1753,7 +1753,11 @@ def completion( timeout=timeout, ) - if "stream" in optional_params and optional_params["stream"] == True: + if ( + "stream" in optional_params + and optional_params["stream"] == True + and not isinstance(response, CustomStreamWrapper) + ): # don't try to access stream object, if "ai21" in model: response = CustomStreamWrapper(