feat: support tool_choice = {required, none, <function>} (#1059)

Summary:

titled


Test Plan:

added tests and

LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/
--safety-shield meta-llama/Llama-Guard-3-8B
This commit is contained in:
ehhuang 2025-02-18 20:25:15 -08:00 committed by GitHub
parent 37cf60b732
commit 8de7cf103b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 164 additions and 41 deletions

View file

@ -182,10 +182,12 @@ class ToolChoice(Enum):
:cvar auto: The model may use tools if it determines that is appropriate.
:cvar required: The model must use tools.
:cvar none: The model must not use tools.
"""
auto = "auto"
required = "required"
none = "none"
@json_schema_type
@ -326,7 +328,7 @@ class SystemMessageBehavior(Enum):
class ToolConfig(BaseModel):
"""Configuration for tool use.
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
:param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
@ -337,9 +339,16 @@ class ToolConfig(BaseModel):
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
system_message_behavior: SystemMessageBehavior = Field(default=SystemMessageBehavior.append)
system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append)
def model_post_init(self, __context: Any) -> None:
if isinstance(self.tool_choice, str):
try:
self.tool_choice = ToolChoice[self.tool_choice]
except KeyError:
pass
# This is an internally used class

View file

@ -128,7 +128,7 @@ class InferenceRouter(Inference):
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -140,20 +140,36 @@ class InferenceRouter(Inference):
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
if tool_config:
if tool_choice != tool_config.tool_choice:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
if tool_prompt_format != tool_config.tool_prompt_format:
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
else:
tool_config = ToolConfig(
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
)
params = {}
if tool_choice:
params["tool_choice"] = tool_choice
if tool_prompt_format:
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
tools = tools or []
if tool_config.tool_choice == ToolChoice.none:
tools = []
elif tool_config.tool_choice == ToolChoice.auto:
pass
elif tool_config.tool_choice == ToolChoice.required:
pass
else:
# verify tool_choice is one of the tools
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
if tool_config.tool_choice not in tool_names:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
params = dict(
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,

View file

@ -31,6 +31,7 @@ from llama_stack.apis.inference import (
SystemMessage,
SystemMessageBehavior,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
@ -311,8 +312,6 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> List[Message]:
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
@ -352,6 +351,10 @@ def augment_messages_for_tools_llama_3_1(
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
@ -377,8 +380,6 @@ def augment_messages_for_tools_llama_3_1(
def augment_messages_for_tools_llama_3_2(
request: ChatCompletionRequest,
) -> List[Message]:
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
@ -386,7 +387,6 @@ def augment_messages_for_tools_llama_3_2(
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
messages = []
sys_content = ""
custom_tools, builtin_tools = [], []
for t in request.tools:
@ -395,7 +395,6 @@ def augment_messages_for_tools_llama_3_2(
else:
builtin_tools.append(t)
tool_template = None
if builtin_tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools)
@ -423,8 +422,22 @@ def augment_messages_for_tools_llama_3_2(
):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
messages.append(SystemMessage(content=sys_content.strip("\n")))
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
# Add back existing messages from the request
messages += existing_messages
messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages]
return messages
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
if tool_choice == ToolChoice.auto:
return ""
elif tool_choice == ToolChoice.required:
return "You MUST use one of the provided functions/tools to answer the user query."
elif tool_choice == ToolChoice.none:
# tools are already not passed in
return ""
else:
# specific tool
return f"You MUST use the tool `{tool_choice}` to answer the user query."