mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-23 00:27:26 +00:00
feat(agents)!: changing agents API signatures to use OpenAI types
Replace legacy Message/SamplingParams usage with OpenAI chat message structures across agents: schemas, meta-reference implementation, and tests now rely on OpenAI message/tool payloads and generation knobs.
This commit is contained in:
parent
548ccff368
commit
c56b2deb7d
6 changed files with 392 additions and 305 deletions
|
@ -62,14 +62,9 @@ def agent_config(llama_stack_client, text_model_id):
|
|||
agent_config = dict(
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params={
|
||||
"strategy": {
|
||||
"type": "top_p",
|
||||
"temperature": 0.0001,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
"max_tokens": 512,
|
||||
},
|
||||
temperature=0.0001,
|
||||
top_p=0.9,
|
||||
max_output_tokens=512,
|
||||
tools=[],
|
||||
input_shields=available_shields,
|
||||
output_shields=available_shields,
|
||||
|
@ -83,14 +78,9 @@ def agent_config_without_safety(text_model_id):
|
|||
agent_config = dict(
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params={
|
||||
"strategy": {
|
||||
"type": "top_p",
|
||||
"temperature": 0.0001,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
"max_tokens": 512,
|
||||
},
|
||||
temperature=0.0001,
|
||||
top_p=0.9,
|
||||
max_output_tokens=512,
|
||||
tools=[],
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
@ -194,14 +184,9 @@ def test_tool_config(agent_config):
|
|||
common_params = dict(
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params={
|
||||
"strategy": {
|
||||
"type": "top_p",
|
||||
"temperature": 1.0,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
"max_tokens": 512,
|
||||
},
|
||||
temperature=1.0,
|
||||
top_p=0.9,
|
||||
max_output_tokens=512,
|
||||
toolgroups=[],
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
@ -212,40 +197,25 @@ def test_tool_config(agent_config):
|
|||
|
||||
agent_config = AgentConfig(
|
||||
**common_params,
|
||||
tool_choice="auto",
|
||||
tool_config=ToolConfig(tool_choice="auto"),
|
||||
)
|
||||
server_config = Server__AgentConfig(**agent_config)
|
||||
assert server_config.tool_config.tool_choice == ToolChoice.auto
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**common_params,
|
||||
tool_choice="auto",
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
),
|
||||
tool_config=ToolConfig(tool_choice="auto"),
|
||||
)
|
||||
server_config = Server__AgentConfig(**agent_config)
|
||||
assert server_config.tool_config.tool_choice == ToolChoice.auto
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**common_params,
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="required",
|
||||
),
|
||||
tool_config=ToolConfig(tool_choice="required"),
|
||||
)
|
||||
server_config = Server__AgentConfig(**agent_config)
|
||||
assert server_config.tool_config.tool_choice == ToolChoice.required
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**common_params,
|
||||
tool_choice="required",
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
),
|
||||
)
|
||||
with pytest.raises(ValueError, match="tool_choice is deprecated"):
|
||||
Server__AgentConfig(**agent_config)
|
||||
|
||||
|
||||
def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig, Turn
|
||||
from llama_stack.apis.inference import SamplingParams, UserMessage
|
||||
from llama_stack.apis.inference import OpenAIUserMessageParam
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
@ -16,7 +16,7 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|||
@pytest.fixture
|
||||
def sample_messages():
|
||||
return [
|
||||
UserMessage(content="What's the weather like today?"),
|
||||
OpenAIUserMessageParam(content="What's the weather like today?"),
|
||||
]
|
||||
|
||||
|
||||
|
@ -36,7 +36,9 @@ def common_params(inference_model):
|
|||
model=inference_model,
|
||||
instructions="You are a helpful assistant.",
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
temperature=0.7,
|
||||
top_p=0.95,
|
||||
max_output_tokens=256,
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
|
|
|
@ -69,30 +69,26 @@ async def agents_impl(config, mock_apis):
|
|||
@pytest.fixture
|
||||
def sample_agent_config():
|
||||
return AgentConfig(
|
||||
sampling_params={
|
||||
"strategy": {"type": "greedy"},
|
||||
"max_tokens": 0,
|
||||
"repetition_penalty": 1.0,
|
||||
},
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
max_output_tokens=0,
|
||||
input_shields=["string"],
|
||||
output_shields=["string"],
|
||||
toolgroups=["mcp::my_mcp_server"],
|
||||
client_tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"name": "client_tool",
|
||||
"description": "Client Tool",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "string",
|
||||
"parameter_type": "string",
|
||||
"description": "string",
|
||||
"required": True,
|
||||
"default": None,
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"property1": None,
|
||||
"property2": None,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"string": {
|
||||
"type": "string",
|
||||
"description": "string",
|
||||
}
|
||||
},
|
||||
"required": ["string"],
|
||||
},
|
||||
}
|
||||
],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue