Fix tool tests

This commit is contained in:
Ashwin Bharambe 2025-01-22 20:31:18 -08:00
parent 0bff6e1658
commit 6c205e1d5a
2 changed files with 4 additions and 5 deletions

View file

@ -8,8 +8,8 @@ import pytest
from ..conftest import get_provider_fixture_overrides from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
from .fixtures import TOOL_RUNTIME_FIXTURES from .fixtures import TOOL_RUNTIME_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [ DEFAULT_PROVIDER_COMBINATIONS = [
@ -17,7 +17,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{ {
"inference": "together", "inference": "together",
"safety": "llama_guard", "safety": "llama_guard",
"memory": "faiss", "vector_io": "faiss",
"tool_runtime": "memory_and_search", "tool_runtime": "memory_and_search",
}, },
id="together", id="together",
@ -39,12 +39,11 @@ def pytest_generate_tests(metafunc):
available_fixtures = { available_fixtures = {
"inference": INFERENCE_FIXTURES, "inference": INFERENCE_FIXTURES,
"safety": SAFETY_FIXTURES, "safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES, "vector_io": VECTOR_IO_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES, "tool_runtime": TOOL_RUNTIME_FIXTURES,
} }
combinations = ( combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures) get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS or DEFAULT_PROVIDER_COMBINATIONS
) )
print(combinations)
metafunc.parametrize("tools_stack", combinations, indirect=True) metafunc.parametrize("tools_stack", combinations, indirect=True)

View file

@ -88,7 +88,7 @@ class TestTools:
tools_impl = tools_stack.impls[Api.tool_runtime] tools_impl = tools_stack.impls[Api.tool_runtime]
# Register memory bank # Register memory bank
await vector_dbs_impl.register( await vector_dbs_impl.register_vector_db(
vector_db_id="test_bank", vector_db_id="test_bank",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384, embedding_dimension=384,