mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 08:10:10 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
0ac4d2fced
119 changed files with 5069 additions and 2779 deletions
|
|
@ -7,13 +7,12 @@
|
|||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
||||
from .fixtures import AGENTS_FIXTURES
|
||||
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
|
|
@ -21,6 +20,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="meta_reference",
|
||||
marks=pytest.mark.meta_reference,
|
||||
|
|
@ -31,6 +31,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
|
|
@ -42,6 +43,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
# make this work with Weaviate which is what the together distro supports
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
|
|
@ -52,6 +54,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="fireworks",
|
||||
marks=pytest.mark.fireworks,
|
||||
|
|
@ -62,6 +65,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "remote",
|
||||
"memory": "remote",
|
||||
"agents": "remote",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="remote",
|
||||
marks=pytest.mark.remote,
|
||||
|
|
@ -117,6 +121,7 @@ def pytest_generate_tests(metafunc):
|
|||
"safety": SAFETY_FIXTURES,
|
||||
"memory": MEMORY_FIXTURES,
|
||||
"agents": AGENTS_FIXTURES,
|
||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
|
|
|
|||
|
|
@ -11,13 +11,12 @@ import pytest_asyncio
|
|||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference import (
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
||||
|
||||
|
|
@ -59,12 +58,18 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_stack(request, inference_model, safety_shield):
|
||||
async def agents_stack(
|
||||
request,
|
||||
inference_model,
|
||||
safety_shield,
|
||||
tool_group_input_memory,
|
||||
tool_group_input_tavily_search,
|
||||
):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "safety", "memory", "agents"]:
|
||||
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if key == "inference":
|
||||
|
|
@ -113,10 +118,11 @@ async def agents_stack(request, inference_model, safety_shield):
|
|||
)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
|
||||
providers,
|
||||
provider_data,
|
||||
models=models,
|
||||
shields=[safety_shield] if safety_shield else [],
|
||||
tool_groups=[tool_group_input_memory, tool_group_input_tavily_search],
|
||||
)
|
||||
return test_stack
|
||||
|
|
|
|||
|
|
@ -5,22 +5,17 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTool,
|
||||
AgentTurnResponseEventType,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
Attachment,
|
||||
MemoryToolDefinition,
|
||||
SearchEngineType,
|
||||
SearchToolDefinition,
|
||||
Document,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolChoice,
|
||||
|
|
@ -35,7 +30,6 @@ from llama_stack.providers.datatypes import Api
|
|||
#
|
||||
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
||||
# -m "meta_reference"
|
||||
|
||||
from .fixtures import pick_inference_model
|
||||
from .utils import create_agent_session
|
||||
|
||||
|
|
@ -51,7 +45,7 @@ def common_params(inference_model):
|
|||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
toolgroups=[],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
|
|
@ -88,73 +82,6 @@ def query_attachment_messages():
|
|||
]
|
||||
|
||||
|
||||
async def create_agent_turn_with_search_tool(
|
||||
agents_stack: Dict[str, object],
|
||||
search_query_messages: List[object],
|
||||
common_params: Dict[str, str],
|
||||
search_tool_definition: SearchToolDefinition,
|
||||
) -> None:
|
||||
"""
|
||||
Create an agent turn with a search tool.
|
||||
|
||||
Args:
|
||||
agents_stack (Dict[str, object]): The agents stack.
|
||||
search_query_messages (List[object]): The search query messages.
|
||||
common_params (Dict[str, str]): The common parameters.
|
||||
search_tool_definition (SearchToolDefinition): The search tool definition.
|
||||
"""
|
||||
|
||||
# Create an agent with the search tool
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"tools": [search_tool_definition],
|
||||
}
|
||||
)
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_stack.impls[Api.agents], agent_config
|
||||
)
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=search_query_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk
|
||||
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
|
||||
**turn_request
|
||||
)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
)
|
||||
|
||||
check_event_types(turn_response)
|
||||
|
||||
# Check for tool execution events
|
||||
tool_execution_events = [
|
||||
chunk
|
||||
for chunk in turn_response
|
||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
|
||||
]
|
||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||
|
||||
# Check the tool execution details
|
||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||
|
||||
|
||||
class TestAgents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turns_with_safety(
|
||||
|
|
@ -227,7 +154,7 @@ class TestAgents:
|
|||
check_turn_complete_event(turn_response, session_id, sample_messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_agent_as_attachments(
|
||||
async def test_rag_agent(
|
||||
self,
|
||||
agents_stack,
|
||||
attachment_message,
|
||||
|
|
@ -243,29 +170,17 @@ class TestAgents:
|
|||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
|
||||
attachments = [
|
||||
Attachment(
|
||||
documents = [
|
||||
Document(
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"tools": [
|
||||
MemoryToolDefinition(
|
||||
memory_bank_configs=[],
|
||||
query_generator_config={
|
||||
"type": "default",
|
||||
"sep": " ",
|
||||
},
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=10,
|
||||
),
|
||||
],
|
||||
"toolgroups": ["builtin::memory"],
|
||||
"tool_choice": ToolChoice.auto,
|
||||
}
|
||||
)
|
||||
|
|
@ -275,7 +190,7 @@ class TestAgents:
|
|||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=attachment_message,
|
||||
attachments=attachments,
|
||||
documents=documents,
|
||||
stream=True,
|
||||
)
|
||||
turn_response = [
|
||||
|
|
@ -298,22 +213,6 @@ class TestAgents:
|
|||
|
||||
assert len(turn_response) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_brave_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
):
|
||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
search_tool_definition = SearchToolDefinition(
|
||||
type=AgentTool.brave_search.value,
|
||||
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
||||
engine=SearchEngineType.brave,
|
||||
)
|
||||
await create_agent_turn_with_search_tool(
|
||||
agents_stack, search_query_messages, common_params, search_tool_definition
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_tavily_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
|
|
@ -321,14 +220,57 @@ class TestAgents:
|
|||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
search_tool_definition = SearchToolDefinition(
|
||||
type=AgentTool.brave_search.value, # place holder only
|
||||
api_key=os.environ["TAVILY_SEARCH_API_KEY"],
|
||||
engine=SearchEngineType.tavily,
|
||||
# Create an agent with the toolgroup
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"toolgroups": ["builtin::web_search"],
|
||||
}
|
||||
)
|
||||
await create_agent_turn_with_search_tool(
|
||||
agents_stack, search_query_messages, common_params, search_tool_definition
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_stack.impls[Api.agents], agent_config
|
||||
)
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=search_query_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk
|
||||
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
|
||||
**turn_request
|
||||
)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
)
|
||||
|
||||
check_event_types(turn_response)
|
||||
|
||||
# Check for tool execution events
|
||||
tool_execution_events = [
|
||||
chunk
|
||||
for chunk in turn_response
|
||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||
and chunk.event.payload.step_details.step_type
|
||||
== StepType.tool_execution.value
|
||||
]
|
||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||
|
||||
# Check the tool execution details
|
||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
||||
assert actual_tool_name == BuiltinTool.brave_search
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||
|
||||
|
||||
def check_event_types(turn_response):
|
||||
|
|
|
|||
|
|
@ -157,4 +157,5 @@ pytest_plugins = [
|
|||
"llama_stack.providers.tests.scoring.fixtures",
|
||||
"llama_stack.providers.tests.eval.fixtures",
|
||||
"llama_stack.providers.tests.post_training.fixtures",
|
||||
"llama_stack.providers.tests.tools.fixtures",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
|||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.memory_banks import MemoryBankInput
|
|||
from llama_stack.apis.models import ModelInput
|
||||
from llama_stack.apis.scoring_functions import ScoringFnInput
|
||||
from llama_stack.apis.shields import ShieldInput
|
||||
|
||||
from llama_stack.apis.tools import ToolGroupInput
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
|
|
@ -43,6 +43,7 @@ async def construct_stack_for_test(
|
|||
datasets: Optional[List[DatasetInput]] = None,
|
||||
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
||||
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
||||
tool_groups: Optional[List[ToolGroupInput]] = None,
|
||||
) -> TestStack:
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
run_config = dict(
|
||||
|
|
@ -56,6 +57,7 @@ async def construct_stack_for_test(
|
|||
datasets=datasets or [],
|
||||
scoring_fns=scoring_fns or [],
|
||||
eval_tasks=eval_tasks or [],
|
||||
tool_groups=tool_groups or [],
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
try:
|
||||
|
|
|
|||
5
llama_stack/providers/tests/tools/__init__.py
Normal file
5
llama_stack/providers/tests/tools/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
65
llama_stack/providers/tests/tools/conftest.py
Normal file
65
llama_stack/providers/tests/tools/conftest.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES
|
||||
from .fixtures import TOOL_RUNTIME_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "together",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for mark in ["together"]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{mark}: marks tests as {mark} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--inference-model",
|
||||
action="store",
|
||||
default="meta-llama/Llama-3.2-3B-Instruct",
|
||||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default="meta-llama/Llama-Guard-3-1B",
|
||||
help="Specify the safety shield to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "tools_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"safety": SAFETY_FIXTURES,
|
||||
"memory": MEMORY_FIXTURES,
|
||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
print(combinations)
|
||||
metafunc.parametrize("tools_stack", combinations, indirect=True)
|
||||
130
llama_stack/providers/tests/tools/fixtures.py
Normal file
130
llama_stack/providers/tests/tools/fixtures.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.apis.tools import ToolGroupInput
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_runtime_memory_and_search() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="memory-runtime",
|
||||
provider_type="inline::memory-runtime",
|
||||
config={},
|
||||
),
|
||||
Provider(
|
||||
provider_id="tavily-search",
|
||||
provider_type="remote::tavily-search",
|
||||
config={
|
||||
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
|
||||
},
|
||||
),
|
||||
Provider(
|
||||
provider_id="wolfram-alpha",
|
||||
provider_type="remote::wolfram-alpha",
|
||||
config={
|
||||
"api_key": os.environ["WOLFRAM_ALPHA_API_KEY"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_group_input_memory() -> ToolGroupInput:
|
||||
return ToolGroupInput(
|
||||
toolgroup_id="builtin::memory",
|
||||
provider_id="memory-runtime",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_group_input_tavily_search() -> ToolGroupInput:
|
||||
return ToolGroupInput(
|
||||
toolgroup_id="builtin::web_search",
|
||||
provider_id="tavily-search",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_group_input_wolfram_alpha() -> ToolGroupInput:
|
||||
return ToolGroupInput(
|
||||
toolgroup_id="builtin::wolfram_alpha",
|
||||
provider_id="wolfram-alpha",
|
||||
)
|
||||
|
||||
|
||||
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def tools_stack(
|
||||
request,
|
||||
inference_model,
|
||||
tool_group_input_memory,
|
||||
tool_group_input_tavily_search,
|
||||
tool_group_input_wolfram_alpha,
|
||||
):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "memory", "tool_runtime"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if key == "inference":
|
||||
providers[key].append(
|
||||
Provider(
|
||||
provider_id="tools_memory_provider",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config={},
|
||||
)
|
||||
)
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
inference_models = (
|
||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||
)
|
||||
models = [
|
||||
ModelInput(
|
||||
model_id=model,
|
||||
model_type=ModelType.llm,
|
||||
provider_id=providers["inference"][0].provider_id,
|
||||
)
|
||||
for model in inference_models
|
||||
]
|
||||
models.append(
|
||||
ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
model_type=ModelType.embedding,
|
||||
provider_id="tools_memory_provider",
|
||||
metadata={"embedding_dimension": 384},
|
||||
)
|
||||
)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime],
|
||||
providers,
|
||||
provider_data,
|
||||
models=models,
|
||||
tool_groups=[
|
||||
tool_group_input_tavily_search,
|
||||
tool_group_input_wolfram_alpha,
|
||||
tool_group_input_memory,
|
||||
],
|
||||
)
|
||||
return test_stack
|
||||
127
llama_stack/providers/tests/tools/test_tools.py
Normal file
127
llama_stack/providers/tests/tools/test_tools.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.apis.memory import MemoryBankDocument
|
||||
from llama_stack.apis.memory_banks import VectorMemoryBankParams
|
||||
from llama_stack.apis.tools import ToolInvocationResult
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_query():
|
||||
return "What are the latest developments in quantum computing?"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_wolfram_alpha_query():
|
||||
return "What is the square root of 16?"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents():
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
return [
|
||||
MemoryBankDocument(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
metadata={},
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
|
||||
class TestTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_tool(self, tools_stack, sample_search_query):
|
||||
"""Test the web search tool functionality."""
|
||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||
|
||||
# Execute the tool
|
||||
response = await tools_impl.invoke_tool(
|
||||
tool_name="web_search", args={"query": sample_search_query}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, ToolInvocationResult)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wolfram_alpha_tool(self, tools_stack, sample_wolfram_alpha_query):
|
||||
"""Test the wolfram alpha tool functionality."""
|
||||
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
||||
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
||||
|
||||
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||
|
||||
response = await tools_impl.invoke_tool(
|
||||
tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, ToolInvocationResult)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_tool(self, tools_stack, sample_documents):
|
||||
"""Test the memory tool functionality."""
|
||||
memory_banks_impl = tools_stack.impls[Api.memory_banks]
|
||||
memory_impl = tools_stack.impls[Api.memory]
|
||||
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||
|
||||
# Register memory bank
|
||||
await memory_banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
provider_id="faiss",
|
||||
)
|
||||
|
||||
# Insert documents into memory
|
||||
await memory_impl.insert_documents(
|
||||
bank_id="test_bank",
|
||||
documents=sample_documents,
|
||||
)
|
||||
|
||||
# Execute the memory tool
|
||||
response = await tools_impl.invoke_tool(
|
||||
tool_name="memory",
|
||||
args={
|
||||
"messages": [
|
||||
UserMessage(
|
||||
content="What are the main topics covered in the documentation?",
|
||||
)
|
||||
],
|
||||
"memory_bank_ids": ["test_bank"],
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, ToolInvocationResult)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue