From 3922999118fc50251a49e6437d39be7eeaed8f48 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 5 Feb 2025 21:11:32 -0800 Subject: [PATCH] sys_prompt support in Agent (#938) # What does this PR do? The current default system prompt for llama3.2 tends to overindex on tool calling and doesn't work well when the prompt does not require tool calling. This PR adds an option to override the default system prompt, and organizes tool-related configs into a new config object. - [ ] Addresses issue (#issue) ## Test Plan LLAMA_STACK_CONFIG=together pytest \-\-inference\-model=meta\-llama/Llama\-3\.3\-70B\-Instruct -s -v tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- docs/_static/llama-stack-spec.html | 6 ++ docs/_static/llama-stack-spec.yaml | 4 + llama_stack/apis/agents/agents.py | 20 ++++- .../agents/meta_reference/agent_instance.py | 3 +- .../inline/agents/meta_reference/agents.py | 15 +++- tests/client-sdk/agents/test_agents.py | 82 +++++++++++++++++++ 6 files changed, 126 insertions(+), 4 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index ae599563d..151ac1451 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -3612,6 +3612,9 @@ ], "description": "Prompt format for calling custom / zero shot tools." }, + "tool_config": { + "$ref": "#/components/schemas/ToolConfig" + }, "max_infer_iters": { "type": "integer", "default": 10 @@ -3881,6 +3884,9 @@ "items": { "$ref": "#/components/schemas/AgentTool" } + }, + "tool_config": { + "$ref": "#/components/schemas/ToolConfig" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 2953f1b69..37fba4541 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2347,6 +2347,8 @@ components: - python_list description: >- Prompt format for calling custom / zero shot tools. + tool_config: + $ref: '#/components/schemas/ToolConfig' max_infer_iters: type: integer default: 10 @@ -2500,6 +2502,8 @@ components: type: array items: $ref: '#/components/schemas/AgentTool' + tool_config: + $ref: '#/components/schemas/ToolConfig' additionalProperties: false required: - messages diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 38c6b5561..95107d99f 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -33,6 +33,7 @@ from llama_stack.apis.inference import ( ToolResponse, ToolResponseMessage, UserMessage, + ToolConfig, ) from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef @@ -153,11 +154,24 @@ class AgentConfigCommon(BaseModel): output_shields: Optional[List[str]] = Field(default_factory=list) toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list) client_tools: Optional[List[ToolDef]] = Field(default_factory=list) - tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) - tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) + tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead") + tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead") + tool_config: Optional[ToolConfig] = Field(default=None) max_infer_iters: Optional[int] = 10 + def model_post_init(self, __context): + if self.tool_config: + if self.tool_choice and self.tool_config.tool_choice != self.tool_choice: + raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.") + if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format: + raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.") + if self.tool_config is None: + self.tool_config = ToolConfig( + tool_choice=self.tool_choice, + tool_prompt_format=self.tool_prompt_format, + ) + @json_schema_type class AgentConfig(AgentConfigCommon): @@ -268,6 +282,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): toolgroups: Optional[List[AgentToolGroup]] = None stream: Optional[bool] = False + tool_config: Optional[ToolConfig] = None @json_schema_type @@ -315,6 +330,7 @@ class Agents(Protocol): stream: Optional[bool] = False, documents: Optional[List[Document]] = None, toolgroups: Optional[List[AgentToolGroup]] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET") diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index f5ddbab40..51691c546 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -496,10 +496,11 @@ class ChatAgent(ShieldRunnerMixin): tools=[ tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP ], - tool_prompt_format=self.agent_config.tool_prompt_format, + tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, response_format=self.agent_config.response_format, stream=True, sampling_params=sampling_params, + tool_config=self.agent_config.tool_config, ): event = chunk.event if event.event_type == ChatCompletionResponseEventType.start: diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index b9e3066c6..8f9fa2d82 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -25,7 +25,12 @@ from llama_stack.apis.agents import ( Session, Turn, ) -from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage +from llama_stack.apis.inference import ( + Inference, + ToolConfig, + ToolResponseMessage, + UserMessage, +) from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO @@ -76,6 +81,12 @@ class MetaReferenceAgentsImpl(Agents): ) -> AgentCreateResponse: agent_id = str(uuid.uuid4()) + if agent_config.tool_config is None: + agent_config.tool_config = ToolConfig( + tool_choice=agent_config.tool_choice, + tool_prompt_format=agent_config.tool_prompt_format, + ) + await self.persistence_store.set( key=f"agent:{agent_id}", value=agent_config.model_dump_json(), @@ -140,6 +151,7 @@ class MetaReferenceAgentsImpl(Agents): toolgroups: Optional[List[AgentToolGroup]] = None, documents: Optional[List[Document]] = None, stream: Optional[bool] = False, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: request = AgentTurnCreateRequest( agent_id=agent_id, @@ -148,6 +160,7 @@ class MetaReferenceAgentsImpl(Agents): stream=True, toolgroups=toolgroups, documents=documents, + tool_config=tool_config, ) if stream: return self._create_agent_turn_streaming(request) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index ed599e43d..eb6e68e8f 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -263,6 +263,88 @@ def test_custom_tool(llama_stack_client, agent_config): assert "CustomTool" in logs_str +def test_override_system_message_behavior(llama_stack_client, agent_config): + client_tool = TestClientTool() + agent_config = { + **agent_config, + "instructions": "You are a pirate", + "client_tools": [client_tool.get_tool_definition()], + } + + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "tell me a joke about bicycles", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + print(logs_str) + # can't tell a joke: "I don't have a function" + assert "function" in logs_str + + # with system message behavior replace + instructions = """ + You are a helpful assistant. You have access to functions, but you should only use them if they are required. + + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you may or may not need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, don't return [], instead answer the question directly without using functions. If the given question lacks the parameters required by the function, + also point it out. + + {{ function_description }} + """ + agent_config = { + **agent_config, + "instructions": instructions, + "client_tools": [client_tool.get_tool_definition()], + "tool_config": { + "system_message_behavior": "replace", + }, + } + + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "tell me a joke about bicycles", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + print(logs_str) + assert "bicycle" in logs_str + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "What is the boiling point of polyjuice?", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + print(logs_str) + assert "-100" in logs_str + assert "CustomTool" in logs_str + + def test_rag_agent(llama_stack_client, agent_config): urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] documents = [