forked from phoenix-oss/llama-stack-mirror
feat: support tool_choice = {required, none, <function>} (#1059)
Summary: titled Test Plan: added tests and LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B
This commit is contained in:
parent
37cf60b732
commit
8de7cf103b
7 changed files with 164 additions and 41 deletions
|
@ -31,6 +31,7 @@ from llama_stack.apis.inference import (
|
|||
SystemMessage,
|
||||
SystemMessageBehavior,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
|
@ -311,8 +312,6 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
|
|||
def augment_messages_for_tools_llama_3_1(
|
||||
request: ChatCompletionRequest,
|
||||
) -> List[Message]:
|
||||
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
if existing_messages[0].role == Role.system.value:
|
||||
|
@ -352,6 +351,10 @@ def augment_messages_for_tools_llama_3_1(
|
|||
elif isinstance(existing_system_message.content, list):
|
||||
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
|
||||
|
||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
||||
if tool_choice_prompt:
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
messages.append(SystemMessage(content=sys_content))
|
||||
|
||||
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
||||
|
@ -377,8 +380,6 @@ def augment_messages_for_tools_llama_3_1(
|
|||
def augment_messages_for_tools_llama_3_2(
|
||||
request: ChatCompletionRequest,
|
||||
) -> List[Message]:
|
||||
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
if existing_messages[0].role == Role.system.value:
|
||||
|
@ -386,7 +387,6 @@ def augment_messages_for_tools_llama_3_2(
|
|||
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
|
||||
messages = []
|
||||
sys_content = ""
|
||||
custom_tools, builtin_tools = [], []
|
||||
for t in request.tools:
|
||||
|
@ -395,7 +395,6 @@ def augment_messages_for_tools_llama_3_2(
|
|||
else:
|
||||
builtin_tools.append(t)
|
||||
|
||||
tool_template = None
|
||||
if builtin_tools:
|
||||
tool_gen = BuiltinToolGenerator()
|
||||
tool_template = tool_gen.gen(builtin_tools)
|
||||
|
@ -423,8 +422,22 @@ def augment_messages_for_tools_llama_3_2(
|
|||
):
|
||||
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
|
||||
|
||||
messages.append(SystemMessage(content=sys_content.strip("\n")))
|
||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
||||
if tool_choice_prompt:
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
# Add back existing messages from the request
|
||||
messages += existing_messages
|
||||
messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages]
|
||||
return messages
|
||||
|
||||
|
||||
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
|
||||
if tool_choice == ToolChoice.auto:
|
||||
return ""
|
||||
elif tool_choice == ToolChoice.required:
|
||||
return "You MUST use one of the provided functions/tools to answer the user query."
|
||||
elif tool_choice == ToolChoice.none:
|
||||
# tools are already not passed in
|
||||
return ""
|
||||
else:
|
||||
# specific tool
|
||||
return f"You MUST use the tool `{tool_choice}` to answer the user query."
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue