diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 17cf92341..65a1bdd6b 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -2697,7 +2697,8 @@
"type": "string",
"enum": [
"auto",
- "required"
+ "required",
+ "none"
],
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
},
@@ -3231,13 +3232,22 @@
"type": "object",
"properties": {
"tool_choice": {
- "type": "string",
- "enum": [
- "auto",
- "required"
+ "oneOf": [
+ {
+ "type": "string",
+ "enum": [
+ "auto",
+ "required",
+ "none"
+ ],
+ "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
+ },
+ {
+ "type": "string"
+ }
],
- "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.",
- "default": "auto"
+ "default": "auto",
+ "description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
},
"tool_prompt_format": {
"type": "string",
@@ -3259,9 +3269,6 @@
}
},
"additionalProperties": false,
- "required": [
- "system_message_behavior"
- ],
"description": "Configuration for tool use."
},
"ToolDef": {
@@ -4100,7 +4107,8 @@
"type": "string",
"enum": [
"auto",
- "required"
+ "required",
+ "none"
],
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
},
@@ -4384,7 +4392,8 @@
"type": "string",
"enum": [
"auto",
- "required"
+ "required",
+ "none"
],
"description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. .. deprecated:: Use tool_config instead."
},
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index f63374406..60b777e91 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -1637,6 +1637,7 @@ components:
enum:
- auto
- required
+ - none
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following capabilities
@@ -1994,13 +1995,21 @@ components:
type: object
properties:
tool_choice:
- type: string
- enum:
- - auto
- - required
- description: >-
- (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
+ oneOf:
+ - type: string
+ enum:
+ - auto
+ - required
+ - none
+ description: >-
+ Whether tool use is required or automatic. This is a hint to the model
+ which may not be followed. It depends on the Instruction Following
+ capabilities of the model.
+ - type: string
default: auto
+ description: >-
+ (Optional) Whether tool use is automatic, required, or none. Can also
+ specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
tool_prompt_format:
type: string
enum:
@@ -2027,8 +2036,6 @@ components:
where the function definitions should be inserted.
default: append
additionalProperties: false
- required:
- - system_message_behavior
description: Configuration for tool use.
ToolDef:
type: object
@@ -2533,6 +2540,7 @@ components:
enum:
- auto
- required
+ - none
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following capabilities
@@ -2739,6 +2747,7 @@ components:
enum:
- auto
- required
+ - none
description: >-
(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
.. deprecated:: Use tool_config instead.
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 433ba3274..a3fb69477 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -182,10 +182,12 @@ class ToolChoice(Enum):
:cvar auto: The model may use tools if it determines that is appropriate.
:cvar required: The model must use tools.
+ :cvar none: The model must not use tools.
"""
auto = "auto"
required = "required"
+ none = "none"
@json_schema_type
@@ -326,7 +328,7 @@ class SystemMessageBehavior(Enum):
class ToolConfig(BaseModel):
"""Configuration for tool use.
- :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
+ :param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag.
@@ -337,9 +339,16 @@ class ToolConfig(BaseModel):
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""
- tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
+ tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
- system_message_behavior: SystemMessageBehavior = Field(default=SystemMessageBehavior.append)
+ system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append)
+
+ def model_post_init(self, __context: Any) -> None:
+ if isinstance(self.tool_choice, str):
+ try:
+ self.tool_choice = ToolChoice[self.tool_choice]
+ except KeyError:
+ pass
# This is an internally used class
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index f45975189..9d12c8a40 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -128,7 +128,7 @@ class InferenceRouter(Inference):
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
- tool_choice: Optional[ToolChoice] = ToolChoice.auto,
+ tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@@ -140,20 +140,36 @@ class InferenceRouter(Inference):
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
if tool_config:
- if tool_choice != tool_config.tool_choice:
+ if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
- if tool_prompt_format != tool_config.tool_prompt_format:
+ if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
else:
- tool_config = ToolConfig(
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
- )
+ params = {}
+ if tool_choice:
+ params["tool_choice"] = tool_choice
+ if tool_prompt_format:
+ params["tool_prompt_format"] = tool_prompt_format
+ tool_config = ToolConfig(**params)
+
+ tools = tools or []
+ if tool_config.tool_choice == ToolChoice.none:
+ tools = []
+ elif tool_config.tool_choice == ToolChoice.auto:
+ pass
+ elif tool_config.tool_choice == ToolChoice.required:
+ pass
+ else:
+ # verify tool_choice is one of the tools
+ tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
+ if tool_config.tool_choice not in tool_names:
+ raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
+
params = dict(
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
- tools=tools or [],
+ tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py
index b7945dee7..2782c661f 100644
--- a/llama_stack/providers/utils/inference/prompt_adapter.py
+++ b/llama_stack/providers/utils/inference/prompt_adapter.py
@@ -31,6 +31,7 @@ from llama_stack.apis.inference import (
SystemMessage,
SystemMessageBehavior,
ToolChoice,
+ ToolDefinition,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
@@ -311,8 +312,6 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> List[Message]:
- assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
-
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
@@ -352,6 +351,10 @@ def augment_messages_for_tools_llama_3_1(
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
+ tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
+ if tool_choice_prompt:
+ sys_content += "\n" + tool_choice_prompt
+
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
@@ -377,8 +380,6 @@ def augment_messages_for_tools_llama_3_1(
def augment_messages_for_tools_llama_3_2(
request: ChatCompletionRequest,
) -> List[Message]:
- assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
-
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
@@ -386,7 +387,6 @@ def augment_messages_for_tools_llama_3_2(
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
- messages = []
sys_content = ""
custom_tools, builtin_tools = [], []
for t in request.tools:
@@ -395,7 +395,6 @@ def augment_messages_for_tools_llama_3_2(
else:
builtin_tools.append(t)
- tool_template = None
if builtin_tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools)
@@ -423,8 +422,22 @@ def augment_messages_for_tools_llama_3_2(
):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
- messages.append(SystemMessage(content=sys_content.strip("\n")))
+ tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
+ if tool_choice_prompt:
+ sys_content += "\n" + tool_choice_prompt
- # Add back existing messages from the request
- messages += existing_messages
+ messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages]
return messages
+
+
+def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
+ if tool_choice == ToolChoice.auto:
+ return ""
+ elif tool_choice == ToolChoice.required:
+ return "You MUST use one of the provided functions/tools to answer the user query."
+ elif tool_choice == ToolChoice.none:
+ # tools are already not passed in
+ return ""
+ else:
+ # specific tool
+ return f"You MUST use the tool `{tool_choice}` to answer the user query."
diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py
index 0369f325b..e5380d357 100644
--- a/tests/client-sdk/agents/test_agents.py
+++ b/tests/client-sdk/agents/test_agents.py
@@ -98,7 +98,6 @@ def agent_config(llama_stack_client, text_model_id):
},
},
toolgroups=[],
- tool_choice="auto",
input_shields=available_shields,
output_shields=available_shields,
enable_session_persistence=False,
@@ -322,6 +321,38 @@ def test_custom_tool(llama_stack_client, agent_config):
assert "get_boiling_point" in logs_str
+def test_tool_choice(llama_stack_client, agent_config):
+ data = [
+ ("required", '{"type": "function"'),
+ ("none", None),
+ ("get_boiling_point", '{"type": "function", "name": "get_boiling_point"'),
+ ]
+ client_tool = TestClientTool()
+ for tool_choice, expected_tool in data:
+ agent_config["tool_config"] = {"tool_choice": tool_choice}
+ agent_config["client_tools"] = [client_tool.get_tool_definition()]
+
+ agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
+ session_id = agent.create_session(f"test-session-{uuid4()}")
+
+ response = agent.create_turn(
+ messages=[
+ {
+ "role": "user",
+ "content": "What is the boiling point of polyjuice?",
+ },
+ ],
+ session_id=session_id,
+ )
+
+ logs = [str(log) for log in EventLogger().log(response) if log is not None]
+ logs_str = "".join(logs)
+ if expected_tool:
+ assert expected_tool in logs_str
+ else:
+ assert '{"type": "function"' not in logs_str
+
+
# TODO: fix this flaky test
def xtest_override_system_message_behavior(llama_stack_client, agent_config):
client_tool = TestClientTool()
diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py
index c931ca255..52d5a24f2 100644
--- a/tests/client-sdk/inference/test_text_inference.py
+++ b/tests/client-sdk/inference/test_text_inference.py
@@ -247,6 +247,42 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
+def test_text_chat_completion_with_tool_choice_required(
+ llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format, inference_provider_type
+):
+ if inference_provider_type == "remote::vllm":
+ pytest.xfail("vllm-project/vllm#13002")
+ response = llama_stack_client.inference.chat_completion(
+ model_id=text_model_id,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What's the weather like in San Francisco?"},
+ ],
+ tools=[get_weather_tool_definition],
+ tool_config={"tool_choice": "required", "tool_prompt_format": provider_tool_format},
+ stream=True,
+ )
+ tool_invocation_content = extract_tool_invocation_content(response)
+ assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
+
+
+def test_text_chat_completion_with_tool_choice_none(
+ llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
+):
+ response = llama_stack_client.inference.chat_completion(
+ model_id=text_model_id,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What's the weather like in San Francisco?"},
+ ],
+ tools=[get_weather_tool_definition],
+ tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
+ stream=True,
+ )
+ tool_invocation_content = extract_tool_invocation_content(response)
+ assert tool_invocation_content == ""
+
+
def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type):
class AnswerFormat(BaseModel):
first_name: str