forked from phoenix/litellm-mirror
Merge pull request #2558 from lucasmrdt/main
fix(anthropic): tool calling detection
This commit is contained in:
commit
97130bb34b
2 changed files with 4 additions and 1 deletions
|
@ -7,6 +7,7 @@ from typing import Callable, Optional
|
|||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
import litellm
|
||||
from .prompt_templates.factory import (
|
||||
contains_tag,
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
construct_tool_use_system_prompt,
|
||||
|
@ -235,7 +236,7 @@ def completion(
|
|||
else:
|
||||
text_content = completion_response["content"][0].get("text", None)
|
||||
## TOOL CALLING - OUTPUT PARSE
|
||||
if text_content is not None and "invoke" in text_content:
|
||||
if text_content is not None and contains_tag("invoke", text_content):
|
||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||
function_arguments_str = extract_between_tags("invoke", text_content)[
|
||||
0
|
||||
|
|
|
@ -714,6 +714,8 @@ def extract_between_tags(tag: str, string: str, strip: bool = False) -> List[str
|
|||
ext_list = [e.strip() for e in ext_list]
|
||||
return ext_list
|
||||
|
||||
def contains_tag(tag: str, string: str) -> bool:
|
||||
return bool(re.search(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL))
|
||||
|
||||
def parse_xml_params(xml_content):
|
||||
root = ET.fromstring(xml_content)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue