fix(factory.py): parse list in xml tool calling response (anthropic)

improves tool calling outparsing to check if list in response. Also returns the raw response back to the user via `response._hidden_params["original_response"]`, so user can see exactly what anthropic returned
This commit is contained in:
Krrish Dholakia 2024-03-29 11:51:26 -07:00
parent de31af0dc2
commit 69f27aa25c
4 changed files with 114 additions and 11 deletions

View file

@ -3,7 +3,7 @@ import json
from enum import Enum
import requests, copy
import time, uuid
from typing import Callable, Optional
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
from .prompt_templates.factory import (
@ -118,6 +118,7 @@ def completion(
):
headers = validate_environment(api_key, headers)
_is_function_call = False
json_schemas: dict = {}
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict:
@ -161,6 +162,10 @@ def completion(
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
for tool in optional_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"]
)
@ -248,7 +253,12 @@ def completion(
0
].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(function_arguments_str)
function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name
)
_message = litellm.Message(
tool_calls=[
{
@ -263,6 +273,9 @@ def completion(
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
text_content # allow user to access raw anthropic tool calling response
)
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(

View file

@ -691,6 +691,7 @@ def completion(
):
exception_mapping_worked = False
_is_function_call = False
json_schemas: dict = {}
try:
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
@ -757,6 +758,10 @@ def completion(
## Handle Tool Calling
if "tools" in inference_params:
_is_function_call = True
for tool in inference_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=inference_params["tools"]
)
@ -943,7 +948,12 @@ def completion(
function_arguments_str = (
f"<invoke>{function_arguments_str}</invoke>"
)
function_arguments = parse_xml_params(function_arguments_str)
function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name)
)
_message = litellm.Message(
tool_calls=[
{
@ -958,6 +968,9 @@ def completion(
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
outputText # allow user to access raw anthropic tool calling response
)
if _is_function_call == True and stream is not None and stream == True:
print_verbose(
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"

View file

@ -731,18 +731,53 @@ def contains_tag(tag: str, string: str) -> bool:
return bool(re.search(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL))
def parse_xml_params(xml_content):
def parse_xml_params(xml_content, json_schema: Optional[dict] = None):
"""
Compare the xml output to the json schema
check if a value is a list - if so, get it's child elements
"""
root = ET.fromstring(xml_content)
params = {}
for child in root.findall(".//parameters/*"):
try:
# Attempt to decode the element's text as JSON
params[child.tag] = json.loads(child.text)
except json.JSONDecodeError:
# If JSON decoding fails, use the original text
params[child.tag] = child.text
if json_schema is not None: # check if we have a json schema for this function call
# iterate over all properties in the schema
for prop in json_schema["properties"]:
# If property is an array, get the nested items
_element = root.find(f"parameters/{prop}")
if json_schema["properties"][prop]["type"] == "array":
items = []
if _element is not None:
for value in _element:
try:
if value.text is not None:
_value = json.loads(value.text)
else:
continue
except json.JSONDecodeError:
_value = value.text
items.append(_value)
params[prop] = items
# If property is not an array, append the value directly
elif _element is not None and _element.text is not None:
try:
_value = json.loads(_element.text)
except json.JSONDecodeError:
_value = _element.text
params[prop] = _value
else:
for child in root.findall(".//parameters/*"):
if child is not None and child.text is not None:
try:
# Attempt to decode the element's text as JSON
params[child.tag] = json.loads(child.text) # type: ignore
except json.JSONDecodeError:
# If JSON decoding fails, use the original text
params[child.tag] = child.text # type: ignore
return params
###

View file

@ -195,6 +195,48 @@ def test_completion_claude_3_function_call():
pytest.fail(f"Error occurred: {e}")
def test_parse_xml_params():
from litellm.llms.prompt_templates.factory import parse_xml_params
## SCENARIO 1 ## - W/ ARRAY
xml_content = """<invoke><tool_name>return_list_of_str</tool_name>\n<parameters>\n<value>\n<item>apple</item>\n<item>banana</item>\n<item>orange</item>\n</value>\n</parameters></invoke>"""
json_schema = {
"properties": {
"value": {
"items": {"type": "string"},
"title": "Value",
"type": "array",
}
},
"required": ["value"],
"type": "object",
}
response = parse_xml_params(xml_content=xml_content, json_schema=json_schema)
print(f"response: {response}")
assert response["value"] == ["apple", "banana", "orange"]
## SCENARIO 2 ## - W/OUT ARRAY
xml_content = """<invoke><tool_name>get_current_weather</tool_name>\n<parameters>\n<location>Boston, MA</location>\n<unit>fahrenheit</unit>\n</parameters></invoke>"""
json_schema = {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
}
response = parse_xml_params(xml_content=xml_content, json_schema=json_schema)
print(f"response: {response}")
assert response["location"] == "Boston, MA"
assert response["unit"] == "fahrenheit"
def test_completion_claude_3_multi_turn_conversations():
litellm.set_verbose = True
litellm.modify_params = True