diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index ae599563d..151ac1451 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -3612,6 +3612,9 @@
],
"description": "Prompt format for calling custom / zero shot tools."
},
+ "tool_config": {
+ "$ref": "#/components/schemas/ToolConfig"
+ },
"max_infer_iters": {
"type": "integer",
"default": 10
@@ -3881,6 +3884,9 @@
"items": {
"$ref": "#/components/schemas/AgentTool"
}
+ },
+ "tool_config": {
+ "$ref": "#/components/schemas/ToolConfig"
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 2953f1b69..37fba4541 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -2347,6 +2347,8 @@ components:
- python_list
description: >-
Prompt format for calling custom / zero shot tools.
+ tool_config:
+ $ref: '#/components/schemas/ToolConfig'
max_infer_iters:
type: integer
default: 10
@@ -2500,6 +2502,8 @@ components:
type: array
items:
$ref: '#/components/schemas/AgentTool'
+ tool_config:
+ $ref: '#/components/schemas/ToolConfig'
additionalProperties: false
required:
- messages
diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py
index 38c6b5561..95107d99f 100644
--- a/llama_stack/apis/agents/agents.py
+++ b/llama_stack/apis/agents/agents.py
@@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
ToolResponse,
ToolResponseMessage,
UserMessage,
+ ToolConfig,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
@@ -153,11 +154,24 @@ class AgentConfigCommon(BaseModel):
output_shields: Optional[List[str]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
- tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
- tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
+ tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead")
+ tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
+ tool_config: Optional[ToolConfig] = Field(default=None)
max_infer_iters: Optional[int] = 10
+ def model_post_init(self, __context):
+ if self.tool_config:
+ if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
+ raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
+ if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
+ raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
+ if self.tool_config is None:
+ self.tool_config = ToolConfig(
+ tool_choice=self.tool_choice,
+ tool_prompt_format=self.tool_prompt_format,
+ )
+
@json_schema_type
class AgentConfig(AgentConfigCommon):
@@ -268,6 +282,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
toolgroups: Optional[List[AgentToolGroup]] = None
stream: Optional[bool] = False
+ tool_config: Optional[ToolConfig] = None
@json_schema_type
@@ -315,6 +330,7 @@ class Agents(Protocol):
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index f5ddbab40..51691c546 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -496,10 +496,11 @@ class ChatAgent(ShieldRunnerMixin):
tools=[
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
],
- tool_prompt_format=self.agent_config.tool_prompt_format,
+ tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format,
stream=True,
sampling_params=sampling_params,
+ tool_config=self.agent_config.tool_config,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py
index b9e3066c6..8f9fa2d82 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agents.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agents.py
@@ -25,7 +25,12 @@ from llama_stack.apis.agents import (
Session,
Turn,
)
-from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
+from llama_stack.apis.inference import (
+ Inference,
+ ToolConfig,
+ ToolResponseMessage,
+ UserMessage,
+)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
@@ -76,6 +81,12 @@ class MetaReferenceAgentsImpl(Agents):
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
+ if agent_config.tool_config is None:
+ agent_config.tool_config = ToolConfig(
+ tool_choice=agent_config.tool_choice,
+ tool_prompt_format=agent_config.tool_prompt_format,
+ )
+
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.model_dump_json(),
@@ -140,6 +151,7 @@ class MetaReferenceAgentsImpl(Agents):
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
+ tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
@@ -148,6 +160,7 @@ class MetaReferenceAgentsImpl(Agents):
stream=True,
toolgroups=toolgroups,
documents=documents,
+ tool_config=tool_config,
)
if stream:
return self._create_agent_turn_streaming(request)
diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py
index ed599e43d..eb6e68e8f 100644
--- a/tests/client-sdk/agents/test_agents.py
+++ b/tests/client-sdk/agents/test_agents.py
@@ -263,6 +263,88 @@ def test_custom_tool(llama_stack_client, agent_config):
assert "CustomTool" in logs_str
+def test_override_system_message_behavior(llama_stack_client, agent_config):
+ client_tool = TestClientTool()
+ agent_config = {
+ **agent_config,
+ "instructions": "You are a pirate",
+ "client_tools": [client_tool.get_tool_definition()],
+ }
+
+ agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
+ session_id = agent.create_session(f"test-session-{uuid4()}")
+
+ response = agent.create_turn(
+ messages=[
+ {
+ "role": "user",
+ "content": "tell me a joke about bicycles",
+ },
+ ],
+ session_id=session_id,
+ )
+
+ logs = [str(log) for log in EventLogger().log(response) if log is not None]
+ logs_str = "".join(logs)
+ print(logs_str)
+ # can't tell a joke: "I don't have a function"
+ assert "function" in logs_str
+
+ # with system message behavior replace
+ instructions = """
+ You are a helpful assistant. You have access to functions, but you should only use them if they are required.
+
+ You are an expert in composing functions. You are given a question and a set of possible functions.
+ Based on the question, you may or may not need to make one or more function/tool calls to achieve the purpose.
+ If none of the function can be used, don't return [], instead answer the question directly without using functions. If the given question lacks the parameters required by the function,
+ also point it out.
+
+ {{ function_description }}
+ """
+ agent_config = {
+ **agent_config,
+ "instructions": instructions,
+ "client_tools": [client_tool.get_tool_definition()],
+ "tool_config": {
+ "system_message_behavior": "replace",
+ },
+ }
+
+ agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
+ session_id = agent.create_session(f"test-session-{uuid4()}")
+
+ response = agent.create_turn(
+ messages=[
+ {
+ "role": "user",
+ "content": "tell me a joke about bicycles",
+ },
+ ],
+ session_id=session_id,
+ )
+
+ logs = [str(log) for log in EventLogger().log(response) if log is not None]
+ logs_str = "".join(logs)
+ print(logs_str)
+ assert "bicycle" in logs_str
+
+ response = agent.create_turn(
+ messages=[
+ {
+ "role": "user",
+ "content": "What is the boiling point of polyjuice?",
+ },
+ ],
+ session_id=session_id,
+ )
+
+ logs = [str(log) for log in EventLogger().log(response) if log is not None]
+ logs_str = "".join(logs)
+ print(logs_str)
+ assert "-100" in logs_str
+ assert "CustomTool" in logs_str
+
+
def test_rag_agent(llama_stack_client, agent_config):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [