mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-28 04:04:31 +00:00
fix(factory.py): parse list in xml tool calling response (anthropic)
improves tool calling outparsing to check if list in response. Also returns the raw response back to the user via `response._hidden_params["original_response"]`, so user can see exactly what anthropic returned
This commit is contained in:
parent
de31af0dc2
commit
69f27aa25c
4 changed files with 114 additions and 11 deletions
|
@ -3,7 +3,7 @@ import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests, copy
|
import requests, copy
|
||||||
import time, uuid
|
import time, uuid
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, List
|
||||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import (
|
from .prompt_templates.factory import (
|
||||||
|
@ -118,6 +118,7 @@ def completion(
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key, headers)
|
headers = validate_environment(api_key, headers)
|
||||||
_is_function_call = False
|
_is_function_call = False
|
||||||
|
json_schemas: dict = {}
|
||||||
messages = copy.deepcopy(messages)
|
messages = copy.deepcopy(messages)
|
||||||
optional_params = copy.deepcopy(optional_params)
|
optional_params = copy.deepcopy(optional_params)
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
|
@ -161,6 +162,10 @@ def completion(
|
||||||
## Handle Tool Calling
|
## Handle Tool Calling
|
||||||
if "tools" in optional_params:
|
if "tools" in optional_params:
|
||||||
_is_function_call = True
|
_is_function_call = True
|
||||||
|
for tool in optional_params["tools"]:
|
||||||
|
json_schemas[tool["function"]["name"]] = tool["function"].get(
|
||||||
|
"parameters", None
|
||||||
|
)
|
||||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||||
tools=optional_params["tools"]
|
tools=optional_params["tools"]
|
||||||
)
|
)
|
||||||
|
@ -248,7 +253,12 @@ def completion(
|
||||||
0
|
0
|
||||||
].strip()
|
].strip()
|
||||||
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
||||||
function_arguments = parse_xml_params(function_arguments_str)
|
function_arguments = parse_xml_params(
|
||||||
|
function_arguments_str,
|
||||||
|
json_schema=json_schemas.get(
|
||||||
|
function_name, None
|
||||||
|
), # check if we have a json schema for this function name
|
||||||
|
)
|
||||||
_message = litellm.Message(
|
_message = litellm.Message(
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
{
|
{
|
||||||
|
@ -263,6 +273,9 @@ def completion(
|
||||||
content=None,
|
content=None,
|
||||||
)
|
)
|
||||||
model_response.choices[0].message = _message # type: ignore
|
model_response.choices[0].message = _message # type: ignore
|
||||||
|
model_response._hidden_params["original_response"] = (
|
||||||
|
text_content # allow user to access raw anthropic tool calling response
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model_response.choices[0].message.content = text_content # type: ignore
|
model_response.choices[0].message.content = text_content # type: ignore
|
||||||
model_response.choices[0].finish_reason = map_finish_reason(
|
model_response.choices[0].finish_reason = map_finish_reason(
|
||||||
|
|
|
@ -691,6 +691,7 @@ def completion(
|
||||||
):
|
):
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
_is_function_call = False
|
_is_function_call = False
|
||||||
|
json_schemas: dict = {}
|
||||||
try:
|
try:
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
@ -757,6 +758,10 @@ def completion(
|
||||||
## Handle Tool Calling
|
## Handle Tool Calling
|
||||||
if "tools" in inference_params:
|
if "tools" in inference_params:
|
||||||
_is_function_call = True
|
_is_function_call = True
|
||||||
|
for tool in inference_params["tools"]:
|
||||||
|
json_schemas[tool["function"]["name"]] = tool["function"].get(
|
||||||
|
"parameters", None
|
||||||
|
)
|
||||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||||
tools=inference_params["tools"]
|
tools=inference_params["tools"]
|
||||||
)
|
)
|
||||||
|
@ -943,7 +948,12 @@ def completion(
|
||||||
function_arguments_str = (
|
function_arguments_str = (
|
||||||
f"<invoke>{function_arguments_str}</invoke>"
|
f"<invoke>{function_arguments_str}</invoke>"
|
||||||
)
|
)
|
||||||
function_arguments = parse_xml_params(function_arguments_str)
|
function_arguments = parse_xml_params(
|
||||||
|
function_arguments_str,
|
||||||
|
json_schema=json_schemas.get(
|
||||||
|
function_name, None
|
||||||
|
), # check if we have a json schema for this function name)
|
||||||
|
)
|
||||||
_message = litellm.Message(
|
_message = litellm.Message(
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
{
|
{
|
||||||
|
@ -958,6 +968,9 @@ def completion(
|
||||||
content=None,
|
content=None,
|
||||||
)
|
)
|
||||||
model_response.choices[0].message = _message # type: ignore
|
model_response.choices[0].message = _message # type: ignore
|
||||||
|
model_response._hidden_params["original_response"] = (
|
||||||
|
outputText # allow user to access raw anthropic tool calling response
|
||||||
|
)
|
||||||
if _is_function_call == True and stream is not None and stream == True:
|
if _is_function_call == True and stream is not None and stream == True:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
|
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
|
||||||
|
|
|
@ -731,18 +731,53 @@ def contains_tag(tag: str, string: str) -> bool:
|
||||||
return bool(re.search(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL))
|
return bool(re.search(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL))
|
||||||
|
|
||||||
|
|
||||||
def parse_xml_params(xml_content):
|
def parse_xml_params(xml_content, json_schema: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
Compare the xml output to the json schema
|
||||||
|
|
||||||
|
check if a value is a list - if so, get it's child elements
|
||||||
|
"""
|
||||||
root = ET.fromstring(xml_content)
|
root = ET.fromstring(xml_content)
|
||||||
params = {}
|
params = {}
|
||||||
for child in root.findall(".//parameters/*"):
|
|
||||||
try:
|
if json_schema is not None: # check if we have a json schema for this function call
|
||||||
# Attempt to decode the element's text as JSON
|
# iterate over all properties in the schema
|
||||||
params[child.tag] = json.loads(child.text)
|
for prop in json_schema["properties"]:
|
||||||
except json.JSONDecodeError:
|
# If property is an array, get the nested items
|
||||||
# If JSON decoding fails, use the original text
|
_element = root.find(f"parameters/{prop}")
|
||||||
params[child.tag] = child.text
|
if json_schema["properties"][prop]["type"] == "array":
|
||||||
|
items = []
|
||||||
|
if _element is not None:
|
||||||
|
for value in _element:
|
||||||
|
try:
|
||||||
|
if value.text is not None:
|
||||||
|
_value = json.loads(value.text)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
_value = value.text
|
||||||
|
items.append(_value)
|
||||||
|
params[prop] = items
|
||||||
|
# If property is not an array, append the value directly
|
||||||
|
elif _element is not None and _element.text is not None:
|
||||||
|
try:
|
||||||
|
_value = json.loads(_element.text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
_value = _element.text
|
||||||
|
params[prop] = _value
|
||||||
|
else:
|
||||||
|
for child in root.findall(".//parameters/*"):
|
||||||
|
if child is not None and child.text is not None:
|
||||||
|
try:
|
||||||
|
# Attempt to decode the element's text as JSON
|
||||||
|
params[child.tag] = json.loads(child.text) # type: ignore
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If JSON decoding fails, use the original text
|
||||||
|
params[child.tag] = child.text # type: ignore
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
###
|
###
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -195,6 +195,48 @@ def test_completion_claude_3_function_call():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_xml_params():
|
||||||
|
from litellm.llms.prompt_templates.factory import parse_xml_params
|
||||||
|
|
||||||
|
## SCENARIO 1 ## - W/ ARRAY
|
||||||
|
xml_content = """<invoke><tool_name>return_list_of_str</tool_name>\n<parameters>\n<value>\n<item>apple</item>\n<item>banana</item>\n<item>orange</item>\n</value>\n</parameters></invoke>"""
|
||||||
|
json_schema = {
|
||||||
|
"properties": {
|
||||||
|
"value": {
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"title": "Value",
|
||||||
|
"type": "array",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["value"],
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
response = parse_xml_params(xml_content=xml_content, json_schema=json_schema)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
assert response["value"] == ["apple", "banana", "orange"]
|
||||||
|
|
||||||
|
## SCENARIO 2 ## - W/OUT ARRAY
|
||||||
|
xml_content = """<invoke><tool_name>get_current_weather</tool_name>\n<parameters>\n<location>Boston, MA</location>\n<unit>fahrenheit</unit>\n</parameters></invoke>"""
|
||||||
|
json_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = parse_xml_params(xml_content=xml_content, json_schema=json_schema)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
assert response["location"] == "Boston, MA"
|
||||||
|
assert response["unit"] == "fahrenheit"
|
||||||
|
|
||||||
|
|
||||||
def test_completion_claude_3_multi_turn_conversations():
|
def test_completion_claude_3_multi_turn_conversations():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
litellm.modify_params = True
|
litellm.modify_params = True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue