llama-stack/llama_stack/providers/tests/tools/fixtures.py
Yuan Tang 34ab7a3b6c
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
2025-02-02 06:46:45 -08:00

133 lines
3.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import ToolGroupInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture
@pytest.fixture(scope="session")
def tool_runtime_memory_and_search() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="rag-runtime",
provider_type="inline::rag-runtime",
config={},
),
Provider(
provider_id="tavily-search",
provider_type="remote::tavily-search",
config={
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
},
),
Provider(
provider_id="wolfram-alpha",
provider_type="remote::wolfram-alpha",
config={
"api_key": os.environ["WOLFRAM_ALPHA_API_KEY"],
},
),
],
)
@pytest.fixture(scope="session")
def tool_group_input_memory() -> ToolGroupInput:
return ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
)
@pytest.fixture(scope="session")
def tool_group_input_tavily_search() -> ToolGroupInput:
return ToolGroupInput(
toolgroup_id="builtin::web_search",
provider_id="tavily-search",
)
@pytest.fixture(scope="session")
def tool_group_input_wolfram_alpha() -> ToolGroupInput:
return ToolGroupInput(
toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha",
)
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
@pytest_asyncio.fixture(scope="session")
async def tools_stack(
request,
inference_model,
tool_group_input_memory,
tool_group_input_tavily_search,
tool_group_input_wolfram_alpha,
):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "vector_io", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
providers[key].append(
Provider(
provider_id="tools_memory_provider",
provider_type="inline::sentence-transformers",
config={},
)
)
if fixture.provider_data:
provider_data.update(fixture.provider_data)
inference_models = inference_model if isinstance(inference_model, list) else [inference_model]
models = [
ModelInput(
model_id=model,
model_type=ModelType.llm,
provider_id=providers["inference"][0].provider_id,
)
for model in inference_models
]
models.append(
ModelInput(
model_id="all-MiniLM-L6-v2",
model_type=ModelType.embedding,
provider_id="tools_memory_provider",
metadata={"embedding_dimension": 384},
)
)
test_stack = await construct_stack_for_test(
[
Api.tool_groups,
Api.inference,
Api.vector_io,
Api.tool_runtime,
],
providers,
provider_data,
models=models,
tool_groups=[
tool_group_input_tavily_search,
tool_group_input_wolfram_alpha,
tool_group_input_memory,
],
)
return test_stack