forked from phoenix/litellm-mirror
fix(test_completion.py): testing for anthropic function calling
This commit is contained in:
parent
ae82b3f31a
commit
c53563a1fe
3 changed files with 92 additions and 6 deletions
|
@ -2,14 +2,16 @@ import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time, uuid
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage, map_finish_reason
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import (
|
from .prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
construct_tool_use_system_prompt,
|
construct_tool_use_system_prompt,
|
||||||
|
extract_between_tags,
|
||||||
|
parse_xml_params,
|
||||||
)
|
)
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -114,6 +116,7 @@ def completion(
|
||||||
headers={},
|
headers={},
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key, headers)
|
headers = validate_environment(api_key, headers)
|
||||||
|
_is_function_call = False
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
# check if the model has a registered custom prompt
|
# check if the model has a registered custom prompt
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
@ -148,12 +151,14 @@ def completion(
|
||||||
|
|
||||||
## Handle Tool Calling
|
## Handle Tool Calling
|
||||||
if "tools" in optional_params:
|
if "tools" in optional_params:
|
||||||
|
_is_function_call = True
|
||||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||||
tools=optional_params["tools"]
|
tools=optional_params["tools"]
|
||||||
)
|
)
|
||||||
optional_params["system"] = (
|
optional_params["system"] = (
|
||||||
optional_params("system", "\n") + tool_calling_system_prompt
|
optional_params.get("system", "\n") + tool_calling_system_prompt
|
||||||
) # add the anthropic tool calling prompt to the system prompt
|
) # add the anthropic tool calling prompt to the system prompt
|
||||||
|
optional_params.pop("tools")
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
|
@ -221,8 +226,33 @@ def completion(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text_content = completion_response["content"][0].get("text", None)
|
text_content = completion_response["content"][0].get("text", None)
|
||||||
model_response.choices[0].message.content = text_content # type: ignore
|
## TOOL CALLING - OUTPUT PARSE
|
||||||
model_response.choices[0].finish_reason = completion_response["stop_reason"]
|
if _is_function_call == True:
|
||||||
|
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||||
|
function_arguments_str = extract_between_tags("invoke", text_content)[
|
||||||
|
0
|
||||||
|
].strip()
|
||||||
|
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
||||||
|
function_arguments = parse_xml_params(function_arguments_str)
|
||||||
|
_message = litellm.Message(
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{uuid.uuid4()}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": json.dumps(function_arguments),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
content=None,
|
||||||
|
)
|
||||||
|
model_response.choices[0].message = _message
|
||||||
|
else:
|
||||||
|
model_response.choices[0].message.content = text_content # type: ignore
|
||||||
|
model_response.choices[0].finish_reason = map_finish_reason(
|
||||||
|
completion_response["stop_reason"]
|
||||||
|
)
|
||||||
|
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests, traceback
|
import requests, traceback
|
||||||
import json
|
import json, re, xml.etree.ElementTree as ET
|
||||||
from jinja2 import Template, exceptions, Environment, meta
|
from jinja2 import Template, exceptions, Environment, meta
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
@ -520,6 +520,21 @@ def anthropic_messages_pt(messages: list):
|
||||||
return new_messages
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
|
def extract_between_tags(tag: str, string: str, strip: bool = False) -> list[str]:
|
||||||
|
ext_list = re.findall(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL)
|
||||||
|
if strip:
|
||||||
|
ext_list = [e.strip() for e in ext_list]
|
||||||
|
return ext_list
|
||||||
|
|
||||||
|
|
||||||
|
def parse_xml_params(xml_content):
|
||||||
|
root = ET.fromstring(xml_content)
|
||||||
|
params = {}
|
||||||
|
for child in root.findall(".//parameters/*"):
|
||||||
|
params[child.tag] = child.text
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
###
|
###
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -99,6 +99,47 @@ def test_completion_claude_3():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_claude_3_function_call():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||||
|
try:
|
||||||
|
# test without max tokens
|
||||||
|
response = completion(
|
||||||
|
model="anthropic/claude-3-opus-20240229",
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
)
|
||||||
|
# Add any assertions, here to check response args
|
||||||
|
print(response)
|
||||||
|
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||||
|
assert isinstance(
|
||||||
|
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_claude_3_stream():
|
def test_completion_claude_3_stream():
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
messages = [{"role": "user", "content": "Hello, world"}]
|
messages = [{"role": "user", "content": "Hello, world"}]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue