diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 95107d99f..785248633 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -154,7 +154,7 @@ class AgentConfigCommon(BaseModel): output_shields: Optional[List[str]] = Field(default_factory=list) toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list) client_tools: Optional[List[ToolDef]] = Field(default_factory=list) - tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead") + tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead") tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead") tool_config: Optional[ToolConfig] = Field(default=None) @@ -166,11 +166,13 @@ class AgentConfigCommon(BaseModel): raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.") if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format: raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.") - if self.tool_config is None: - self.tool_config = ToolConfig( - tool_choice=self.tool_choice, - tool_prompt_format=self.tool_prompt_format, - ) + else: + params = {} + if self.tool_choice: + params["tool_choice"] = self.tool_choice + if self.tool_prompt_format: + params["tool_prompt_format"] = self.tool_prompt_format + self.tool_config = ToolConfig(**params) @json_schema_type diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 8f9fa2d82..fe4ccd1a3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -81,12 +81,6 @@ class MetaReferenceAgentsImpl(Agents): ) -> AgentCreateResponse: agent_id = str(uuid.uuid4()) - if agent_config.tool_config is None: - agent_config.tool_config = ToolConfig( - tool_choice=agent_config.tool_choice, - tool_prompt_format=agent_config.tool_prompt_format, - ) - await self.persistence_store.set( key=f"agent:{agent_id}", value=agent_config.model_dump_json(), diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 85b7af831..d14a7003f 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -13,11 +13,12 @@ from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types import ToolResponseMessage -from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.tool_def_param import Parameter +from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig, ToolChoice class TestClientTool(ClientTool): @@ -141,6 +142,62 @@ def test_agent_simple(llama_stack_client, agent_config): assert "I can't" in logs_str +def test_tool_config(llama_stack_client, agent_config): + common_params = dict( + model="meta-llama/Llama-3.2-3B-Instruct", + instructions="You are a helpful assistant", + sampling_params={ + "strategy": { + "type": "top_p", + "temperature": 1.0, + "top_p": 0.9, + }, + }, + toolgroups=[], + enable_session_persistence=False, + ) + agent_config = AgentConfig( + **common_params, + ) + Server__AgentConfig(**agent_config) + + agent_config = AgentConfig( + **common_params, + tool_choice="auto", + ) + server_config = Server__AgentConfig(**agent_config) + assert server_config.tool_config.tool_choice == ToolChoice.auto + + agent_config = AgentConfig( + **common_params, + tool_choice="auto", + tool_config=ToolConfig( + tool_choice="auto", + ), + ) + server_config = Server__AgentConfig(**agent_config) + assert server_config.tool_config.tool_choice == ToolChoice.auto + + agent_config = AgentConfig( + **common_params, + tool_config=ToolConfig( + tool_choice="required", + ), + ) + server_config = Server__AgentConfig(**agent_config) + assert server_config.tool_config.tool_choice == ToolChoice.required + + agent_config = AgentConfig( + **common_params, + tool_choice="required", + tool_config=ToolConfig( + tool_choice="auto", + ), + ) + with pytest.raises(ValueError, match="tool_choice is deprecated"): + Server__AgentConfig(**agent_config) + + def test_builtin_tool_web_search(llama_stack_client, agent_config): agent_config = { **agent_config,