fix: agent config validation (#1053)

Summary:

Fixes AgentConfig init bug introduced with ToolConfig.

Namely, the below doesn't work
```
    agent_config = AgentConfig(
        **common_params,
        tool_config=ToolConfig(
            tool_choice="required",
        ),
    )
```
bvecause tool_choice was defaulted to 'auto' leading to validation check
failing.

Test Plan:

added unittests

LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/
--safety-shield meta-llama/Llama-Guard-3-8B
This commit is contained in:
ehhuang 2025-02-11 14:48:42 -08:00 committed by GitHub
parent 6ad272927d
commit 96c88397da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 66 additions and 13 deletions

View file

@ -154,7 +154,7 @@ class AgentConfigCommon(BaseModel):
output_shields: Optional[List[str]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead")
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
tool_config: Optional[ToolConfig] = Field(default=None)
@ -166,11 +166,13 @@ class AgentConfigCommon(BaseModel):
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
if self.tool_config is None:
self.tool_config = ToolConfig(
tool_choice=self.tool_choice,
tool_prompt_format=self.tool_prompt_format,
)
else:
params = {}
if self.tool_choice:
params["tool_choice"] = self.tool_choice
if self.tool_prompt_format:
params["tool_prompt_format"] = self.tool_prompt_format
self.tool_config = ToolConfig(**params)
@json_schema_type

View file

@ -81,12 +81,6 @@ class MetaReferenceAgentsImpl(Agents):
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
if agent_config.tool_config is None:
agent_config.tool_config = ToolConfig(
tool_choice=agent_config.tool_choice,
tool_prompt_format=agent_config.tool_prompt_format,
)
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.model_dump_json(),

View file

@ -13,11 +13,12 @@ from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.client_tool import ClientTool
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.tool_def_param import Parameter
from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig, ToolChoice
class TestClientTool(ClientTool):
@ -141,6 +142,62 @@ def test_agent_simple(llama_stack_client, agent_config):
assert "I can't" in logs_str
def test_tool_config(llama_stack_client, agent_config):
common_params = dict(
model="meta-llama/Llama-3.2-3B-Instruct",
instructions="You are a helpful assistant",
sampling_params={
"strategy": {
"type": "top_p",
"temperature": 1.0,
"top_p": 0.9,
},
},
toolgroups=[],
enable_session_persistence=False,
)
agent_config = AgentConfig(
**common_params,
)
Server__AgentConfig(**agent_config)
agent_config = AgentConfig(
**common_params,
tool_choice="auto",
)
server_config = Server__AgentConfig(**agent_config)
assert server_config.tool_config.tool_choice == ToolChoice.auto
agent_config = AgentConfig(
**common_params,
tool_choice="auto",
tool_config=ToolConfig(
tool_choice="auto",
),
)
server_config = Server__AgentConfig(**agent_config)
assert server_config.tool_config.tool_choice == ToolChoice.auto
agent_config = AgentConfig(
**common_params,
tool_config=ToolConfig(
tool_choice="required",
),
)
server_config = Server__AgentConfig(**agent_config)
assert server_config.tool_config.tool_choice == ToolChoice.required
agent_config = AgentConfig(
**common_params,
tool_choice="required",
tool_config=ToolConfig(
tool_choice="auto",
),
)
with pytest.raises(ValueError, match="tool_choice is deprecated"):
Server__AgentConfig(**agent_config)
def test_builtin_tool_web_search(llama_stack_client, agent_config):
agent_config = {
**agent_config,