diff --git a/distributions/meta-reference-gpu/run.yaml b/distributions/meta-reference-gpu/run.yaml index 9bf7655f9..ad3187aa1 100644 --- a/distributions/meta-reference-gpu/run.yaml +++ b/distributions/meta-reference-gpu/run.yaml @@ -13,14 +13,22 @@ apis: - safety providers: inference: - - provider_id: meta0 + - provider_id: meta-reference-inference provider_type: meta-reference config: - model: Llama3.1-8B-Instruct + model: Llama3.2-3B-Instruct quantization: null torch_seed: null max_seq_len: 4096 max_batch_size: 1 + - provider_id: meta-reference-safety + provider_type: meta-reference + config: + model: Llama-Guard-3-1B + quantization: null + torch_seed: null + max_seq_len: 2048 + max_batch_size: 1 safety: - provider_id: meta0 provider_type: meta-reference @@ -28,10 +36,9 @@ providers: llama_guard_shield: model: Llama-Guard-3-1B excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M +# Uncomment to use prompt guard +# prompt_guard_shield: +# model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference @@ -52,7 +59,7 @@ providers: persistence_store: namespace: null type: sqlite - db_path: ~/.llama/runtime/kvstore.db + db_path: ~/.llama/runtime/agents_store.db telemetry: - provider_id: meta0 provider_type: meta-reference diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 0d1177f5a..7c8e3939a 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -23,7 +23,7 @@ class ShieldDef(BaseModel): identifier: str = Field( description="A unique identifier for the shield type", ) - type: str = Field( + shield_type: str = Field( description="The type of shield this is; the value is one of the ShieldType enum" ) params: Dict[str, Any] = Field( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index fcf3451c1..c184557c6 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -178,13 +178,13 @@ class CommonRoutingTableImpl(RoutingTable): await register_object_with_provider(obj, p) await self.dist_registry.register(obj) + async def get_all(self) -> List[RoutableObjectWithProvider]: + return await self.dist_registry.get_all() + class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> List[ModelDefWithProvider]: - objects = [] - for objs in self.registry.values(): - objects.extend(objs) - return objects + return await self.get_all() async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: return self.get_object_by_identifier(identifier) @@ -195,10 +195,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[ShieldDef]: - objects = [] - for objs in self.registry.values(): - objects.extend(objs) - return objects + return await self.get_all() async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: return self.get_object_by_identifier(shield_type) @@ -209,10 +206,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: - objects = [] - for objs in self.registry.values(): - objects.extend(objs) - return objects + return await self.get_all() async def get_memory_bank( self, identifier: str @@ -227,10 +221,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> List[DatasetDefWithProvider]: - objects = [] - for objs in self.registry.values(): - objects.extend(objs) - return objects + return await self.get_all() async def get_dataset( self, dataset_identifier: str @@ -243,10 +234,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: - objects = [] - for objs in self.registry.values(): - objects.extend(objs) - return objects + return await self.get_all() async def get_scoring_function( self, name: str diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index c7e9630eb..da45ed5b8 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -37,7 +37,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat return [ ShieldDef( identifier=ShieldType.llama_guard.value, - type=ShieldType.llama_guard.value, + shield_type=ShieldType.llama_guard.value, params={}, ) ] diff --git a/llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py b/llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py index 37ea96270..fc6efd71b 100644 --- a/llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py +++ b/llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py @@ -25,8 +25,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): pass async def register_shield(self, shield: ShieldDef) -> None: - if shield.type != ShieldType.code_scanner.value: - raise ValueError(f"Unsupported safety shield type: {shield.type}") + if shield.shield_type != ShieldType.code_scanner.value: + raise ValueError(f"Unsupported safety shield type: {shield.shield_type}") async def run_shield( self, diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index de438ad29..28c78b65c 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -49,7 +49,7 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): return [ ShieldDef( identifier=shield_type, - type=shield_type, + shield_type=shield_type, params={}, ) for shield_type in self.available_shields @@ -92,14 +92,14 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): return RunShieldResponse(violation=violation) def get_shield_impl(self, shield: ShieldDef) -> ShieldBase: - if shield.type == ShieldType.llama_guard.value: + if shield.shield_type == ShieldType.llama_guard.value: cfg = self.config.llama_guard_shield return LlamaGuardShield( model=cfg.model, inference_api=self.inference_api, excluded_categories=cfg.excluded_categories, ) - elif shield.type == ShieldType.prompt_guard.value: + elif shield.shield_type == ShieldType.prompt_guard.value: model_dir = model_local_dir(PROMPT_GUARD_MODEL) subtype = shield.params.get("prompt_guard_type", "injection") if subtype == "injection": @@ -109,4 +109,4 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): else: raise ValueError(f"Unknown prompt guard type: {subtype}") else: - raise ValueError(f"Unknown shield type: {shield.type}") + raise ValueError(f"Unknown shield type: {shield.shield_type}")