forked from phoenix/litellm-mirror
Merge pull request #2558 from lucasmrdt/main
fix(anthropic): tool calling detection
This commit is contained in:
commit
97130bb34b
2 changed files with 4 additions and 1 deletions
|
@ -7,6 +7,7 @@ from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import (
|
from .prompt_templates.factory import (
|
||||||
|
contains_tag,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
construct_tool_use_system_prompt,
|
construct_tool_use_system_prompt,
|
||||||
|
@ -235,7 +236,7 @@ def completion(
|
||||||
else:
|
else:
|
||||||
text_content = completion_response["content"][0].get("text", None)
|
text_content = completion_response["content"][0].get("text", None)
|
||||||
## TOOL CALLING - OUTPUT PARSE
|
## TOOL CALLING - OUTPUT PARSE
|
||||||
if text_content is not None and "invoke" in text_content:
|
if text_content is not None and contains_tag("invoke", text_content):
|
||||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||||
function_arguments_str = extract_between_tags("invoke", text_content)[
|
function_arguments_str = extract_between_tags("invoke", text_content)[
|
||||||
0
|
0
|
||||||
|
|
|
@ -714,6 +714,8 @@ def extract_between_tags(tag: str, string: str, strip: bool = False) -> List[str
|
||||||
ext_list = [e.strip() for e in ext_list]
|
ext_list = [e.strip() for e in ext_list]
|
||||||
return ext_list
|
return ext_list
|
||||||
|
|
||||||
|
def contains_tag(tag: str, string: str) -> bool:
|
||||||
|
return bool(re.search(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL))
|
||||||
|
|
||||||
def parse_xml_params(xml_content):
|
def parse_xml_params(xml_content):
|
||||||
root = ET.fromstring(xml_content)
|
root = ET.fromstring(xml_content)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue