forked from phoenix/litellm-mirror
fix(bedrock.py): support claude 3 function calling when stream=true
https://github.com/BerriAI/litellm/issues/2615
This commit is contained in:
parent
425165dda9
commit
94f55aa6d9
2 changed files with 70 additions and 5 deletions
|
@ -11,6 +11,7 @@ from .prompt_templates.factory import (
|
||||||
construct_tool_use_system_prompt,
|
construct_tool_use_system_prompt,
|
||||||
extract_between_tags,
|
extract_between_tags,
|
||||||
parse_xml_params,
|
parse_xml_params,
|
||||||
|
contains_tag,
|
||||||
)
|
)
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -679,6 +680,7 @@ def completion(
|
||||||
timeout=None,
|
timeout=None,
|
||||||
):
|
):
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
|
_is_function_call = False
|
||||||
try:
|
try:
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
@ -727,8 +729,10 @@ def completion(
|
||||||
system_messages.append(message["content"])
|
system_messages.append(message["content"])
|
||||||
system_prompt_idx.append(idx)
|
system_prompt_idx.append(idx)
|
||||||
if len(system_prompt_idx) > 0:
|
if len(system_prompt_idx) > 0:
|
||||||
inference_params["system"] = '\n'.join(system_messages)
|
inference_params["system"] = "\n".join(system_messages)
|
||||||
messages = [i for j, i in enumerate(messages) if j not in system_prompt_idx]
|
messages = [
|
||||||
|
i for j, i in enumerate(messages) if j not in system_prompt_idx
|
||||||
|
]
|
||||||
# Format rest of message according to anthropic guidelines
|
# Format rest of message according to anthropic guidelines
|
||||||
messages = prompt_factory(
|
messages = prompt_factory(
|
||||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||||
|
@ -742,6 +746,7 @@ def completion(
|
||||||
inference_params[k] = v
|
inference_params[k] = v
|
||||||
## Handle Tool Calling
|
## Handle Tool Calling
|
||||||
if "tools" in inference_params:
|
if "tools" in inference_params:
|
||||||
|
_is_function_call = True
|
||||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||||
tools=inference_params["tools"]
|
tools=inference_params["tools"]
|
||||||
)
|
)
|
||||||
|
@ -823,7 +828,7 @@ def completion(
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
accept = "application/json"
|
accept = "application/json"
|
||||||
contentType = "application/json"
|
contentType = "application/json"
|
||||||
if stream == True:
|
if stream == True and _is_function_call == False:
|
||||||
if provider == "ai21":
|
if provider == "ai21":
|
||||||
## LOGGING
|
## LOGGING
|
||||||
request_str = f"""
|
request_str = f"""
|
||||||
|
@ -918,7 +923,9 @@ def completion(
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
if model.startswith("anthropic.claude-3"):
|
if model.startswith("anthropic.claude-3"):
|
||||||
outputText = response_body.get("content")[0].get("text", None)
|
outputText = response_body.get("content")[0].get("text", None)
|
||||||
if "<invoke>" in outputText: # OUTPUT PARSE FUNCTION CALL
|
if outputText is not None and contains_tag(
|
||||||
|
"invoke", outputText
|
||||||
|
): # OUTPUT PARSE FUNCTION CALL
|
||||||
function_name = extract_between_tags("tool_name", outputText)[0]
|
function_name = extract_between_tags("tool_name", outputText)[0]
|
||||||
function_arguments_str = extract_between_tags("invoke", outputText)[
|
function_arguments_str = extract_between_tags("invoke", outputText)[
|
||||||
0
|
0
|
||||||
|
@ -941,6 +948,56 @@ def completion(
|
||||||
content=None,
|
content=None,
|
||||||
)
|
)
|
||||||
model_response.choices[0].message = _message # type: ignore
|
model_response.choices[0].message = _message # type: ignore
|
||||||
|
if _is_function_call == True and stream is not None and stream == True:
|
||||||
|
print_verbose(
|
||||||
|
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
|
||||||
|
)
|
||||||
|
# return an iterator
|
||||||
|
streaming_model_response = ModelResponse(stream=True)
|
||||||
|
streaming_model_response.choices[0].finish_reason = (
|
||||||
|
model_response.choices[0].finish_reason
|
||||||
|
)
|
||||||
|
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||||
|
streaming_choice = litellm.utils.StreamingChoices()
|
||||||
|
streaming_choice.index = model_response.choices[0].index
|
||||||
|
_tool_calls = []
|
||||||
|
print_verbose(
|
||||||
|
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||||
|
)
|
||||||
|
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
||||||
|
if isinstance(model_response.choices[0], litellm.Choices):
|
||||||
|
if getattr(
|
||||||
|
model_response.choices[0].message, "tool_calls", None
|
||||||
|
) is not None and isinstance(
|
||||||
|
model_response.choices[0].message.tool_calls, list
|
||||||
|
):
|
||||||
|
for tool_call in model_response.choices[
|
||||||
|
0
|
||||||
|
].message.tool_calls:
|
||||||
|
_tool_call = {**tool_call.dict(), "index": 0}
|
||||||
|
_tool_calls.append(_tool_call)
|
||||||
|
delta_obj = litellm.utils.Delta(
|
||||||
|
content=getattr(
|
||||||
|
model_response.choices[0].message, "content", None
|
||||||
|
),
|
||||||
|
role=model_response.choices[0].message.role,
|
||||||
|
tool_calls=_tool_calls,
|
||||||
|
)
|
||||||
|
streaming_choice.delta = delta_obj
|
||||||
|
streaming_model_response.choices = [streaming_choice]
|
||||||
|
completion_stream = model_response_iterator(
|
||||||
|
model_response=streaming_model_response
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||||
|
)
|
||||||
|
return litellm.CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
model_response["finish_reason"] = response_body["stop_reason"]
|
model_response["finish_reason"] = response_body["stop_reason"]
|
||||||
_usage = litellm.Usage(
|
_usage = litellm.Usage(
|
||||||
prompt_tokens=response_body["usage"]["input_tokens"],
|
prompt_tokens=response_body["usage"]["input_tokens"],
|
||||||
|
@ -1029,6 +1086,10 @@ def completion(
|
||||||
raise BedrockError(status_code=500, message=traceback.format_exc())
|
raise BedrockError(status_code=500, message=traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
async def model_response_iterator(model_response):
|
||||||
|
yield model_response
|
||||||
|
|
||||||
|
|
||||||
def _embedding_func_single(
|
def _embedding_func_single(
|
||||||
model: str,
|
model: str,
|
||||||
input: str,
|
input: str,
|
||||||
|
|
|
@ -1753,7 +1753,11 @@ def completion(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if (
|
||||||
|
"stream" in optional_params
|
||||||
|
and optional_params["stream"] == True
|
||||||
|
and not isinstance(response, CustomStreamWrapper)
|
||||||
|
):
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
if "ai21" in model:
|
if "ai21" in model:
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue