Allow specifying resources in StackRunConfig

This commit is contained in:
Ashwin Bharambe 2024-11-11 22:08:51 -08:00
parent 8035fa1869
commit 38257a9cbe
9 changed files with 151 additions and 102 deletions

View file

@ -7,7 +7,9 @@
import pytest
import pytest_asyncio
from llama_stack.apis.shields import ShieldType
from llama_stack.apis.models import Model
from llama_stack.apis.shields import Shield, ShieldType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
@ -96,32 +98,29 @@ async def safety_stack(inference_model, safety_model, request):
if safety_fixture.provider_data:
provider_data.update(safety_fixture.provider_data)
shield_provider_type = safety_fixture.providers[0].provider_type
shield = get_shield(
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
)
impls = await resolve_impls_for_test_v2(
[Api.safety, Api.shields, Api.inference],
providers,
provider_data,
models=[
Model(
identifier=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
provider_resource_id=inference_model,
)
],
shields=[shield],
)
safety_impl = impls[Api.safety]
shields_impl = impls[Api.shields]
# Register the appropriate shield based on provider type
provider_type = safety_fixture.providers[0].provider_type
shield = await create_and_register_shield(provider_type, safety_model, shields_impl)
provider_id = inference_fixture.providers[0].provider_id
print(f"Registering model {inference_model} with provider {provider_id}")
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
)
return safety_impl, shields_impl, shield
return impls[Api.safety], impls[Api.shields], shield
async def create_and_register_shield(
provider_type: str, safety_model: str, shields_impl
):
def get_shield(provider_type: str, provider_id: str, safety_model: str):
shield_config = {}
shield_type = ShieldType.llama_guard
identifier = "llama_guard"
@ -134,8 +133,10 @@ async def create_and_register_shield(
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
shield_type = ShieldType.generic_content_shield
return await shields_impl.register_shield(
shield_id=identifier,
return Shield(
identifier=identifier,
shield_type=shield_type,
params=shield_config,
provider_id=provider_id,
provider_resource_id=identifier,
)