diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index ce413be65..6cea6450b 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -2,14 +2,16 @@ import os, types import json from enum import Enum import requests -import time +import time, uuid from typing import Callable, Optional -from litellm.utils import ModelResponse, Usage +from litellm.utils import ModelResponse, Usage, map_finish_reason import litellm from .prompt_templates.factory import ( prompt_factory, custom_prompt, construct_tool_use_system_prompt, + extract_between_tags, + parse_xml_params, ) import httpx @@ -114,6 +116,7 @@ def completion( headers={}, ): headers = validate_environment(api_key, headers) + _is_function_call = False if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] @@ -148,12 +151,14 @@ def completion( ## Handle Tool Calling if "tools" in optional_params: + _is_function_call = True tool_calling_system_prompt = construct_tool_use_system_prompt( tools=optional_params["tools"] ) optional_params["system"] = ( - optional_params("system", "\n") + tool_calling_system_prompt + optional_params.get("system", "\n") + tool_calling_system_prompt ) # add the anthropic tool calling prompt to the system prompt + optional_params.pop("tools") data = { "model": model, @@ -221,8 +226,33 @@ def completion( ) else: text_content = completion_response["content"][0].get("text", None) - model_response.choices[0].message.content = text_content # type: ignore - model_response.choices[0].finish_reason = completion_response["stop_reason"] + ## TOOL CALLING - OUTPUT PARSE + if _is_function_call == True: + function_name = extract_between_tags("tool_name", text_content)[0] + function_arguments_str = extract_between_tags("invoke", text_content)[ + 0 + ].strip() + function_arguments_str = f"{function_arguments_str}" + function_arguments = parse_xml_params(function_arguments_str) + _message = litellm.Message( + tool_calls=[ + { + "id": f"call_{uuid.uuid4()}", + "type": "function", + "function": { + "name": function_name, + "arguments": json.dumps(function_arguments), + }, + } + ], + content=None, + ) + model_response.choices[0].message = _message + else: + model_response.choices[0].message.content = text_content # type: ignore + model_response.choices[0].finish_reason = map_finish_reason( + completion_response["stop_reason"] + ) ## CALCULATING USAGE prompt_tokens = completion_response["usage"]["input_tokens"] diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index cc75237e0..2b0f4a2cf 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -1,6 +1,6 @@ from enum import Enum import requests, traceback -import json +import json, re, xml.etree.ElementTree as ET from jinja2 import Template, exceptions, Environment, meta from typing import Optional, Any @@ -520,6 +520,21 @@ def anthropic_messages_pt(messages: list): return new_messages +def extract_between_tags(tag: str, string: str, strip: bool = False) -> list[str]: + ext_list = re.findall(f"<{tag}>(.+?)", string, re.DOTALL) + if strip: + ext_list = [e.strip() for e in ext_list] + return ext_list + + +def parse_xml_params(xml_content): + root = ET.fromstring(xml_content) + params = {} + for child in root.findall(".//parameters/*"): + params[child.tag] = child.text + return params + + ### diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index a9d41be8d..068ddc78f 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -99,6 +99,47 @@ def test_completion_claude_3(): pytest.fail(f"Error occurred: {e}") +def test_completion_claude_3_function_call(): + litellm.set_verbose = True + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] + try: + # test without max tokens + response = completion( + model="anthropic/claude-3-opus-20240229", + messages=messages, + tools=tools, + tool_choice="auto", + ) + # Add any assertions, here to check response args + print(response) + assert isinstance(response.choices[0].message.tool_calls[0].function.name, str) + assert isinstance( + response.choices[0].message.tool_calls[0].function.arguments, str + ) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_completion_claude_3_stream(): litellm.set_verbose = False messages = [{"role": "user", "content": "Hello, world"}]