forked from phoenix/litellm-mirror
fix(bedrock.py): support claude 3 function calling when stream=true
https://github.com/BerriAI/litellm/issues/2615
This commit is contained in:
parent
425165dda9
commit
94f55aa6d9
2 changed files with 70 additions and 5 deletions
|
@ -11,6 +11,7 @@ from .prompt_templates.factory import (
|
|||
construct_tool_use_system_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
contains_tag,
|
||||
)
|
||||
import httpx
|
||||
|
||||
|
@ -679,6 +680,7 @@ def completion(
|
|||
timeout=None,
|
||||
):
|
||||
exception_mapping_worked = False
|
||||
_is_function_call = False
|
||||
try:
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
|
@ -727,8 +729,10 @@ def completion(
|
|||
system_messages.append(message["content"])
|
||||
system_prompt_idx.append(idx)
|
||||
if len(system_prompt_idx) > 0:
|
||||
inference_params["system"] = '\n'.join(system_messages)
|
||||
messages = [i for j, i in enumerate(messages) if j not in system_prompt_idx]
|
||||
inference_params["system"] = "\n".join(system_messages)
|
||||
messages = [
|
||||
i for j, i in enumerate(messages) if j not in system_prompt_idx
|
||||
]
|
||||
# Format rest of message according to anthropic guidelines
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
|
@ -742,6 +746,7 @@ def completion(
|
|||
inference_params[k] = v
|
||||
## Handle Tool Calling
|
||||
if "tools" in inference_params:
|
||||
_is_function_call = True
|
||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||
tools=inference_params["tools"]
|
||||
)
|
||||
|
@ -823,7 +828,7 @@ def completion(
|
|||
## COMPLETION CALL
|
||||
accept = "application/json"
|
||||
contentType = "application/json"
|
||||
if stream == True:
|
||||
if stream == True and _is_function_call == False:
|
||||
if provider == "ai21":
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
|
@ -918,7 +923,9 @@ def completion(
|
|||
elif provider == "anthropic":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
outputText = response_body.get("content")[0].get("text", None)
|
||||
if "<invoke>" in outputText: # OUTPUT PARSE FUNCTION CALL
|
||||
if outputText is not None and contains_tag(
|
||||
"invoke", outputText
|
||||
): # OUTPUT PARSE FUNCTION CALL
|
||||
function_name = extract_between_tags("tool_name", outputText)[0]
|
||||
function_arguments_str = extract_between_tags("invoke", outputText)[
|
||||
0
|
||||
|
@ -941,6 +948,56 @@ def completion(
|
|||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
if _is_function_call == True and stream is not None and stream == True:
|
||||
print_verbose(
|
||||
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
|
||||
)
|
||||
# return an iterator
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = (
|
||||
model_response.choices[0].finish_reason
|
||||
)
|
||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||
streaming_choice = litellm.utils.StreamingChoices()
|
||||
streaming_choice.index = model_response.choices[0].index
|
||||
_tool_calls = []
|
||||
print_verbose(
|
||||
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||
)
|
||||
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
||||
if isinstance(model_response.choices[0], litellm.Choices):
|
||||
if getattr(
|
||||
model_response.choices[0].message, "tool_calls", None
|
||||
) is not None and isinstance(
|
||||
model_response.choices[0].message.tool_calls, list
|
||||
):
|
||||
for tool_call in model_response.choices[
|
||||
0
|
||||
].message.tool_calls:
|
||||
_tool_call = {**tool_call.dict(), "index": 0}
|
||||
_tool_calls.append(_tool_call)
|
||||
delta_obj = litellm.utils.Delta(
|
||||
content=getattr(
|
||||
model_response.choices[0].message, "content", None
|
||||
),
|
||||
role=model_response.choices[0].message.role,
|
||||
tool_calls=_tool_calls,
|
||||
)
|
||||
streaming_choice.delta = delta_obj
|
||||
streaming_model_response.choices = [streaming_choice]
|
||||
completion_stream = model_response_iterator(
|
||||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
)
|
||||
return litellm.CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
model_response["finish_reason"] = response_body["stop_reason"]
|
||||
_usage = litellm.Usage(
|
||||
prompt_tokens=response_body["usage"]["input_tokens"],
|
||||
|
@ -1029,6 +1086,10 @@ def completion(
|
|||
raise BedrockError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
|
||||
async def model_response_iterator(model_response):
|
||||
yield model_response
|
||||
|
||||
|
||||
def _embedding_func_single(
|
||||
model: str,
|
||||
input: str,
|
||||
|
|
|
@ -1753,7 +1753,11 @@ def completion(
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and not isinstance(response, CustomStreamWrapper)
|
||||
):
|
||||
# don't try to access stream object,
|
||||
if "ai21" in model:
|
||||
response = CustomStreamWrapper(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue