mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
Fix agents tests
This commit is contained in:
parent
38257a9cbe
commit
abe1cc6303
7 changed files with 51 additions and 24 deletions
|
@ -641,12 +641,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if session_info.memory_bank_id is None:
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
memory_bank = VectorMemoryBank(
|
||||
identifier=bank_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
await self.memory_banks_api.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
)
|
||||
await self.memory_banks_api.register_memory_bank(memory_bank)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
bank_id = session_info.memory_bank_id
|
||||
|
|
|
@ -19,7 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
"memory": "meta_reference",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
id="meta_reference",
|
||||
|
@ -29,7 +29,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "ollama",
|
||||
"safety": "llama_guard",
|
||||
"memory": "meta_reference",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
id="ollama",
|
||||
|
@ -40,7 +40,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"inference": "together",
|
||||
"safety": "llama_guard",
|
||||
# make this work with Weaviate which is what the together distro supports
|
||||
"memory": "meta_reference",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
id="together",
|
||||
|
|
|
@ -9,6 +9,7 @@ import tempfile
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference import (
|
||||
|
@ -17,8 +18,18 @@ from llama_stack.providers.inline.agents.meta_reference import (
|
|||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..safety.fixtures import get_shield_to_register
|
||||
|
||||
|
||||
def pick_inference_model(inference_model):
|
||||
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
|
||||
# multiple models when you need to run a safety model in addition to normal agent
|
||||
# inference model. We filter off the safety model by looking for "Llama-Guard"
|
||||
if isinstance(inference_model, list):
|
||||
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
|
||||
assert inference_model is not None
|
||||
return inference_model
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -49,7 +60,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_stack(request):
|
||||
async def agents_stack(request, inference_model, safety_model):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
|
@ -60,9 +71,28 @@ async def agents_stack(request):
|
|||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
inf_provider_id = providers["inference"][0].provider_id
|
||||
safety_provider_id = providers["safety"][0].provider_id
|
||||
|
||||
shield = get_shield_to_register(
|
||||
providers["safety"][0].provider_type, safety_provider_id, safety_model
|
||||
)
|
||||
|
||||
inference_models = (
|
||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||
)
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[
|
||||
Model(
|
||||
identifier=model,
|
||||
provider_id=inf_provider_id,
|
||||
provider_resource_id=model,
|
||||
)
|
||||
for model in inference_models
|
||||
],
|
||||
shields=[shield],
|
||||
)
|
||||
return impls[Api.agents], impls[Api.memory]
|
||||
|
|
|
@ -16,15 +16,12 @@ from llama_stack.providers.datatypes import * # noqa: F403
|
|||
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
||||
# -m "meta_reference"
|
||||
|
||||
from .fixtures import pick_inference_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def common_params(inference_model):
|
||||
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
|
||||
# multiple models when you need to run a safety model in addition to normal agent
|
||||
# inference model. We filter off the safety model by looking for "Llama-Guard"
|
||||
if isinstance(inference_model, list):
|
||||
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
|
||||
assert inference_model is not None
|
||||
inference_model = pick_inference_model(inference_model)
|
||||
|
||||
return dict(
|
||||
model=inference_model,
|
||||
|
|
|
@ -26,13 +26,13 @@ def memory_remote() -> ProviderFixture:
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_meta_reference() -> ProviderFixture:
|
||||
def memory_faiss() -> ProviderFixture:
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissImplConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||
).model_dump(),
|
||||
|
@ -93,7 +93,7 @@ def memory_chroma() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"]
|
||||
MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
|
@ -44,7 +44,6 @@ def sample_documents():
|
|||
|
||||
|
||||
async def register_memory_bank(banks_impl: MemoryBanks):
|
||||
|
||||
return await banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank",
|
||||
params=VectorMemoryBankParams(
|
||||
|
@ -71,7 +70,7 @@ class TestMemory:
|
|||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
_, banks_impl = memory_stack
|
||||
|
||||
bank = await banks_impl.register_memory_bank(
|
||||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank_no_provider",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
|
|
|
@ -99,7 +99,7 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
provider_data.update(safety_fixture.provider_data)
|
||||
|
||||
shield_provider_type = safety_fixture.providers[0].provider_type
|
||||
shield = get_shield(
|
||||
shield = get_shield_to_register(
|
||||
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
|
||||
)
|
||||
|
||||
|
@ -120,7 +120,7 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
return impls[Api.safety], impls[Api.shields], shield
|
||||
|
||||
|
||||
def get_shield(provider_type: str, provider_id: str, safety_model: str):
|
||||
def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str):
|
||||
shield_config = {}
|
||||
shield_type = ShieldType.llama_guard
|
||||
identifier = "llama_guard"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue