mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
Fix agents tests
This commit is contained in:
parent
38257a9cbe
commit
abe1cc6303
7 changed files with 51 additions and 24 deletions
|
@ -641,12 +641,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
if session_info.memory_bank_id is None:
|
if session_info.memory_bank_id is None:
|
||||||
bank_id = f"memory_bank_{session_id}"
|
bank_id = f"memory_bank_{session_id}"
|
||||||
memory_bank = VectorMemoryBank(
|
await self.memory_banks_api.register_memory_bank(
|
||||||
identifier=bank_id,
|
memory_bank_id=bank_id,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
params=VectorMemoryBankParams(
|
||||||
chunk_size_in_tokens=512,
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
await self.memory_banks_api.register_memory_bank(memory_bank)
|
|
||||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||||
else:
|
else:
|
||||||
bank_id = session_info.memory_bank_id
|
bank_id = session_info.memory_bank_id
|
||||||
|
|
|
@ -19,7 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
{
|
{
|
||||||
"inference": "meta_reference",
|
"inference": "meta_reference",
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "meta_reference",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
},
|
},
|
||||||
id="meta_reference",
|
id="meta_reference",
|
||||||
|
@ -29,7 +29,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
{
|
{
|
||||||
"inference": "ollama",
|
"inference": "ollama",
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "meta_reference",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
},
|
},
|
||||||
id="ollama",
|
id="ollama",
|
||||||
|
@ -40,7 +40,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"inference": "together",
|
"inference": "together",
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
# make this work with Weaviate which is what the together distro supports
|
# make this work with Weaviate which is what the together distro supports
|
||||||
"memory": "meta_reference",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
},
|
},
|
||||||
id="together",
|
id="together",
|
||||||
|
|
|
@ -9,6 +9,7 @@ import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.inline.agents.meta_reference import (
|
from llama_stack.providers.inline.agents.meta_reference import (
|
||||||
|
@ -17,8 +18,18 @@ from llama_stack.providers.inline.agents.meta_reference import (
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
from ..safety.fixtures import get_shield_to_register
|
||||||
|
|
||||||
|
|
||||||
|
def pick_inference_model(inference_model):
|
||||||
|
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
|
||||||
|
# multiple models when you need to run a safety model in addition to normal agent
|
||||||
|
# inference model. We filter off the safety model by looking for "Llama-Guard"
|
||||||
|
if isinstance(inference_model, list):
|
||||||
|
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
|
||||||
|
assert inference_model is not None
|
||||||
|
return inference_model
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -49,7 +60,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def agents_stack(request):
|
async def agents_stack(request, inference_model, safety_model):
|
||||||
fixture_dict = request.param
|
fixture_dict = request.param
|
||||||
|
|
||||||
providers = {}
|
providers = {}
|
||||||
|
@ -60,9 +71,28 @@ async def agents_stack(request):
|
||||||
if fixture.provider_data:
|
if fixture.provider_data:
|
||||||
provider_data.update(fixture.provider_data)
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
|
inf_provider_id = providers["inference"][0].provider_id
|
||||||
|
safety_provider_id = providers["safety"][0].provider_id
|
||||||
|
|
||||||
|
shield = get_shield_to_register(
|
||||||
|
providers["safety"][0].provider_type, safety_provider_id, safety_model
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_models = (
|
||||||
|
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||||
|
)
|
||||||
impls = await resolve_impls_for_test_v2(
|
impls = await resolve_impls_for_test_v2(
|
||||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
|
models=[
|
||||||
|
Model(
|
||||||
|
identifier=model,
|
||||||
|
provider_id=inf_provider_id,
|
||||||
|
provider_resource_id=model,
|
||||||
|
)
|
||||||
|
for model in inference_models
|
||||||
|
],
|
||||||
|
shields=[shield],
|
||||||
)
|
)
|
||||||
return impls[Api.agents], impls[Api.memory]
|
return impls[Api.agents], impls[Api.memory]
|
||||||
|
|
|
@ -16,15 +16,12 @@ from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
||||||
# -m "meta_reference"
|
# -m "meta_reference"
|
||||||
|
|
||||||
|
from .fixtures import pick_inference_model
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def common_params(inference_model):
|
def common_params(inference_model):
|
||||||
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
|
inference_model = pick_inference_model(inference_model)
|
||||||
# multiple models when you need to run a safety model in addition to normal agent
|
|
||||||
# inference model. We filter off the safety model by looking for "Llama-Guard"
|
|
||||||
if isinstance(inference_model, list):
|
|
||||||
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
|
|
||||||
assert inference_model is not None
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
model=inference_model,
|
model=inference_model,
|
||||||
|
|
|
@ -26,13 +26,13 @@ def memory_remote() -> ProviderFixture:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def memory_meta_reference() -> ProviderFixture:
|
def memory_faiss() -> ProviderFixture:
|
||||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
providers=[
|
providers=[
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="meta-reference",
|
provider_id="faiss",
|
||||||
provider_type="meta-reference",
|
provider_type="inline::faiss",
|
||||||
config=FaissImplConfig(
|
config=FaissImplConfig(
|
||||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
|
@ -93,7 +93,7 @@ def memory_chroma() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"]
|
MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
|
|
@ -44,7 +44,6 @@ def sample_documents():
|
||||||
|
|
||||||
|
|
||||||
async def register_memory_bank(banks_impl: MemoryBanks):
|
async def register_memory_bank(banks_impl: MemoryBanks):
|
||||||
|
|
||||||
return await banks_impl.register_memory_bank(
|
return await banks_impl.register_memory_bank(
|
||||||
memory_bank_id="test_bank",
|
memory_bank_id="test_bank",
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
|
@ -71,7 +70,7 @@ class TestMemory:
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
# but so far we don't have an unregister API unfortunately, so be careful
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
bank = await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id="test_bank_no_provider",
|
memory_bank_id="test_bank_no_provider",
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
|
|
@ -99,7 +99,7 @@ async def safety_stack(inference_model, safety_model, request):
|
||||||
provider_data.update(safety_fixture.provider_data)
|
provider_data.update(safety_fixture.provider_data)
|
||||||
|
|
||||||
shield_provider_type = safety_fixture.providers[0].provider_type
|
shield_provider_type = safety_fixture.providers[0].provider_type
|
||||||
shield = get_shield(
|
shield = get_shield_to_register(
|
||||||
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
|
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -120,7 +120,7 @@ async def safety_stack(inference_model, safety_model, request):
|
||||||
return impls[Api.safety], impls[Api.shields], shield
|
return impls[Api.safety], impls[Api.shields], shield
|
||||||
|
|
||||||
|
|
||||||
def get_shield(provider_type: str, provider_id: str, safety_model: str):
|
def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str):
|
||||||
shield_config = {}
|
shield_config = {}
|
||||||
shield_type = ShieldType.llama_guard
|
shield_type = ShieldType.llama_guard
|
||||||
identifier = "llama_guard"
|
identifier = "llama_guard"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue