diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index ae599563d..151ac1451 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -3612,6 +3612,9 @@ ], "description": "Prompt format for calling custom / zero shot tools." }, + "tool_config": { + "$ref": "#/components/schemas/ToolConfig" + }, "max_infer_iters": { "type": "integer", "default": 10 @@ -3881,6 +3884,9 @@ "items": { "$ref": "#/components/schemas/AgentTool" } + }, + "tool_config": { + "$ref": "#/components/schemas/ToolConfig" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 2953f1b69..37fba4541 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2347,6 +2347,8 @@ components: - python_list description: >- Prompt format for calling custom / zero shot tools. + tool_config: + $ref: '#/components/schemas/ToolConfig' max_infer_iters: type: integer default: 10 @@ -2500,6 +2502,8 @@ components: type: array items: $ref: '#/components/schemas/AgentTool' + tool_config: + $ref: '#/components/schemas/ToolConfig' additionalProperties: false required: - messages diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 38c6b5561..95107d99f 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -33,6 +33,7 @@ from llama_stack.apis.inference import ( ToolResponse, ToolResponseMessage, UserMessage, + ToolConfig, ) from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef @@ -153,11 +154,24 @@ class AgentConfigCommon(BaseModel): output_shields: Optional[List[str]] = Field(default_factory=list) toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list) client_tools: Optional[List[ToolDef]] = Field(default_factory=list) - tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) - tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) + tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead") + tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead") + tool_config: Optional[ToolConfig] = Field(default=None) max_infer_iters: Optional[int] = 10 + def model_post_init(self, __context): + if self.tool_config: + if self.tool_choice and self.tool_config.tool_choice != self.tool_choice: + raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.") + if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format: + raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.") + if self.tool_config is None: + self.tool_config = ToolConfig( + tool_choice=self.tool_choice, + tool_prompt_format=self.tool_prompt_format, + ) + @json_schema_type class AgentConfig(AgentConfigCommon): @@ -268,6 +282,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): toolgroups: Optional[List[AgentToolGroup]] = None stream: Optional[bool] = False + tool_config: Optional[ToolConfig] = None @json_schema_type @@ -315,6 +330,7 @@ class Agents(Protocol): stream: Optional[bool] = False, documents: Optional[List[Document]] = None, toolgroups: Optional[List[AgentToolGroup]] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET") diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index f5ddbab40..51691c546 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -496,10 +496,11 @@ class ChatAgent(ShieldRunnerMixin): tools=[ tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP ], - tool_prompt_format=self.agent_config.tool_prompt_format, + tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, response_format=self.agent_config.response_format, stream=True, sampling_params=sampling_params, + tool_config=self.agent_config.tool_config, ): event = chunk.event if event.event_type == ChatCompletionResponseEventType.start: diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index b9e3066c6..8f9fa2d82 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -25,7 +25,12 @@ from llama_stack.apis.agents import ( Session, Turn, ) -from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage +from llama_stack.apis.inference import ( + Inference, + ToolConfig, + ToolResponseMessage, + UserMessage, +) from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO @@ -76,6 +81,12 @@ class MetaReferenceAgentsImpl(Agents): ) -> AgentCreateResponse: agent_id = str(uuid.uuid4()) + if agent_config.tool_config is None: + agent_config.tool_config = ToolConfig( + tool_choice=agent_config.tool_choice, + tool_prompt_format=agent_config.tool_prompt_format, + ) + await self.persistence_store.set( key=f"agent:{agent_id}", value=agent_config.model_dump_json(), @@ -140,6 +151,7 @@ class MetaReferenceAgentsImpl(Agents): toolgroups: Optional[List[AgentToolGroup]] = None, documents: Optional[List[Document]] = None, stream: Optional[bool] = False, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: request = AgentTurnCreateRequest( agent_id=agent_id, @@ -148,6 +160,7 @@ class MetaReferenceAgentsImpl(Agents): stream=True, toolgroups=toolgroups, documents=documents, + tool_config=tool_config, ) if stream: return self._create_agent_turn_streaming(request) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index ed599e43d..eb6e68e8f 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -263,6 +263,88 @@ def test_custom_tool(llama_stack_client, agent_config): assert "CustomTool" in logs_str +def test_override_system_message_behavior(llama_stack_client, agent_config): + client_tool = TestClientTool() + agent_config = { + **agent_config, + "instructions": "You are a pirate", + "client_tools": [client_tool.get_tool_definition()], + } + + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "tell me a joke about bicycles", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + print(logs_str) + # can't tell a joke: "I don't have a function" + assert "function" in logs_str + + # with system message behavior replace + instructions = """ + You are a helpful assistant. You have access to functions, but you should only use them if they are required. + + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you may or may not need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, don't return [], instead answer the question directly without using functions. If the given question lacks the parameters required by the function, + also point it out. + + {{ function_description }} + """ + agent_config = { + **agent_config, + "instructions": instructions, + "client_tools": [client_tool.get_tool_definition()], + "tool_config": { + "system_message_behavior": "replace", + }, + } + + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "tell me a joke about bicycles", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + print(logs_str) + assert "bicycle" in logs_str + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "What is the boiling point of polyjuice?", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + print(logs_str) + assert "-100" in logs_str + assert "CustomTool" in logs_str + + def test_rag_agent(llama_stack_client, agent_config): urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] documents = [