forked from phoenix/litellm-mirror
fix(test_completion.py): testing for anthropic function calling
This commit is contained in:
parent
ae82b3f31a
commit
c53563a1fe
3 changed files with 92 additions and 6 deletions
|
@ -2,14 +2,16 @@ import os, types
|
|||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import time
|
||||
import time, uuid
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason
|
||||
import litellm
|
||||
from .prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
construct_tool_use_system_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
)
|
||||
import httpx
|
||||
|
||||
|
@ -114,6 +116,7 @@ def completion(
|
|||
headers={},
|
||||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
_is_function_call = False
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
|
@ -148,12 +151,14 @@ def completion(
|
|||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||
tools=optional_params["tools"]
|
||||
)
|
||||
optional_params["system"] = (
|
||||
optional_params("system", "\n") + tool_calling_system_prompt
|
||||
optional_params.get("system", "\n") + tool_calling_system_prompt
|
||||
) # add the anthropic tool calling prompt to the system prompt
|
||||
optional_params.pop("tools")
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
|
@ -221,8 +226,33 @@ def completion(
|
|||
)
|
||||
else:
|
||||
text_content = completion_response["content"][0].get("text", None)
|
||||
## TOOL CALLING - OUTPUT PARSE
|
||||
if _is_function_call == True:
|
||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||
function_arguments_str = extract_between_tags("invoke", text_content)[
|
||||
0
|
||||
].strip()
|
||||
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
||||
function_arguments = parse_xml_params(function_arguments_str)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{uuid.uuid4()}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_arguments),
|
||||
},
|
||||
}
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message
|
||||
else:
|
||||
model_response.choices[0].message.content = text_content # type: ignore
|
||||
model_response.choices[0].finish_reason = completion_response["stop_reason"]
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["stop_reason"]
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from enum import Enum
|
||||
import requests, traceback
|
||||
import json
|
||||
import json, re, xml.etree.ElementTree as ET
|
||||
from jinja2 import Template, exceptions, Environment, meta
|
||||
from typing import Optional, Any
|
||||
|
||||
|
@ -520,6 +520,21 @@ def anthropic_messages_pt(messages: list):
|
|||
return new_messages
|
||||
|
||||
|
||||
def extract_between_tags(tag: str, string: str, strip: bool = False) -> list[str]:
|
||||
ext_list = re.findall(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL)
|
||||
if strip:
|
||||
ext_list = [e.strip() for e in ext_list]
|
||||
return ext_list
|
||||
|
||||
|
||||
def parse_xml_params(xml_content):
|
||||
root = ET.fromstring(xml_content)
|
||||
params = {}
|
||||
for child in root.findall(".//parameters/*"):
|
||||
params[child.tag] = child.text
|
||||
return params
|
||||
|
||||
|
||||
###
|
||||
|
||||
|
||||
|
|
|
@ -99,6 +99,47 @@ def test_completion_claude_3():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_claude_3_function_call():
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||
try:
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
# Add any assertions, here to check response args
|
||||
print(response)
|
||||
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||
assert isinstance(
|
||||
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_claude_3_stream():
|
||||
litellm.set_verbose = False
|
||||
messages = [{"role": "user", "content": "Hello, world"}]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue