Merge pull request #2558 from lucasmrdt/main

fix(anthropic): tool calling detection
This commit is contained in:
Krish Dholakia 2024-03-19 11:48:05 -07:00 committed by GitHub
commit 97130bb34b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 4 additions and 1 deletions

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm import litellm
from .prompt_templates.factory import ( from .prompt_templates.factory import (
contains_tag,
prompt_factory, prompt_factory,
custom_prompt, custom_prompt,
construct_tool_use_system_prompt, construct_tool_use_system_prompt,
@ -235,7 +236,7 @@ def completion(
else: else:
text_content = completion_response["content"][0].get("text", None) text_content = completion_response["content"][0].get("text", None)
## TOOL CALLING - OUTPUT PARSE ## TOOL CALLING - OUTPUT PARSE
if text_content is not None and "invoke" in text_content: if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0] function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[ function_arguments_str = extract_between_tags("invoke", text_content)[
0 0

View file

@ -714,6 +714,8 @@ def extract_between_tags(tag: str, string: str, strip: bool = False) -> List[str
ext_list = [e.strip() for e in ext_list] ext_list = [e.strip() for e in ext_list]
return ext_list return ext_list
def contains_tag(tag: str, string: str) -> bool:
return bool(re.search(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL))
def parse_xml_params(xml_content): def parse_xml_params(xml_content):
root = ET.fromstring(xml_content) root = ET.fromstring(xml_content)