fix(factory.py): parse list in xml tool calling response (anthropic)

improves tool calling outparsing to check if list in response. Also returns the raw response back to the user via `response._hidden_params["original_response"]`, so user can see exactly what anthropic returned
This commit is contained in:
Krrish Dholakia 2024-03-29 11:51:26 -07:00
parent de31af0dc2
commit 69f27aa25c
4 changed files with 114 additions and 11 deletions

View file

@ -3,7 +3,7 @@ import json
from enum import Enum from enum import Enum
import requests, copy import requests, copy
import time, uuid import time, uuid
from typing import Callable, Optional from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm import litellm
from .prompt_templates.factory import ( from .prompt_templates.factory import (
@ -118,6 +118,7 @@ def completion(
): ):
headers = validate_environment(api_key, headers) headers = validate_environment(api_key, headers)
_is_function_call = False _is_function_call = False
json_schemas: dict = {}
messages = copy.deepcopy(messages) messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params) optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict: if model in custom_prompt_dict:
@ -161,6 +162,10 @@ def completion(
## Handle Tool Calling ## Handle Tool Calling
if "tools" in optional_params: if "tools" in optional_params:
_is_function_call = True _is_function_call = True
for tool in optional_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt( tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"] tools=optional_params["tools"]
) )
@ -248,7 +253,12 @@ def completion(
0 0
].strip() ].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>" function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(function_arguments_str) function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name
)
_message = litellm.Message( _message = litellm.Message(
tool_calls=[ tool_calls=[
{ {
@ -263,6 +273,9 @@ def completion(
content=None, content=None,
) )
model_response.choices[0].message = _message # type: ignore model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
text_content # allow user to access raw anthropic tool calling response
)
else: else:
model_response.choices[0].message.content = text_content # type: ignore model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason( model_response.choices[0].finish_reason = map_finish_reason(

View file

@ -691,6 +691,7 @@ def completion(
): ):
exception_mapping_worked = False exception_mapping_worked = False
_is_function_call = False _is_function_call = False
json_schemas: dict = {}
try: try:
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
@ -757,6 +758,10 @@ def completion(
## Handle Tool Calling ## Handle Tool Calling
if "tools" in inference_params: if "tools" in inference_params:
_is_function_call = True _is_function_call = True
for tool in inference_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt( tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=inference_params["tools"] tools=inference_params["tools"]
) )
@ -943,7 +948,12 @@ def completion(
function_arguments_str = ( function_arguments_str = (
f"<invoke>{function_arguments_str}</invoke>" f"<invoke>{function_arguments_str}</invoke>"
) )
function_arguments = parse_xml_params(function_arguments_str) function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name)
)
_message = litellm.Message( _message = litellm.Message(
tool_calls=[ tool_calls=[
{ {
@ -958,6 +968,9 @@ def completion(
content=None, content=None,
) )
model_response.choices[0].message = _message # type: ignore model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
outputText # allow user to access raw anthropic tool calling response
)
if _is_function_call == True and stream is not None and stream == True: if _is_function_call == True and stream is not None and stream == True:
print_verbose( print_verbose(
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK" f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"

View file

@ -731,18 +731,53 @@ def contains_tag(tag: str, string: str) -> bool:
return bool(re.search(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL)) return bool(re.search(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL))
def parse_xml_params(xml_content): def parse_xml_params(xml_content, json_schema: Optional[dict] = None):
"""
Compare the xml output to the json schema
check if a value is a list - if so, get it's child elements
"""
root = ET.fromstring(xml_content) root = ET.fromstring(xml_content)
params = {} params = {}
for child in root.findall(".//parameters/*"):
try: if json_schema is not None: # check if we have a json schema for this function call
# Attempt to decode the element's text as JSON # iterate over all properties in the schema
params[child.tag] = json.loads(child.text) for prop in json_schema["properties"]:
except json.JSONDecodeError: # If property is an array, get the nested items
# If JSON decoding fails, use the original text _element = root.find(f"parameters/{prop}")
params[child.tag] = child.text if json_schema["properties"][prop]["type"] == "array":
items = []
if _element is not None:
for value in _element:
try:
if value.text is not None:
_value = json.loads(value.text)
else:
continue
except json.JSONDecodeError:
_value = value.text
items.append(_value)
params[prop] = items
# If property is not an array, append the value directly
elif _element is not None and _element.text is not None:
try:
_value = json.loads(_element.text)
except json.JSONDecodeError:
_value = _element.text
params[prop] = _value
else:
for child in root.findall(".//parameters/*"):
if child is not None and child.text is not None:
try:
# Attempt to decode the element's text as JSON
params[child.tag] = json.loads(child.text) # type: ignore
except json.JSONDecodeError:
# If JSON decoding fails, use the original text
params[child.tag] = child.text # type: ignore
return params return params
### ###

View file

@ -195,6 +195,48 @@ def test_completion_claude_3_function_call():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_parse_xml_params():
from litellm.llms.prompt_templates.factory import parse_xml_params
## SCENARIO 1 ## - W/ ARRAY
xml_content = """<invoke><tool_name>return_list_of_str</tool_name>\n<parameters>\n<value>\n<item>apple</item>\n<item>banana</item>\n<item>orange</item>\n</value>\n</parameters></invoke>"""
json_schema = {
"properties": {
"value": {
"items": {"type": "string"},
"title": "Value",
"type": "array",
}
},
"required": ["value"],
"type": "object",
}
response = parse_xml_params(xml_content=xml_content, json_schema=json_schema)
print(f"response: {response}")
assert response["value"] == ["apple", "banana", "orange"]
## SCENARIO 2 ## - W/OUT ARRAY
xml_content = """<invoke><tool_name>get_current_weather</tool_name>\n<parameters>\n<location>Boston, MA</location>\n<unit>fahrenheit</unit>\n</parameters></invoke>"""
json_schema = {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
}
response = parse_xml_params(xml_content=xml_content, json_schema=json_schema)
print(f"response: {response}")
assert response["location"] == "Boston, MA"
assert response["unit"] == "fahrenheit"
def test_completion_claude_3_multi_turn_conversations(): def test_completion_claude_3_multi_turn_conversations():
litellm.set_verbose = True litellm.set_verbose = True
litellm.modify_params = True litellm.modify_params = True