diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 17cf92341..65a1bdd6b 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -2697,7 +2697,8 @@ "type": "string", "enum": [ "auto", - "required" + "required", + "none" ], "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model." }, @@ -3231,13 +3232,22 @@ "type": "object", "properties": { "tool_choice": { - "type": "string", - "enum": [ - "auto", - "required" + "oneOf": [ + { + "type": "string", + "enum": [ + "auto", + "required", + "none" + ], + "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model." + }, + { + "type": "string" + } ], - "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.", - "default": "auto" + "default": "auto", + "description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto." }, "tool_prompt_format": { "type": "string", @@ -3259,9 +3269,6 @@ } }, "additionalProperties": false, - "required": [ - "system_message_behavior" - ], "description": "Configuration for tool use." }, "ToolDef": { @@ -4100,7 +4107,8 @@ "type": "string", "enum": [ "auto", - "required" + "required", + "none" ], "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model." }, @@ -4384,7 +4392,8 @@ "type": "string", "enum": [ "auto", - "required" + "required", + "none" ], "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. .. deprecated:: Use tool_config instead." }, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index f63374406..60b777e91 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -1637,6 +1637,7 @@ components: enum: - auto - required + - none description: >- Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities @@ -1994,13 +1995,21 @@ components: type: object properties: tool_choice: - type: string - enum: - - auto - - required - description: >- - (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. + oneOf: + - type: string + enum: + - auto + - required + - none + description: >- + Whether tool use is required or automatic. This is a hint to the model + which may not be followed. It depends on the Instruction Following + capabilities of the model. + - type: string default: auto + description: >- + (Optional) Whether tool use is automatic, required, or none. Can also + specify a tool name to use a specific tool. Defaults to ToolChoice.auto. tool_prompt_format: type: string enum: @@ -2027,8 +2036,6 @@ components: where the function definitions should be inserted. default: append additionalProperties: false - required: - - system_message_behavior description: Configuration for tool use. ToolDef: type: object @@ -2533,6 +2540,7 @@ components: enum: - auto - required + - none description: >- Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities @@ -2739,6 +2747,7 @@ components: enum: - auto - required + - none description: >- (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. .. deprecated:: Use tool_config instead. diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 433ba3274..a3fb69477 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -182,10 +182,12 @@ class ToolChoice(Enum): :cvar auto: The model may use tools if it determines that is appropriate. :cvar required: The model must use tools. + :cvar none: The model must not use tools. """ auto = "auto" required = "required" + none = "none" @json_schema_type @@ -326,7 +328,7 @@ class SystemMessageBehavior(Enum): class ToolConfig(BaseModel): """Configuration for tool use. - :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. + :param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto. :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. @@ -337,9 +339,16 @@ class ToolConfig(BaseModel): '{{function_definitions}}' to indicate where the function definitions should be inserted. """ - tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) + tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) - system_message_behavior: SystemMessageBehavior = Field(default=SystemMessageBehavior.append) + system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append) + + def model_post_init(self, __context: Any) -> None: + if isinstance(self.tool_choice, str): + try: + self.tool_choice = ToolChoice[self.tool_choice] + except KeyError: + pass # This is an internally used class diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index f45975189..9d12c8a40 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -128,7 +128,7 @@ class InferenceRouter(Inference): sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_choice: Optional[ToolChoice] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, @@ -140,20 +140,36 @@ class InferenceRouter(Inference): if model.model_type == ModelType.embedding: raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") if tool_config: - if tool_choice != tool_config.tool_choice: + if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") - if tool_prompt_format != tool_config.tool_prompt_format: + if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") else: - tool_config = ToolConfig( - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, - ) + params = {} + if tool_choice: + params["tool_choice"] = tool_choice + if tool_prompt_format: + params["tool_prompt_format"] = tool_prompt_format + tool_config = ToolConfig(**params) + + tools = tools or [] + if tool_config.tool_choice == ToolChoice.none: + tools = [] + elif tool_config.tool_choice == ToolChoice.auto: + pass + elif tool_config.tool_choice == ToolChoice.required: + pass + else: + # verify tool_choice is one of the tools + tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] + if tool_config.tool_choice not in tool_names: + raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") + params = dict( model_id=model_id, messages=messages, sampling_params=sampling_params, - tools=tools or [], + tools=tools, tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, response_format=response_format, diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index b7945dee7..2782c661f 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -31,6 +31,7 @@ from llama_stack.apis.inference import ( SystemMessage, SystemMessageBehavior, ToolChoice, + ToolDefinition, UserMessage, ) from llama_stack.models.llama.datatypes import ( @@ -311,8 +312,6 @@ def response_format_prompt(fmt: Optional[ResponseFormat]): def augment_messages_for_tools_llama_3_1( request: ChatCompletionRequest, ) -> List[Message]: - assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" - existing_messages = request.messages existing_system_message = None if existing_messages[0].role == Role.system.value: @@ -352,6 +351,10 @@ def augment_messages_for_tools_llama_3_1( elif isinstance(existing_system_message.content, list): sys_content += "\n".join([_process(c) for c in existing_system_message.content]) + tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) + if tool_choice_prompt: + sys_content += "\n" + tool_choice_prompt + messages.append(SystemMessage(content=sys_content)) has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) @@ -377,8 +380,6 @@ def augment_messages_for_tools_llama_3_1( def augment_messages_for_tools_llama_3_2( request: ChatCompletionRequest, ) -> List[Message]: - assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" - existing_messages = request.messages existing_system_message = None if existing_messages[0].role == Role.system.value: @@ -386,7 +387,6 @@ def augment_messages_for_tools_llama_3_2( assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" - messages = [] sys_content = "" custom_tools, builtin_tools = [], [] for t in request.tools: @@ -395,7 +395,6 @@ def augment_messages_for_tools_llama_3_2( else: builtin_tools.append(t) - tool_template = None if builtin_tools: tool_gen = BuiltinToolGenerator() tool_template = tool_gen.gen(builtin_tools) @@ -423,8 +422,22 @@ def augment_messages_for_tools_llama_3_2( ): sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n") - messages.append(SystemMessage(content=sys_content.strip("\n"))) + tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) + if tool_choice_prompt: + sys_content += "\n" + tool_choice_prompt - # Add back existing messages from the request - messages += existing_messages + messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages] return messages + + +def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str: + if tool_choice == ToolChoice.auto: + return "" + elif tool_choice == ToolChoice.required: + return "You MUST use one of the provided functions/tools to answer the user query." + elif tool_choice == ToolChoice.none: + # tools are already not passed in + return "" + else: + # specific tool + return f"You MUST use the tool `{tool_choice}` to answer the user query." diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 0369f325b..e5380d357 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -98,7 +98,6 @@ def agent_config(llama_stack_client, text_model_id): }, }, toolgroups=[], - tool_choice="auto", input_shields=available_shields, output_shields=available_shields, enable_session_persistence=False, @@ -322,6 +321,38 @@ def test_custom_tool(llama_stack_client, agent_config): assert "get_boiling_point" in logs_str +def test_tool_choice(llama_stack_client, agent_config): + data = [ + ("required", '{"type": "function"'), + ("none", None), + ("get_boiling_point", '{"type": "function", "name": "get_boiling_point"'), + ] + client_tool = TestClientTool() + for tool_choice, expected_tool in data: + agent_config["tool_config"] = {"tool_choice": tool_choice} + agent_config["client_tools"] = [client_tool.get_tool_definition()] + + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "What is the boiling point of polyjuice?", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + if expected_tool: + assert expected_tool in logs_str + else: + assert '{"type": "function"' not in logs_str + + # TODO: fix this flaky test def xtest_override_system_message_behavior(llama_stack_client, agent_config): client_tool = TestClientTool() diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index c931ca255..52d5a24f2 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -247,6 +247,42 @@ def test_text_chat_completion_with_tool_calling_and_streaming( assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" +def test_text_chat_completion_with_tool_choice_required( + llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format, inference_provider_type +): + if inference_provider_type == "remote::vllm": + pytest.xfail("vllm-project/vllm#13002") + response = llama_stack_client.inference.chat_completion( + model_id=text_model_id, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather like in San Francisco?"}, + ], + tools=[get_weather_tool_definition], + tool_config={"tool_choice": "required", "tool_prompt_format": provider_tool_format}, + stream=True, + ) + tool_invocation_content = extract_tool_invocation_content(response) + assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" + + +def test_text_chat_completion_with_tool_choice_none( + llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format +): + response = llama_stack_client.inference.chat_completion( + model_id=text_model_id, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather like in San Francisco?"}, + ], + tools=[get_weather_tool_definition], + tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format}, + stream=True, + ) + tool_invocation_content = extract_tool_invocation_content(response) + assert tool_invocation_content == "" + + def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type): class AnswerFormat(BaseModel): first_name: str