From fe072620c8e66a0e1c3236b1aae709bdb9ed1893 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 8 Nov 2024 12:00:36 -0800 Subject: [PATCH] address feedback --- llama_stack/apis/resource.py | 2 +- llama_stack/apis/shields/client.py | 4 ++-- llama_stack/apis/shields/shields.py | 2 +- llama_stack/distribution/routers/routers.py | 4 ++-- .../distribution/routers/routing_tables.py | 19 +++++-------------- llama_stack/providers/datatypes.py | 4 +--- .../inline/safety/meta_reference/safety.py | 3 --- .../remote/safety/bedrock/bedrock.py | 9 +++------ .../providers/tests/safety/test_safety.py | 2 +- 9 files changed, 16 insertions(+), 33 deletions(-) diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index 1abf66301..c386311cc 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -26,7 +26,7 @@ class Resource(BaseModel): description="Unique identifier for this resource in llama stack" ) - provider_resource_identifier: str = Field( + provider_resource_id: str = Field( description="Unique identifier for this resource in the provider", default=None, ) diff --git a/llama_stack/apis/shields/client.py b/llama_stack/apis/shields/client.py index 02aa7c2a4..2f6b5e649 100644 --- a/llama_stack/apis/shields/client.py +++ b/llama_stack/apis/shields/client.py @@ -38,7 +38,7 @@ class ShieldsClient(Shields): self, shield_id: str, shield_type: ShieldType, - provider_resource_identifier: Optional[str], + provider_shield_id: Optional[str], provider_id: Optional[str], params: Optional[Dict[str, Any]], ) -> None: @@ -48,7 +48,7 @@ class ShieldsClient(Shields): json={ "shield_id": shield_id, "shield_type": shield_type, - "provider_resource_identifier": provider_resource_identifier, + "provider_shield_id": provider_shield_id, "provider_id": provider_id, "params": params, }, diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 3f0da8573..42fe717fa 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -42,7 +42,7 @@ class Shields(Protocol): self, shield_id: str, shield_type: ShieldType, - provider_resource_identifier: Optional[str] = None, + provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ) -> Shield: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 1643091e8..01861b9b3 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -154,12 +154,12 @@ class SafetyRouter(Safety): self, shield_id: str, shield_type: ShieldType, - provider_resource_identifier: Optional[str] = None, + provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ) -> Shield: return await self.routing_table.register_shield( - shield_id, shield_type, provider_resource_identifier, provider_id, params + shield_id, shield_type, provider_shield_id, provider_id, params ) async def run_shield( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 5378661fd..e02c1cef6 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -219,25 +219,16 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): self, shield_id: str, shield_type: ShieldType, - provider_resource_identifier: Optional[str] = None, + provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ) -> Shield: - if provider_resource_identifier is None: - provider_resource_identifier = shield_id + if provider_shield_id is None: + provider_shield_id = shield_id if provider_id is None: # If provider_id not specified, use the only provider if it supports this shield type if len(self.impls_by_provider_id) == 1: - provider = list(self.impls_by_provider_id.values())[0] - if ( - hasattr(provider, "supported_shield_types") - and shield_type in await provider.supported_shield_types() - ): - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - f"No provider available that supports shield type {shield_type}" - ) + provider_id = list(self.impls_by_provider_id.keys())[0] else: raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." @@ -247,7 +238,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): shield = Shield( identifier=shield_id, shield_type=shield_type, - provider_resource_identifier=provider_resource_identifier, + provider_resource_id=provider_shield_id, provider_id=provider_id, params=params, ) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 68543b3ce..29c551382 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -16,7 +16,7 @@ from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.memory_banks import MemoryBankDef from llama_stack.apis.models import ModelDef from llama_stack.apis.scoring_functions import ScoringFnDef -from llama_stack.apis.shields import Shield, ShieldType +from llama_stack.apis.shields import Shield @json_schema_type @@ -51,8 +51,6 @@ class ModelsProtocolPrivate(Protocol): class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... - async def supported_shield_types(self) -> List[ShieldType]: ... - class MemoryBanksProtocolPrivate(Protocol): async def list_memory_banks(self) -> List[MemoryBankDef]: ... diff --git a/llama_stack/providers/inline/safety/meta_reference/safety.py b/llama_stack/providers/inline/safety/meta_reference/safety.py index 787150e22..824a7cd7e 100644 --- a/llama_stack/providers/inline/safety/meta_reference/safety.py +++ b/llama_stack/providers/inline/safety/meta_reference/safety.py @@ -47,9 +47,6 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): if shield.shield_type not in self.available_shields: raise ValueError(f"Shield type {shield.shield_type} not supported") - async def supported_shield_types(self) -> List[ShieldType]: - return SUPPORTED_SHIELDS - async def run_shield( self, shield_id: str, diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index e9955ab66..d49035321 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -42,7 +42,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): async def register_shield(self, shield: Shield) -> None: response = self.bedrock_client.list_guardrails( - guardrailIdentifier=shield.identifier, + guardrailIdentifier=shield.provider_resource_id, ) if ( not response["guardrails"] @@ -50,12 +50,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): or response["guardrails"][0]["version"] != shield.params["guardrailVersion"] ): raise ValueError( - f"Shield {shield.identifier} with version {shield.params['guardrailVersion']} not found in Bedrock" + f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock" ) - async def supported_shield_types(self) -> List[ShieldType]: - return BEDROCK_SUPPORTED_SHIELDS - async def run_shield( self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: @@ -89,7 +86,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): ) response = self.bedrock_runtime_client.apply_guardrail( - guardrailIdentifier=shield.identifier, + guardrailIdentifier=shield.provider_resource_id, guardrailVersion=shield_params["guardrailVersion"], source="OUTPUT", # or 'INPUT' depending on your use case content=content_messages, diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index 2dd748a60..48fab9741 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -22,7 +22,7 @@ class TestSafety: async def test_new_shield(self, safety_stack): _, shields_impl, shield = safety_stack assert shield is not None - assert shield.provider_resource_identifier == shield.identifier + assert shield.provider_resource_id == shield.identifier assert shield.provider_id is not None @pytest.mark.asyncio