mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-28 04:04:31 +00:00
fix(factory.py): parse list in xml tool calling response (anthropic)
improves tool calling outparsing to check if list in response. Also returns the raw response back to the user via `response._hidden_params["original_response"]`, so user can see exactly what anthropic returned
This commit is contained in:
parent
de31af0dc2
commit
69f27aa25c
4 changed files with 114 additions and 11 deletions
|
@ -3,7 +3,7 @@ import json
|
|||
from enum import Enum
|
||||
import requests, copy
|
||||
import time, uuid
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, List
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
import litellm
|
||||
from .prompt_templates.factory import (
|
||||
|
@ -118,6 +118,7 @@ def completion(
|
|||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
_is_function_call = False
|
||||
json_schemas: dict = {}
|
||||
messages = copy.deepcopy(messages)
|
||||
optional_params = copy.deepcopy(optional_params)
|
||||
if model in custom_prompt_dict:
|
||||
|
@ -161,6 +162,10 @@ def completion(
|
|||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
for tool in optional_params["tools"]:
|
||||
json_schemas[tool["function"]["name"]] = tool["function"].get(
|
||||
"parameters", None
|
||||
)
|
||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||
tools=optional_params["tools"]
|
||||
)
|
||||
|
@ -248,7 +253,12 @@ def completion(
|
|||
0
|
||||
].strip()
|
||||
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
||||
function_arguments = parse_xml_params(function_arguments_str)
|
||||
function_arguments = parse_xml_params(
|
||||
function_arguments_str,
|
||||
json_schema=json_schemas.get(
|
||||
function_name, None
|
||||
), # check if we have a json schema for this function name
|
||||
)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
{
|
||||
|
@ -263,6 +273,9 @@ def completion(
|
|||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = (
|
||||
text_content # allow user to access raw anthropic tool calling response
|
||||
)
|
||||
else:
|
||||
model_response.choices[0].message.content = text_content # type: ignore
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
|
|
|
@ -691,6 +691,7 @@ def completion(
|
|||
):
|
||||
exception_mapping_worked = False
|
||||
_is_function_call = False
|
||||
json_schemas: dict = {}
|
||||
try:
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
|
@ -757,6 +758,10 @@ def completion(
|
|||
## Handle Tool Calling
|
||||
if "tools" in inference_params:
|
||||
_is_function_call = True
|
||||
for tool in inference_params["tools"]:
|
||||
json_schemas[tool["function"]["name"]] = tool["function"].get(
|
||||
"parameters", None
|
||||
)
|
||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||
tools=inference_params["tools"]
|
||||
)
|
||||
|
@ -943,7 +948,12 @@ def completion(
|
|||
function_arguments_str = (
|
||||
f"<invoke>{function_arguments_str}</invoke>"
|
||||
)
|
||||
function_arguments = parse_xml_params(function_arguments_str)
|
||||
function_arguments = parse_xml_params(
|
||||
function_arguments_str,
|
||||
json_schema=json_schemas.get(
|
||||
function_name, None
|
||||
), # check if we have a json schema for this function name)
|
||||
)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
{
|
||||
|
@ -958,6 +968,9 @@ def completion(
|
|||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = (
|
||||
outputText # allow user to access raw anthropic tool calling response
|
||||
)
|
||||
if _is_function_call == True and stream is not None and stream == True:
|
||||
print_verbose(
|
||||
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
|
||||
|
|
|
@ -731,18 +731,53 @@ def contains_tag(tag: str, string: str) -> bool:
|
|||
return bool(re.search(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL))
|
||||
|
||||
|
||||
def parse_xml_params(xml_content):
|
||||
def parse_xml_params(xml_content, json_schema: Optional[dict] = None):
|
||||
"""
|
||||
Compare the xml output to the json schema
|
||||
|
||||
check if a value is a list - if so, get it's child elements
|
||||
"""
|
||||
root = ET.fromstring(xml_content)
|
||||
params = {}
|
||||
for child in root.findall(".//parameters/*"):
|
||||
try:
|
||||
# Attempt to decode the element's text as JSON
|
||||
params[child.tag] = json.loads(child.text)
|
||||
except json.JSONDecodeError:
|
||||
# If JSON decoding fails, use the original text
|
||||
params[child.tag] = child.text
|
||||
|
||||
if json_schema is not None: # check if we have a json schema for this function call
|
||||
# iterate over all properties in the schema
|
||||
for prop in json_schema["properties"]:
|
||||
# If property is an array, get the nested items
|
||||
_element = root.find(f"parameters/{prop}")
|
||||
if json_schema["properties"][prop]["type"] == "array":
|
||||
items = []
|
||||
if _element is not None:
|
||||
for value in _element:
|
||||
try:
|
||||
if value.text is not None:
|
||||
_value = json.loads(value.text)
|
||||
else:
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
_value = value.text
|
||||
items.append(_value)
|
||||
params[prop] = items
|
||||
# If property is not an array, append the value directly
|
||||
elif _element is not None and _element.text is not None:
|
||||
try:
|
||||
_value = json.loads(_element.text)
|
||||
except json.JSONDecodeError:
|
||||
_value = _element.text
|
||||
params[prop] = _value
|
||||
else:
|
||||
for child in root.findall(".//parameters/*"):
|
||||
if child is not None and child.text is not None:
|
||||
try:
|
||||
# Attempt to decode the element's text as JSON
|
||||
params[child.tag] = json.loads(child.text) # type: ignore
|
||||
except json.JSONDecodeError:
|
||||
# If JSON decoding fails, use the original text
|
||||
params[child.tag] = child.text # type: ignore
|
||||
|
||||
return params
|
||||
|
||||
|
||||
###
|
||||
|
||||
|
||||
|
|
|
@ -195,6 +195,48 @@ def test_completion_claude_3_function_call():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_parse_xml_params():
|
||||
from litellm.llms.prompt_templates.factory import parse_xml_params
|
||||
|
||||
## SCENARIO 1 ## - W/ ARRAY
|
||||
xml_content = """<invoke><tool_name>return_list_of_str</tool_name>\n<parameters>\n<value>\n<item>apple</item>\n<item>banana</item>\n<item>orange</item>\n</value>\n</parameters></invoke>"""
|
||||
json_schema = {
|
||||
"properties": {
|
||||
"value": {
|
||||
"items": {"type": "string"},
|
||||
"title": "Value",
|
||||
"type": "array",
|
||||
}
|
||||
},
|
||||
"required": ["value"],
|
||||
"type": "object",
|
||||
}
|
||||
response = parse_xml_params(xml_content=xml_content, json_schema=json_schema)
|
||||
|
||||
print(f"response: {response}")
|
||||
assert response["value"] == ["apple", "banana", "orange"]
|
||||
|
||||
## SCENARIO 2 ## - W/OUT ARRAY
|
||||
xml_content = """<invoke><tool_name>get_current_weather</tool_name>\n<parameters>\n<location>Boston, MA</location>\n<unit>fahrenheit</unit>\n</parameters></invoke>"""
|
||||
json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
}
|
||||
|
||||
response = parse_xml_params(xml_content=xml_content, json_schema=json_schema)
|
||||
|
||||
print(f"response: {response}")
|
||||
assert response["location"] == "Boston, MA"
|
||||
assert response["unit"] == "fahrenheit"
|
||||
|
||||
|
||||
def test_completion_claude_3_multi_turn_conversations():
|
||||
litellm.set_verbose = True
|
||||
litellm.modify_params = True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue