mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 06:53:47 +00:00
more fixes (some fixes to pre-existing issues in safety fixture)
This commit is contained in:
parent
7507cd487f
commit
15ffceb533
6 changed files with 25 additions and 9 deletions
|
@ -16,6 +16,7 @@ from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
|||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -89,7 +90,21 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
|
||||
# Register the appropriate shield based on provider type
|
||||
provider_type = safety_fixture.providers[0].provider_type
|
||||
shield = await create_and_register_shield(provider_type, safety_model, shields_impl)
|
||||
|
||||
provider_id = inference_fixture.providers[0].provider_id
|
||||
print(f"Registering model {inference_model} with provider {provider_id}")
|
||||
await impls[Api.models].register_model(
|
||||
model_id=inference_model,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
return safety_impl, shields_impl, shield
|
||||
|
||||
|
||||
async def create_and_register_shield(
|
||||
provider_type: str, safety_model: str, shields_impl
|
||||
):
|
||||
shield_config = {}
|
||||
shield_type = ShieldType.llama_guard
|
||||
identifier = "llama_guard"
|
||||
|
@ -102,10 +117,8 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
||||
shield_type = ShieldType.generic_content_shield
|
||||
|
||||
shield = await shields_impl.register_shield(
|
||||
return await shields_impl.register_shield(
|
||||
shield_id=identifier,
|
||||
shield_type=shield_type,
|
||||
params=shield_config,
|
||||
)
|
||||
|
||||
return safety_impl, shields_impl, shield
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue