Fix agents tests

This commit is contained in:
Ashwin Bharambe 2024-11-11 22:31:09 -08:00
parent 38257a9cbe
commit abe1cc6303
7 changed files with 51 additions and 24 deletions

View file

@ -641,12 +641,13 @@ class ChatAgent(ShieldRunnerMixin):
if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}"
memory_bank = VectorMemoryBank(
identifier=bank_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
await self.memory_banks_api.register_memory_bank(memory_bank)
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id

View file

@ -19,7 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "meta_reference",
"safety": "llama_guard",
"memory": "meta_reference",
"memory": "faiss",
"agents": "meta_reference",
},
id="meta_reference",
@ -29,7 +29,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "ollama",
"safety": "llama_guard",
"memory": "meta_reference",
"memory": "faiss",
"agents": "meta_reference",
},
id="ollama",
@ -40,7 +40,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"inference": "together",
"safety": "llama_guard",
# make this work with Weaviate which is what the together distro supports
"memory": "meta_reference",
"memory": "faiss",
"agents": "meta_reference",
},
id="together",

View file

@ -9,6 +9,7 @@ import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.models import Model
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.agents.meta_reference import (
@ -17,8 +18,18 @@ from llama_stack.providers.inline.agents.meta_reference import (
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..safety.fixtures import get_shield_to_register
def pick_inference_model(inference_model):
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
# multiple models when you need to run a safety model in addition to normal agent
# inference model. We filter off the safety model by looking for "Llama-Guard"
if isinstance(inference_model, list):
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
assert inference_model is not None
return inference_model
@pytest.fixture(scope="session")
@ -49,7 +60,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def agents_stack(request):
async def agents_stack(request, inference_model, safety_model):
fixture_dict = request.param
providers = {}
@ -60,9 +71,28 @@ async def agents_stack(request):
if fixture.provider_data:
provider_data.update(fixture.provider_data)
inf_provider_id = providers["inference"][0].provider_id
safety_provider_id = providers["safety"][0].provider_id
shield = get_shield_to_register(
providers["safety"][0].provider_type, safety_provider_id, safety_model
)
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
impls = await resolve_impls_for_test_v2(
[Api.agents, Api.inference, Api.safety, Api.memory],
providers,
provider_data,
models=[
Model(
identifier=model,
provider_id=inf_provider_id,
provider_resource_id=model,
)
for model in inference_models
],
shields=[shield],
)
return impls[Api.agents], impls[Api.memory]

View file

@ -16,15 +16,12 @@ from llama_stack.providers.datatypes import * # noqa: F403
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
# -m "meta_reference"
from .fixtures import pick_inference_model
@pytest.fixture
def common_params(inference_model):
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
# multiple models when you need to run a safety model in addition to normal agent
# inference model. We filter off the safety model by looking for "Llama-Guard"
if isinstance(inference_model, list):
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
assert inference_model is not None
inference_model = pick_inference_model(inference_model)
return dict(
model=inference_model,

View file

@ -26,13 +26,13 @@ def memory_remote() -> ProviderFixture:
@pytest.fixture(scope="session")
def memory_meta_reference() -> ProviderFixture:
def memory_faiss() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
provider_id="faiss",
provider_type="inline::faiss",
config=FaissImplConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
@ -93,7 +93,7 @@ def memory_chroma() -> ProviderFixture:
)
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"]
MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session")

View file

@ -44,7 +44,6 @@ def sample_documents():
async def register_memory_bank(banks_impl: MemoryBanks):
return await banks_impl.register_memory_bank(
memory_bank_id="test_bank",
params=VectorMemoryBankParams(
@ -71,7 +70,7 @@ class TestMemory:
# but so far we don't have an unregister API unfortunately, so be careful
_, banks_impl = memory_stack
bank = await banks_impl.register_memory_bank(
await banks_impl.register_memory_bank(
memory_bank_id="test_bank_no_provider",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",

View file

@ -99,7 +99,7 @@ async def safety_stack(inference_model, safety_model, request):
provider_data.update(safety_fixture.provider_data)
shield_provider_type = safety_fixture.providers[0].provider_type
shield = get_shield(
shield = get_shield_to_register(
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
)
@ -120,7 +120,7 @@ async def safety_stack(inference_model, safety_model, request):
return impls[Api.safety], impls[Api.shields], shield
def get_shield(provider_type: str, provider_id: str, safety_model: str):
def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str):
shield_config = {}
shield_type = ShieldType.llama_guard
identifier = "llama_guard"