mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 02:09:40 +00:00
agents to use tools api
This commit is contained in:
parent
596afc6497
commit
f90e9c2003
21 changed files with 538 additions and 329 deletions
|
|
@ -7,12 +7,10 @@
|
|||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||
from .fixtures import AGENTS_FIXTURES
|
||||
|
||||
from .fixtures import AGENTS_FIXTURES, TOOL_RUNTIME_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
|
|
@ -21,6 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory",
|
||||
},
|
||||
id="meta_reference",
|
||||
marks=pytest.mark.meta_reference,
|
||||
|
|
@ -31,6 +30,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
|
|
@ -42,6 +42,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
# make this work with Weaviate which is what the together distro supports
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory",
|
||||
},
|
||||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
|
|
@ -52,6 +53,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory",
|
||||
},
|
||||
id="fireworks",
|
||||
marks=pytest.mark.fireworks,
|
||||
|
|
@ -62,6 +64,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "remote",
|
||||
"memory": "remote",
|
||||
"agents": "remote",
|
||||
"tool_runtime": "memory",
|
||||
},
|
||||
id="remote",
|
||||
marks=pytest.mark.remote,
|
||||
|
|
@ -117,6 +120,7 @@ def pytest_generate_tests(metafunc):
|
|||
"safety": SAFETY_FIXTURES,
|
||||
"memory": MEMORY_FIXTURES,
|
||||
"agents": AGENTS_FIXTURES,
|
||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
|
|
|
|||
|
|
@ -10,14 +10,19 @@ import pytest
|
|||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.apis.tools import (
|
||||
ToolDef,
|
||||
ToolGroupInput,
|
||||
ToolParameter,
|
||||
UserDefinedToolGroupDef,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference import (
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
||||
|
||||
|
|
@ -55,7 +60,21 @@ def agents_meta_reference() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_runtime_memory() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="memory-runtime",
|
||||
provider_type="inline::memory-runtime",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
AGENTS_FIXTURES = ["meta_reference", "remote"]
|
||||
TOOL_RUNTIME_FIXTURES = ["memory"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
@ -64,7 +83,7 @@ async def agents_stack(request, inference_model, safety_shield):
|
|||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "safety", "memory", "agents"]:
|
||||
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if key == "inference":
|
||||
|
|
@ -111,12 +130,48 @@ async def agents_stack(request, inference_model, safety_shield):
|
|||
metadata={"embedding_dimension": 384},
|
||||
)
|
||||
)
|
||||
tool_groups = [
|
||||
ToolGroupInput(
|
||||
tool_group_id="memory_group",
|
||||
tool_group=UserDefinedToolGroupDef(
|
||||
tools=[
|
||||
ToolDef(
|
||||
name="memory",
|
||||
description="memory",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="session_id",
|
||||
description="session id",
|
||||
parameter_type="string",
|
||||
required=True,
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_messages",
|
||||
description="messages",
|
||||
parameter_type="list",
|
||||
required=True,
|
||||
),
|
||||
ToolParameter(
|
||||
name="attachments",
|
||||
description="attachments",
|
||||
parameter_type="list",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
),
|
||||
provider_id="memory-runtime",
|
||||
)
|
||||
]
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
|
||||
providers,
|
||||
provider_data,
|
||||
models=models,
|
||||
shields=[safety_shield] if safety_shield else [],
|
||||
tool_groups=tool_groups,
|
||||
)
|
||||
return test_stack
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ from llama_stack.providers.datatypes import Api
|
|||
#
|
||||
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
||||
# -m "meta_reference"
|
||||
|
||||
from .fixtures import pick_inference_model
|
||||
from .utils import create_agent_session
|
||||
|
||||
|
|
@ -255,17 +254,8 @@ class TestAgents:
|
|||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"tools": [
|
||||
MemoryToolDefinition(
|
||||
memory_bank_configs=[],
|
||||
query_generator_config={
|
||||
"type": "default",
|
||||
"sep": " ",
|
||||
},
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=10,
|
||||
),
|
||||
],
|
||||
"tools": [],
|
||||
"preprocessing_tools": ["memory"],
|
||||
"tool_choice": ToolChoice.auto,
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.memory_banks import MemoryBankInput
|
|||
from llama_stack.apis.models import ModelInput
|
||||
from llama_stack.apis.scoring_functions import ScoringFnInput
|
||||
from llama_stack.apis.shields import ShieldInput
|
||||
|
||||
from llama_stack.apis.tools import ToolGroupInput
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
|
|
@ -43,6 +43,7 @@ async def construct_stack_for_test(
|
|||
datasets: Optional[List[DatasetInput]] = None,
|
||||
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
||||
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
||||
tool_groups: Optional[List[ToolGroupInput]] = None,
|
||||
) -> TestStack:
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
run_config = dict(
|
||||
|
|
@ -56,6 +57,7 @@ async def construct_stack_for_test(
|
|||
datasets=datasets or [],
|
||||
scoring_fns=scoring_fns or [],
|
||||
eval_tasks=eval_tasks or [],
|
||||
tool_groups=tool_groups or [],
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue