diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index 513c69f04..1abf66301 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -36,8 +36,3 @@ class Resource(BaseModel): type: ResourceType = Field( description="Type of resource (e.g. 'model', 'shield', 'memory_bank', etc.)" ) - - # If the provider_resource_identifier is not set, set it to the identifier - def model_post_init(self, __context) -> None: - if self.provider_resource_identifier is None: - self.provider_resource_identifier = self.identifier diff --git a/llama_stack/apis/shields/client.py b/llama_stack/apis/shields/client.py index 0fd854967..02aa7c2a4 100644 --- a/llama_stack/apis/shields/client.py +++ b/llama_stack/apis/shields/client.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import json from typing import List, Optional @@ -35,12 +34,23 @@ class ShieldsClient(Shields): response.raise_for_status() return [Shield(**x) for x in response.json()] - async def register_shield(self, shield: Shield) -> None: + async def register_shield( + self, + shield_id: str, + shield_type: ShieldType, + provider_resource_identifier: Optional[str], + provider_id: Optional[str], + params: Optional[Dict[str, Any]], + ) -> None: async with httpx.AsyncClient() as client: response = await client.post( f"{self.base_url}/shields/register", json={ - "shield": json.loads(shield.json()), + "shield_id": shield_id, + "shield_type": shield_type, + "provider_resource_identifier": provider_resource_identifier, + "provider_id": provider_id, + "params": params, }, headers={"Content-Type": "application/json"}, ) diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 06367a62d..3f0da8573 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -38,4 +38,11 @@ class Shields(Protocol): async def get_shield(self, identifier: str) -> Optional[Shield]: ... @webmethod(route="/shields/register", method="POST") - async def register_shield(self, shield: Shield) -> None: ... + async def register_shield( + self, + shield_id: str, + shield_type: ShieldType, + provider_resource_identifier: Optional[str] = None, + provider_id: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + ) -> Shield: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 0e4653133..1643091e8 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -150,8 +150,17 @@ class SafetyRouter(Safety): async def shutdown(self) -> None: pass - async def register_shield(self, shield: Shield) -> None: - await self.routing_table.register_shield(shield) + async def register_shield( + self, + shield_id: str, + shield_type: ShieldType, + provider_resource_identifier: Optional[str] = None, + provider_id: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + ) -> Shield: + return await self.routing_table.register_shield( + shield_id, shield_type, provider_resource_identifier, provider_id, params + ) async def run_shield( self, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 07ecbe3f9..5378661fd 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -215,8 +215,44 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def get_shield(self, identifier: str) -> Optional[Shield]: return await self.get_object_by_identifier(identifier) - async def register_shield(self, shield: Shield) -> None: + async def register_shield( + self, + shield_id: str, + shield_type: ShieldType, + provider_resource_identifier: Optional[str] = None, + provider_id: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + ) -> Shield: + if provider_resource_identifier is None: + provider_resource_identifier = shield_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this shield type + if len(self.impls_by_provider_id) == 1: + provider = list(self.impls_by_provider_id.values())[0] + if ( + hasattr(provider, "supported_shield_types") + and shield_type in await provider.supported_shield_types() + ): + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + f"No provider available that supports shield type {shield_type}" + ) + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if params is None: + params = {} + shield = Shield( + identifier=shield_id, + shield_type=shield_type, + provider_resource_identifier=provider_resource_identifier, + provider_id=provider_id, + params=params, + ) await self.register_object(shield) + return shield class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 29c551382..68543b3ce 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -16,7 +16,7 @@ from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.memory_banks import MemoryBankDef from llama_stack.apis.models import ModelDef from llama_stack.apis.scoring_functions import ScoringFnDef -from llama_stack.apis.shields import Shield +from llama_stack.apis.shields import Shield, ShieldType @json_schema_type @@ -51,6 +51,8 @@ class ModelsProtocolPrivate(Protocol): class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... + async def supported_shield_types(self) -> List[ShieldType]: ... + class MemoryBanksProtocolPrivate(Protocol): async def list_memory_banks(self) -> List[MemoryBankDef]: ... diff --git a/llama_stack/providers/inline/safety/meta_reference/safety.py b/llama_stack/providers/inline/safety/meta_reference/safety.py index 76c54ecfb..787150e22 100644 --- a/llama_stack/providers/inline/safety/meta_reference/safety.py +++ b/llama_stack/providers/inline/safety/meta_reference/safety.py @@ -21,6 +21,7 @@ from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield PROMPT_GUARD_MODEL = "Prompt-Guard-86M" +SUPPORTED_SHIELDS = [ShieldType.llama_guard, ShieldType.prompt_guard] class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): @@ -46,6 +47,9 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): if shield.shield_type not in self.available_shields: raise ValueError(f"Shield type {shield.shield_type} not supported") + async def supported_shield_types(self) -> List[ShieldType]: + return SUPPORTED_SHIELDS + async def run_shield( self, shield_id: str, diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index 6c41edcdb..e9955ab66 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) BEDROCK_SUPPORTED_SHIELDS = [ - ShieldType.generic_content_shield.value, + ShieldType.generic_content_shield, ] @@ -53,6 +53,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): f"Shield {shield.identifier} with version {shield.params['guardrailVersion']} not found in Bedrock" ) + async def supported_shield_types(self) -> List[ShieldType]: + return BEDROCK_SUPPORTED_SHIELDS + async def run_shield( self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 000318d5d..035288cf8 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -7,7 +7,7 @@ import pytest import pytest_asyncio -from llama_stack.apis.shields import Shield, ShieldType +from llama_stack.apis.shields import ShieldType from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.safety.meta_reference import ( @@ -95,10 +95,10 @@ async def safety_stack(inference_model, safety_model, request): shields_impl = impls[Api.shields] # Register the appropriate shield based on provider type - provider_id = safety_fixture.providers[0].provider_id provider_type = safety_fixture.providers[0].provider_type shield_config = {} + shield_type = ShieldType.llama_guard identifier = "llama_guard" if provider_type == "meta-reference": shield_config["model"] = safety_model @@ -107,12 +107,11 @@ async def safety_stack(inference_model, safety_model, request): elif provider_type == "remote::bedrock": identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER") shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") + shield_type = ShieldType.generic_content_shield - # Create shield - shield = Shield( - identifier=identifier, - shield_type=ShieldType.llama_guard, - provider_id=provider_id, + shield = await shields_impl.register_shield( + shield_id=identifier, + shield_type=shield_type, params=shield_config, ) diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index e13b8417c..2dd748a60 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -19,9 +19,15 @@ from llama_stack.distribution.datatypes import * # noqa: F403 class TestSafety: @pytest.mark.asyncio - async def test_shield_list(self, safety_stack): + async def test_new_shield(self, safety_stack): _, shields_impl, shield = safety_stack - await shields_impl.register_shield(shield) + assert shield is not None + assert shield.provider_resource_identifier == shield.identifier + assert shield.provider_id is not None + + @pytest.mark.asyncio + async def test_shield_list(self, safety_stack): + _, shields_impl, _ = safety_stack response = await shields_impl.list_shields() assert isinstance(response, list) assert len(response) >= 1 @@ -32,9 +38,7 @@ class TestSafety: @pytest.mark.asyncio async def test_run_shield(self, safety_stack): - safety_impl, shields_impl, shield = safety_stack - - await shields_impl.register_shield(shield) + safety_impl, _, shield = safety_stack response = await safety_impl.run_shield( shield_id=shield.identifier,