fix(bedrock.py): support claude 3 function calling when stream=true

https://github.com/BerriAI/litellm/issues/2615
This commit is contained in:
Krrish Dholakia 2024-03-21 18:39:03 -07:00
parent 425165dda9
commit 94f55aa6d9
2 changed files with 70 additions and 5 deletions

View file

@ -11,6 +11,7 @@ from .prompt_templates.factory import (
construct_tool_use_system_prompt, construct_tool_use_system_prompt,
extract_between_tags, extract_between_tags,
parse_xml_params, parse_xml_params,
contains_tag,
) )
import httpx import httpx
@ -679,6 +680,7 @@ def completion(
timeout=None, timeout=None,
): ):
exception_mapping_worked = False exception_mapping_worked = False
_is_function_call = False
try: try:
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
@ -727,8 +729,10 @@ def completion(
system_messages.append(message["content"]) system_messages.append(message["content"])
system_prompt_idx.append(idx) system_prompt_idx.append(idx)
if len(system_prompt_idx) > 0: if len(system_prompt_idx) > 0:
inference_params["system"] = '\n'.join(system_messages) inference_params["system"] = "\n".join(system_messages)
messages = [i for j, i in enumerate(messages) if j not in system_prompt_idx] messages = [
i for j, i in enumerate(messages) if j not in system_prompt_idx
]
# Format rest of message according to anthropic guidelines # Format rest of message according to anthropic guidelines
messages = prompt_factory( messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic" model=model, messages=messages, custom_llm_provider="anthropic"
@ -742,6 +746,7 @@ def completion(
inference_params[k] = v inference_params[k] = v
## Handle Tool Calling ## Handle Tool Calling
if "tools" in inference_params: if "tools" in inference_params:
_is_function_call = True
tool_calling_system_prompt = construct_tool_use_system_prompt( tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=inference_params["tools"] tools=inference_params["tools"]
) )
@ -823,7 +828,7 @@ def completion(
## COMPLETION CALL ## COMPLETION CALL
accept = "application/json" accept = "application/json"
contentType = "application/json" contentType = "application/json"
if stream == True: if stream == True and _is_function_call == False:
if provider == "ai21": if provider == "ai21":
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
@ -918,7 +923,9 @@ def completion(
elif provider == "anthropic": elif provider == "anthropic":
if model.startswith("anthropic.claude-3"): if model.startswith("anthropic.claude-3"):
outputText = response_body.get("content")[0].get("text", None) outputText = response_body.get("content")[0].get("text", None)
if "<invoke>" in outputText: # OUTPUT PARSE FUNCTION CALL if outputText is not None and contains_tag(
"invoke", outputText
): # OUTPUT PARSE FUNCTION CALL
function_name = extract_between_tags("tool_name", outputText)[0] function_name = extract_between_tags("tool_name", outputText)[0]
function_arguments_str = extract_between_tags("invoke", outputText)[ function_arguments_str = extract_between_tags("invoke", outputText)[
0 0
@ -941,6 +948,56 @@ def completion(
content=None, content=None,
) )
model_response.choices[0].message = _message # type: ignore model_response.choices[0].message = _message # type: ignore
if _is_function_call == True and stream is not None and stream == True:
print_verbose(
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
)
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = (
model_response.choices[0].finish_reason
)
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[
0
].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(
model_response.choices[0].message, "content", None
),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = model_response_iterator(
model_response=streaming_model_response
)
print_verbose(
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return litellm.CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
model_response["finish_reason"] = response_body["stop_reason"] model_response["finish_reason"] = response_body["stop_reason"]
_usage = litellm.Usage( _usage = litellm.Usage(
prompt_tokens=response_body["usage"]["input_tokens"], prompt_tokens=response_body["usage"]["input_tokens"],
@ -1029,6 +1086,10 @@ def completion(
raise BedrockError(status_code=500, message=traceback.format_exc()) raise BedrockError(status_code=500, message=traceback.format_exc())
async def model_response_iterator(model_response):
yield model_response
def _embedding_func_single( def _embedding_func_single(
model: str, model: str,
input: str, input: str,

View file

@ -1753,7 +1753,11 @@ def completion(
timeout=timeout, timeout=timeout,
) )
if "stream" in optional_params and optional_params["stream"] == True: if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object, # don't try to access stream object,
if "ai21" in model: if "ai21" in model:
response = CustomStreamWrapper( response = CustomStreamWrapper(