Refactor safety shield fixtures

This commit is contained in:
Ashwin Bharambe 2024-11-12 15:58:58 -08:00
parent 743da9690b
commit 8645f8bc9e
6 changed files with 53 additions and 62 deletions

View file

@ -19,7 +19,6 @@ from llama_stack.providers.inline.agents.meta_reference import (
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..safety.fixtures import get_shield_to_register
def pick_inference_model(inference_model):
@ -60,7 +59,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def agents_stack(request, inference_model, safety_model):
async def agents_stack(request, inference_model, safety_shield):
fixture_dict = request.param
providers = {}
@ -71,9 +70,6 @@ async def agents_stack(request, inference_model, safety_model):
if fixture.provider_data:
provider_data.update(fixture.provider_data)
shield_input = get_shield_to_register(
providers["safety"][0].provider_type, safety_model
)
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
@ -87,6 +83,6 @@ async def agents_stack(request, inference_model, safety_model):
)
for model in inference_models
],
shields=[shield_input],
shields=[safety_shield],
)
return impls[Api.agents], impls[Api.memory]