fix(test_completion.py): testing for anthropic function calling

This commit is contained in:
Krrish Dholakia 2024-03-04 11:31:56 -08:00
parent ae82b3f31a
commit c53563a1fe
3 changed files with 92 additions and 6 deletions

View file

@ -2,14 +2,16 @@ import os, types
import json
from enum import Enum
import requests
import time
import time, uuid
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage
from litellm.utils import ModelResponse, Usage, map_finish_reason
import litellm
from .prompt_templates.factory import (
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt,
extract_between_tags,
parse_xml_params,
)
import httpx
@ -114,6 +116,7 @@ def completion(
headers={},
):
headers = validate_environment(api_key, headers)
_is_function_call = False
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
@ -148,12 +151,14 @@ def completion(
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"]
)
optional_params["system"] = (
optional_params("system", "\n") + tool_calling_system_prompt
optional_params.get("system", "\n") + tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")
data = {
"model": model,
@ -221,8 +226,33 @@ def completion(
)
else:
text_content = completion_response["content"][0].get("text", None)
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = completion_response["stop_reason"]
## TOOL CALLING - OUTPUT PARSE
if _is_function_call == True:
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[
0
].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]

View file

@ -1,6 +1,6 @@
from enum import Enum
import requests, traceback
import json
import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, Environment, meta
from typing import Optional, Any
@ -520,6 +520,21 @@ def anthropic_messages_pt(messages: list):
return new_messages
def extract_between_tags(tag: str, string: str, strip: bool = False) -> list[str]:
ext_list = re.findall(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL)
if strip:
ext_list = [e.strip() for e in ext_list]
return ext_list
def parse_xml_params(xml_content):
root = ET.fromstring(xml_content)
params = {}
for child in root.findall(".//parameters/*"):
params[child.tag] = child.text
return params
###

View file

@ -99,6 +99,47 @@ def test_completion_claude_3():
pytest.fail(f"Error occurred: {e}")
def test_completion_claude_3_function_call():
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
try:
# test without max tokens
response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_claude_3_stream():
litellm.set_verbose = False
messages = [{"role": "user", "content": "Hello, world"}]