mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:19:49 +00:00
agents to use tools api
This commit is contained in:
parent
596afc6497
commit
f90e9c2003
21 changed files with 538 additions and 329 deletions
|
|
@ -10,14 +10,19 @@ import pytest
|
|||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.apis.tools import (
|
||||
ToolDef,
|
||||
ToolGroupInput,
|
||||
ToolParameter,
|
||||
UserDefinedToolGroupDef,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference import (
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
||||
|
||||
|
|
@ -55,7 +60,21 @@ def agents_meta_reference() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_runtime_memory() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="memory-runtime",
|
||||
provider_type="inline::memory-runtime",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
AGENTS_FIXTURES = ["meta_reference", "remote"]
|
||||
TOOL_RUNTIME_FIXTURES = ["memory"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
@ -64,7 +83,7 @@ async def agents_stack(request, inference_model, safety_shield):
|
|||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "safety", "memory", "agents"]:
|
||||
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if key == "inference":
|
||||
|
|
@ -111,12 +130,48 @@ async def agents_stack(request, inference_model, safety_shield):
|
|||
metadata={"embedding_dimension": 384},
|
||||
)
|
||||
)
|
||||
tool_groups = [
|
||||
ToolGroupInput(
|
||||
tool_group_id="memory_group",
|
||||
tool_group=UserDefinedToolGroupDef(
|
||||
tools=[
|
||||
ToolDef(
|
||||
name="memory",
|
||||
description="memory",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="session_id",
|
||||
description="session id",
|
||||
parameter_type="string",
|
||||
required=True,
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_messages",
|
||||
description="messages",
|
||||
parameter_type="list",
|
||||
required=True,
|
||||
),
|
||||
ToolParameter(
|
||||
name="attachments",
|
||||
description="attachments",
|
||||
parameter_type="list",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
),
|
||||
provider_id="memory-runtime",
|
||||
)
|
||||
]
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
|
||||
providers,
|
||||
provider_data,
|
||||
models=models,
|
||||
shields=[safety_shield] if safety_shield else [],
|
||||
tool_groups=tool_groups,
|
||||
)
|
||||
return test_stack
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue