mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 03:14:19 +00:00
fix: agent config validation (#1053)
Summary: Fixes AgentConfig init bug introduced with ToolConfig. Namely, the below doesn't work ``` agent_config = AgentConfig( **common_params, tool_config=ToolConfig( tool_choice="required", ), ) ``` bvecause tool_choice was defaulted to 'auto' leading to validation check failing. Test Plan: added unittests LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B
This commit is contained in:
parent
6ad272927d
commit
96c88397da
3 changed files with 66 additions and 13 deletions
|
@ -154,7 +154,7 @@ class AgentConfigCommon(BaseModel):
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead")
|
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_config: Optional[ToolConfig] = Field(default=None)
|
tool_config: Optional[ToolConfig] = Field(default=None)
|
||||||
|
|
||||||
|
@ -166,11 +166,13 @@ class AgentConfigCommon(BaseModel):
|
||||||
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
||||||
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
||||||
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
||||||
if self.tool_config is None:
|
else:
|
||||||
self.tool_config = ToolConfig(
|
params = {}
|
||||||
tool_choice=self.tool_choice,
|
if self.tool_choice:
|
||||||
tool_prompt_format=self.tool_prompt_format,
|
params["tool_choice"] = self.tool_choice
|
||||||
)
|
if self.tool_prompt_format:
|
||||||
|
params["tool_prompt_format"] = self.tool_prompt_format
|
||||||
|
self.tool_config = ToolConfig(**params)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -81,12 +81,6 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
) -> AgentCreateResponse:
|
) -> AgentCreateResponse:
|
||||||
agent_id = str(uuid.uuid4())
|
agent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
if agent_config.tool_config is None:
|
|
||||||
agent_config.tool_config = ToolConfig(
|
|
||||||
tool_choice=agent_config.tool_choice,
|
|
||||||
tool_prompt_format=agent_config.tool_prompt_format,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.persistence_store.set(
|
await self.persistence_store.set(
|
||||||
key=f"agent:{agent_id}",
|
key=f"agent:{agent_id}",
|
||||||
value=agent_config.model_dump_json(),
|
value=agent_config.model_dump_json(),
|
||||||
|
|
|
@ -13,11 +13,12 @@ from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.client_tool import ClientTool
|
from llama_stack_client.lib.agents.client_tool import ClientTool
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
from llama_stack_client.types import ToolResponseMessage
|
from llama_stack_client.types import ToolResponseMessage
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||||
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
|
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
|
||||||
from llama_stack_client.types.memory_insert_params import Document
|
from llama_stack_client.types.memory_insert_params import Document
|
||||||
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||||
from llama_stack_client.types.tool_def_param import Parameter
|
from llama_stack_client.types.tool_def_param import Parameter
|
||||||
|
from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig, ToolChoice
|
||||||
|
|
||||||
|
|
||||||
class TestClientTool(ClientTool):
|
class TestClientTool(ClientTool):
|
||||||
|
@ -141,6 +142,62 @@ def test_agent_simple(llama_stack_client, agent_config):
|
||||||
assert "I can't" in logs_str
|
assert "I can't" in logs_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_config(llama_stack_client, agent_config):
|
||||||
|
common_params = dict(
|
||||||
|
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
instructions="You are a helpful assistant",
|
||||||
|
sampling_params={
|
||||||
|
"strategy": {
|
||||||
|
"type": "top_p",
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
toolgroups=[],
|
||||||
|
enable_session_persistence=False,
|
||||||
|
)
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
**common_params,
|
||||||
|
)
|
||||||
|
Server__AgentConfig(**agent_config)
|
||||||
|
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
**common_params,
|
||||||
|
tool_choice="auto",
|
||||||
|
)
|
||||||
|
server_config = Server__AgentConfig(**agent_config)
|
||||||
|
assert server_config.tool_config.tool_choice == ToolChoice.auto
|
||||||
|
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
**common_params,
|
||||||
|
tool_choice="auto",
|
||||||
|
tool_config=ToolConfig(
|
||||||
|
tool_choice="auto",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
server_config = Server__AgentConfig(**agent_config)
|
||||||
|
assert server_config.tool_config.tool_choice == ToolChoice.auto
|
||||||
|
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
**common_params,
|
||||||
|
tool_config=ToolConfig(
|
||||||
|
tool_choice="required",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
server_config = Server__AgentConfig(**agent_config)
|
||||||
|
assert server_config.tool_config.tool_choice == ToolChoice.required
|
||||||
|
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
**common_params,
|
||||||
|
tool_choice="required",
|
||||||
|
tool_config=ToolConfig(
|
||||||
|
tool_choice="auto",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="tool_choice is deprecated"):
|
||||||
|
Server__AgentConfig(**agent_config)
|
||||||
|
|
||||||
|
|
||||||
def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue