From 983d6ce2dfdde3f0d359cb2cb0a60ff4e3f7e32e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 12 Nov 2024 12:37:24 -0800 Subject: [PATCH] Remove the "ShieldType" concept (#430) # What does this PR do? This PR kills the notion of "ShieldType". The impetus for this is the realization: > Why is keyword llama-guard appearing so many times everywhere, sometimes with hyphens, sometimes with underscores? Now that we have a notion of "provider specific resource identifiers" and "user specific aliases" for those and the fact that this works with models ("Llama3.1-8B-Instruct" <> "fireworks/llama-3pv1-..."), we can follow the same rules for Shields. So each Safety provider can make up a notion of identifiers it has registered. This already happens with Bedrock correctly. We just generalize it for Llama Guard, Prompt Guard, etc. For Llama Guard, we further simplify by just adopting the underlying model name itself as the identifier! No confusion necessary. While doing this, I noticed a bug in our DistributionRegistry where we weren't scoping identifiers by type. Fixed. ## Feature/Issue validation/testing/test plan Ran (inference, safety, memory, agents) tests with ollama and fireworks providers. --- .github/PULL_REQUEST_TEMPLATE.md | 3 +- docs/_deprecating_soon.ipynb | 4 +- docs/resources/llama-stack-spec.html | 68 +++++++------------ docs/resources/llama-stack-spec.yaml | 42 ++++-------- docs/zero_to_hero_guide/06_Safety101.ipynb | 6 +- llama_stack/apis/datasets/datasets.py | 1 - llama_stack/apis/eval_tasks/eval_tasks.py | 1 - llama_stack/apis/memory_banks/memory_banks.py | 1 - llama_stack/apis/models/models.py | 1 - llama_stack/apis/safety/client.py | 4 +- .../scoring_functions/scoring_functions.py | 1 - llama_stack/apis/shields/client.py | 6 +- llama_stack/apis/shields/shields.py | 11 --- llama_stack/distribution/routers/routers.py | 3 +- .../distribution/routers/routing_tables.py | 31 ++++----- llama_stack/distribution/stack.py | 1 - llama_stack/distribution/store/registry.py | 66 ++++++++++-------- .../safety/code_scanner/code_scanner.py | 12 +++- .../inline/safety/llama_guard/config.py | 26 +------ .../inline/safety/llama_guard/llama_guard.py | 25 ++++--- .../safety/prompt_guard/prompt_guard.py | 6 +- .../remote/safety/bedrock/bedrock.py | 5 -- .../providers/tests/agents/fixtures.py | 2 +- .../providers/tests/agents/test_agents.py | 8 ++- .../providers/tests/safety/fixtures.py | 24 +++---- .../providers/tests/safety/test_safety.py | 1 - 26 files changed, 150 insertions(+), 209 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 79701d926..fb02dd136 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -4,7 +4,8 @@ In short, provide a summary of what this PR does and why. Usually, the relevant - [ ] Addresses issue (#issue) -## Feature/Issue validation/testing/test plan + +## Test Plan Please describe: - tests you ran to verify your changes with result summaries. diff --git a/docs/_deprecating_soon.ipynb b/docs/_deprecating_soon.ipynb index 343005962..7fa4034ce 100644 --- a/docs/_deprecating_soon.ipynb +++ b/docs/_deprecating_soon.ipynb @@ -180,8 +180,8 @@ " tools=tools,\n", " tool_choice=\"auto\",\n", " tool_prompt_format=\"json\",\n", - " input_shields=[\"llama_guard\"],\n", - " output_shields=[\"llama_guard\"],\n", + " input_shields=[\"Llama-Guard-3-1B\"],\n", + " output_shields=[\"Llama-Guard-3-1B\"],\n", " enable_session_persistence=True,\n", " )\n", "\n", diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 231633464..7ef4ece21 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782" }, "servers": [ { @@ -5743,9 +5743,6 @@ "const": "shield", "default": "shield" }, - "shield_type": { - "$ref": "#/components/schemas/ShieldType" - }, "params": { "type": "object", "additionalProperties": { @@ -5777,20 +5774,10 @@ "identifier", "provider_resource_id", "provider_id", - "type", - "shield_type" + "type" ], "title": "A safety shield resource that can be used to check content" }, - "ShieldType": { - "type": "string", - "enum": [ - "generic_content_shield", - "llama_guard", - "code_scanner", - "prompt_guard" - ] - }, "Trace": { "type": "object", "properties": { @@ -7262,9 +7249,6 @@ "shield_id": { "type": "string" }, - "shield_type": { - "$ref": "#/components/schemas/ShieldType" - }, "provider_shield_id": { "type": "string" }, @@ -7299,8 +7283,7 @@ }, "additionalProperties": false, "required": [ - "shield_id", - "shield_type" + "shield_id" ] }, "RunEvalRequest": { @@ -7854,13 +7837,19 @@ ], "tags": [ { - "name": "Inference" + "name": "MemoryBanks" + }, + { + "name": "BatchInference" }, { "name": "Agents" }, { - "name": "Telemetry" + "name": "Inference" + }, + { + "name": "DatasetIO" }, { "name": "Eval" @@ -7869,43 +7858,37 @@ "name": "Models" }, { - "name": "Inspect" - }, - { - "name": "EvalTasks" + "name": "PostTraining" }, { "name": "ScoringFunctions" }, { - "name": "Memory" - }, - { - "name": "Safety" - }, - { - "name": "DatasetIO" - }, - { - "name": "MemoryBanks" + "name": "Datasets" }, { "name": "Shields" }, { - "name": "PostTraining" + "name": "Telemetry" }, { - "name": "Datasets" + "name": "Inspect" }, { - "name": "Scoring" + "name": "Safety" }, { "name": "SyntheticDataGeneration" }, { - "name": "BatchInference" + "name": "Memory" + }, + { + "name": "Scoring" + }, + { + "name": "EvalTasks" }, { "name": "BuiltinTool", @@ -8255,10 +8238,6 @@ "name": "Shield", "description": "A safety shield resource that can be used to check content\n\n" }, - { - "name": "ShieldType", - "description": "" - }, { "name": "Trace", "description": "" @@ -8614,7 +8593,6 @@ "Session", "Shield", "ShieldCallStep", - "ShieldType", "SpanEndPayload", "SpanStartPayload", "SpanStatus", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 4e02e8075..b86c0df61 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -2227,11 +2227,8 @@ components: type: string shield_id: type: string - shield_type: - $ref: '#/components/schemas/ShieldType' required: - shield_id - - shield_type type: object RestAPIExecutionConfig: additionalProperties: false @@ -2698,8 +2695,6 @@ components: type: string provider_resource_id: type: string - shield_type: - $ref: '#/components/schemas/ShieldType' type: const: shield default: shield @@ -2709,7 +2704,6 @@ components: - provider_resource_id - provider_id - type - - shield_type title: A safety shield resource that can be used to check content type: object ShieldCallStep: @@ -2736,13 +2730,6 @@ components: - step_id - step_type type: object - ShieldType: - enum: - - generic_content_shield - - llama_guard - - code_scanner - - prompt_guard - type: string SpanEndPayload: additionalProperties: false properties: @@ -3397,7 +3384,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871" + \ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4761,24 +4748,24 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Inference +- name: MemoryBanks +- name: BatchInference - name: Agents -- name: Telemetry +- name: Inference +- name: DatasetIO - name: Eval - name: Models -- name: Inspect -- name: EvalTasks -- name: ScoringFunctions -- name: Memory -- name: Safety -- name: DatasetIO -- name: MemoryBanks -- name: Shields - name: PostTraining +- name: ScoringFunctions - name: Datasets -- name: Scoring +- name: Shields +- name: Telemetry +- name: Inspect +- name: Safety - name: SyntheticDataGeneration -- name: BatchInference +- name: Memory +- name: Scoring +- name: EvalTasks - description: name: BuiltinTool - description: ' name: Shield -- description: - name: ShieldType - description: name: Trace - description: 'Checkpoint created during training runs @@ -5343,7 +5328,6 @@ x-tagGroups: - Session - Shield - ShieldCallStep - - ShieldType - SpanEndPayload - SpanStartPayload - SpanStatus diff --git a/docs/zero_to_hero_guide/06_Safety101.ipynb b/docs/zero_to_hero_guide/06_Safety101.ipynb index e1e9301d3..f5352627e 100644 --- a/docs/zero_to_hero_guide/06_Safety101.ipynb +++ b/docs/zero_to_hero_guide/06_Safety101.ipynb @@ -182,13 +182,13 @@ " pass\n", "\n", " async def run_shield(\n", - " self, shield_type: str, messages: List[dict]\n", + " self, shield_id: str, messages: List[dict]\n", " ) -> RunShieldResponse:\n", " async with httpx.AsyncClient() as client:\n", " response = await client.post(\n", " f\"{self.base_url}/safety/run_shield\",\n", " json=dict(\n", - " shield_type=shield_type,\n", + " shield_id=shield_id,\n", " messages=[encodable_dict(m) for m in messages],\n", " ),\n", " headers={\n", @@ -216,7 +216,7 @@ " ]:\n", " cprint(f\"User>{message['content']}\", \"green\")\n", " response = await client.run_shield(\n", - " shield_type=\"llama_guard\",\n", + " shield_id=\"Llama-Guard-3-1B\",\n", " messages=[message],\n", " )\n", " print(response)\n", diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index f0f02b3c5..2dc74e6ec 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -38,7 +38,6 @@ class Dataset(CommonDatasetFields, Resource): return self.provider_resource_id -@json_schema_type class DatasetInput(CommonDatasetFields, BaseModel): dataset_id: str provider_id: Optional[str] = None diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py index 10c35c3ee..940dafc06 100644 --- a/llama_stack/apis/eval_tasks/eval_tasks.py +++ b/llama_stack/apis/eval_tasks/eval_tasks.py @@ -34,7 +34,6 @@ class EvalTask(CommonEvalTaskFields, Resource): return self.provider_resource_id -@json_schema_type class EvalTaskInput(CommonEvalTaskFields, BaseModel): eval_task_id: str provider_id: Optional[str] = None diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 83b292612..c0a0c643a 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -122,7 +122,6 @@ MemoryBank = Annotated[ ] -@json_schema_type class MemoryBankInput(BaseModel): memory_bank_id: str params: BankParams diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index a5d226886..2cd12b4bc 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -32,7 +32,6 @@ class Model(CommonModelFields, Resource): return self.provider_resource_id -@json_schema_type class ModelInput(CommonModelFields): model_id: str provider_id: Optional[str] = None diff --git a/llama_stack/apis/safety/client.py b/llama_stack/apis/safety/client.py index 96168fedd..d7d4bc981 100644 --- a/llama_stack/apis/safety/client.py +++ b/llama_stack/apis/safety/client.py @@ -27,7 +27,7 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety: def encodable_dict(d: BaseModel): - return json.loads(d.json()) + return json.loads(d.model_dump_json()) class SafetyClient(Safety): @@ -80,7 +80,7 @@ async def run_main(host: str, port: int, image_path: str = None): ) cprint(f"User>{message.content}", "green") response = await client.run_shield( - shield_id="llama_guard", + shield_id="Llama-Guard-3-1B", messages=[message], ) print(response) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 7a2a83c72..251a683c1 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -96,7 +96,6 @@ class ScoringFn(CommonScoringFnFields, Resource): return self.provider_resource_id -@json_schema_type class ScoringFnInput(CommonScoringFnFields, BaseModel): scoring_fn_id: str provider_id: Optional[str] = None diff --git a/llama_stack/apis/shields/client.py b/llama_stack/apis/shields/client.py index 2f6b5e649..7556d2d12 100644 --- a/llama_stack/apis/shields/client.py +++ b/llama_stack/apis/shields/client.py @@ -37,7 +37,6 @@ class ShieldsClient(Shields): async def register_shield( self, shield_id: str, - shield_type: ShieldType, provider_shield_id: Optional[str], provider_id: Optional[str], params: Optional[Dict[str, Any]], @@ -47,7 +46,6 @@ class ShieldsClient(Shields): f"{self.base_url}/shields/register", json={ "shield_id": shield_id, - "shield_type": shield_type, "provider_shield_id": provider_shield_id, "provider_id": provider_id, "params": params, @@ -56,12 +54,12 @@ class ShieldsClient(Shields): ) response.raise_for_status() - async def get_shield(self, shield_type: str) -> Optional[Shield]: + async def get_shield(self, shield_id: str) -> Optional[Shield]: async with httpx.AsyncClient() as client: response = await client.get( f"{self.base_url}/shields/get", params={ - "shield_type": shield_type, + "shield_id": shield_id, }, headers={"Content-Type": "application/json"}, ) diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 1dcfd4f4c..5ee444f68 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod @@ -13,16 +12,7 @@ from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType -@json_schema_type -class ShieldType(Enum): - generic_content_shield = "generic_content_shield" - llama_guard = "llama_guard" - code_scanner = "code_scanner" - prompt_guard = "prompt_guard" - - class CommonShieldFields(BaseModel): - shield_type: ShieldType params: Optional[Dict[str, Any]] = None @@ -59,7 +49,6 @@ class Shields(Protocol): async def register_shield( self, shield_id: str, - shield_type: ShieldType, provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5f6395e0d..220dfdb56 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -172,13 +172,12 @@ class SafetyRouter(Safety): async def register_shield( self, shield_id: str, - shield_type: ShieldType, provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ) -> Shield: return await self.routing_table.register_shield( - shield_id, shield_type, provider_shield_id, provider_id, params + shield_id, provider_shield_id, provider_id, params ) async def run_shield( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 7b369df2c..d6fb5d662 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -136,17 +136,18 @@ class CommonRoutingTableImpl(RoutingTable): else: raise ValueError("Unknown routing table type") + apiname, objtype = apiname_object() + # Get objects from disk registry - objects = self.dist_registry.get_cached(routing_key) + objects = self.dist_registry.get_cached(objtype, routing_key) if not objects: - apiname, objname = apiname_object() provider_ids = list(self.impls_by_provider_id.keys()) if len(provider_ids) > 1: provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" else: provider_ids_str = f"provider: `{provider_ids[0]}`" raise ValueError( - f"{objname.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}." + f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." ) for obj in objects: @@ -156,19 +157,19 @@ class CommonRoutingTableImpl(RoutingTable): raise ValueError(f"Provider not found for `{routing_key}`") async def get_object_by_identifier( - self, identifier: str + self, type: str, identifier: str ) -> Optional[RoutableObjectWithProvider]: # Get from disk registry - objects = await self.dist_registry.get(identifier) + objects = await self.dist_registry.get(type, identifier) if not objects: return None - # kind of ill-defined behavior here, but we'll just return the first one + assert len(objects) == 1 return objects[0] async def register_object(self, obj: RoutableObjectWithProvider): # Get existing objects from registry - existing_objects = await self.dist_registry.get(obj.identifier) + existing_objects = await self.dist_registry.get(obj.type, obj.identifier) # Check for existing registration for existing_obj in existing_objects: @@ -200,7 +201,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return await self.get_all_with_type("model") async def get_model(self, identifier: str) -> Optional[Model]: - return await self.get_object_by_identifier(identifier) + return await self.get_object_by_identifier("model", identifier) async def register_model( self, @@ -236,12 +237,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): return await self.get_all_with_type(ResourceType.shield.value) async def get_shield(self, identifier: str) -> Optional[Shield]: - return await self.get_object_by_identifier(identifier) + return await self.get_object_by_identifier("shield", identifier) async def register_shield( self, shield_id: str, - shield_type: ShieldType, provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, @@ -260,7 +260,6 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): params = {} shield = Shield( identifier=shield_id, - shield_type=shield_type, provider_resource_id=provider_shield_id, provider_id=provider_id, params=params, @@ -274,7 +273,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): return await self.get_all_with_type(ResourceType.memory_bank.value) async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: - return await self.get_object_by_identifier(memory_bank_id) + return await self.get_object_by_identifier("memory_bank", memory_bank_id) async def register_memory_bank( self, @@ -312,7 +311,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): return await self.get_all_with_type("dataset") async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: - return await self.get_object_by_identifier(dataset_id) + return await self.get_object_by_identifier("dataset", dataset_id) async def register_dataset( self, @@ -348,10 +347,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> List[ScoringFn]: - return await self.get_all_with_type(ResourceType.scoring_function.value) + return await self.get_all_with_type("scoring_function") async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: - return await self.get_object_by_identifier(scoring_fn_id) + return await self.get_object_by_identifier("scoring_function", scoring_fn_id) async def register_scoring_function( self, @@ -389,7 +388,7 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): return await self.get_all_with_type("eval_task") async def get_eval_task(self, name: str) -> Optional[EvalTask]: - return await self.get_object_by_identifier(name) + return await self.get_object_by_identifier("eval_task", name) async def register_eval_task( self, diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 3afd51304..1c7325eee 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from typing import Any, Dict -from termcolor import colored from termcolor import colored diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 6115ea1b3..d837c4375 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json -from typing import Dict, List, Optional, Protocol +from typing import Dict, List, Optional, Protocol, Tuple import pydantic @@ -35,7 +35,8 @@ class DistributionRegistry(Protocol): async def register(self, obj: RoutableObjectWithProvider) -> bool: ... -KEY_FORMAT = "distributions:registry:v1::{}" +KEY_VERSION = "v1" +KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}" class DiskDistributionRegistry(DistributionRegistry): @@ -45,18 +46,24 @@ class DiskDistributionRegistry(DistributionRegistry): async def initialize(self) -> None: pass - def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: + def get_cached( + self, type: str, identifier: str + ) -> List[RoutableObjectWithProvider]: # Disk registry does not have a cache return [] async def get_all(self) -> List[RoutableObjectWithProvider]: - start_key = KEY_FORMAT.format("") - end_key = KEY_FORMAT.format("\xff") + start_key = KEY_FORMAT.format(type="", identifier="") + end_key = KEY_FORMAT.format(type="", identifier="\xff") keys = await self.kvstore.range(start_key, end_key) - return [await self.get(key.split(":")[-1]) for key in keys] - async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: - json_str = await self.kvstore.get(KEY_FORMAT.format(identifier)) + tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys] + return [await self.get(type, identifier) for type, identifier in tuples] + + async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: + json_str = await self.kvstore.get( + KEY_FORMAT.format(type=type, identifier=identifier) + ) if not json_str: return [] @@ -70,7 +77,7 @@ class DiskDistributionRegistry(DistributionRegistry): ] async def register(self, obj: RoutableObjectWithProvider) -> bool: - existing_objects = await self.get(obj.identifier) + existing_objects = await self.get(obj.type, obj.identifier) # dont register if the object's providerid already exists for eobj in existing_objects: if eobj.provider_id == obj.provider_id: @@ -82,7 +89,8 @@ class DiskDistributionRegistry(DistributionRegistry): obj.model_dump_json() for obj in existing_objects ] # Fixed variable name await self.kvstore.set( - KEY_FORMAT.format(obj.identifier), json.dumps(objects_json) + KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), + json.dumps(objects_json), ) return True @@ -90,33 +98,36 @@ class DiskDistributionRegistry(DistributionRegistry): class CachedDiskDistributionRegistry(DiskDistributionRegistry): def __init__(self, kvstore: KVStore): super().__init__(kvstore) - self.cache: Dict[str, List[RoutableObjectWithProvider]] = {} + self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {} async def initialize(self) -> None: - start_key = KEY_FORMAT.format("") - end_key = KEY_FORMAT.format("\xff") + start_key = KEY_FORMAT.format(type="", identifier="") + end_key = KEY_FORMAT.format(type="", identifier="\xff") keys = await self.kvstore.range(start_key, end_key) for key in keys: - identifier = key.split(":")[-1] - objects = await super().get(identifier) + type, identifier = key.split(":")[-2:] + objects = await super().get(type, identifier) if objects: - self.cache[identifier] = objects + self.cache[type, identifier] = objects - def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: - return self.cache.get(identifier, []) + def get_cached( + self, type: str, identifier: str + ) -> List[RoutableObjectWithProvider]: + return self.cache.get((type, identifier), []) async def get_all(self) -> List[RoutableObjectWithProvider]: return [item for sublist in self.cache.values() for item in sublist] - async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: - if identifier in self.cache: - return self.cache[identifier] + async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: + cachekey = (type, identifier) + if cachekey in self.cache: + return self.cache[cachekey] - objects = await super().get(identifier) + objects = await super().get(type, identifier) if objects: - self.cache[identifier] = objects + self.cache[cachekey] = objects return objects @@ -126,16 +137,17 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): if success: # Then update cache - if obj.identifier not in self.cache: - self.cache[obj.identifier] = [] + cachekey = (obj.type, obj.identifier) + if cachekey not in self.cache: + self.cache[cachekey] = [] # Check if provider already exists in cache - for cached_obj in self.cache[obj.identifier]: + for cached_obj in self.cache[cachekey]: if cached_obj.provider_id == obj.provider_id: return success # If not, update cache - self.cache[obj.identifier].append(obj) + self.cache[cachekey].append(obj) return success diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 1ca65c9bb..c477c685c 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -14,6 +14,12 @@ from .config import CodeScannerConfig from llama_stack.apis.safety import * # noqa: F403 +ALLOWED_CODE_SCANNER_MODEL_IDS = [ + "CodeScanner", + "CodeShield", +] + + class MetaReferenceCodeScannerSafetyImpl(Safety): def __init__(self, config: CodeScannerConfig, deps) -> None: self.config = config @@ -25,8 +31,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.code_scanner: - raise ValueError(f"Unsupported safety shield type: {shield.shield_type}") + if shield.provider_resource_id not in ALLOWED_CODE_SCANNER_MODEL_IDS: + raise ValueError( + f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}" + ) async def run_shield( self, diff --git a/llama_stack/providers/inline/safety/llama_guard/config.py b/llama_stack/providers/inline/safety/llama_guard/config.py index aec856bce..72036fd1c 100644 --- a/llama_stack/providers/inline/safety/llama_guard/config.py +++ b/llama_stack/providers/inline/safety/llama_guard/config.py @@ -6,32 +6,8 @@ from typing import List -from llama_models.sku_list import CoreModelId, safety_models - -from pydantic import BaseModel, field_validator +from pydantic import BaseModel class LlamaGuardConfig(BaseModel): - model: str = "Llama-Guard-3-1B" excluded_categories: List[str] = [] - - @field_validator("model") - @classmethod - def validate_model(cls, model: str) -> str: - permitted_models = [ - m.descriptor() - for m in safety_models() - if ( - m.core_model_id - in { - CoreModelId.llama_guard_3_8b, - CoreModelId.llama_guard_3_1b, - CoreModelId.llama_guard_3_11b_vision, - } - ) - ] - if model not in permitted_models: - raise ValueError( - f"Invalid model: {model}. Must be one of {permitted_models}" - ) - return model diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 12d012b16..494c1b43e 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -73,6 +73,11 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [ CAT_ELECTIONS, ] +LLAMA_GUARD_MODEL_IDS = [ + CoreModelId.llama_guard_3_8b.value, + CoreModelId.llama_guard_3_1b.value, + CoreModelId.llama_guard_3_11b_vision.value, +] MODEL_TO_SAFETY_CATEGORIES_MAP = { CoreModelId.llama_guard_3_8b.value: ( @@ -118,18 +123,16 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): self.inference_api = deps[Api.inference] async def initialize(self) -> None: - self.shield = LlamaGuardShield( - model=self.config.model, - inference_api=self.inference_api, - excluded_categories=self.config.excluded_categories, - ) + pass async def shutdown(self) -> None: pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.llama_guard: - raise ValueError(f"Unsupported shield type: {shield.shield_type}") + if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS: + raise ValueError( + f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}" + ) async def run_shield( self, @@ -147,7 +150,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): if len(messages) > 0 and messages[0].role != Role.user.value: messages[0] = UserMessage(content=messages[0].content) - return await self.shield.run(messages) + impl = LlamaGuardShield( + model=shield.provider_resource_id, + inference_api=self.inference_api, + excluded_categories=self.config.excluded_categories, + ) + + return await impl.run(messages) class LlamaGuardShield: diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 20bfdd241..9f3d78374 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -36,8 +36,10 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.prompt_guard: - raise ValueError(f"Unsupported shield type: {shield.shield_type}") + if shield.provider_resource_id != PROMPT_GUARD_MODEL: + raise ValueError( + f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. " + ) async def run_shield( self, diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index d49035321..78e8105e0 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -20,11 +20,6 @@ from .config import BedrockSafetyConfig logger = logging.getLogger(__name__) -BEDROCK_SUPPORTED_SHIELDS = [ - ShieldType.generic_content_shield, -] - - class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: BedrockSafetyConfig) -> None: self.config = config diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 64f493b88..db157174f 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -44,7 +44,7 @@ def agents_meta_reference() -> ProviderFixture: providers=[ Provider( provider_id="meta-reference", - provider_type="meta-reference", + provider_type="inline::meta-reference", config=MetaReferenceAgentsImplConfig( # TODO: make this an in-memory store persistence_store=SqliteKVStoreConfig( diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index b3f3dc31c..47e5a751f 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -81,15 +81,17 @@ async def create_agent_session(agents_impl, agent_config): class TestAgents: @pytest.mark.asyncio - async def test_agent_turns_with_safety(self, agents_stack, common_params): + async def test_agent_turns_with_safety( + self, safety_model, agents_stack, common_params + ): agents_impl, _ = agents_stack agent_id, session_id = await create_agent_session( agents_impl, AgentConfig( **{ **common_params, - "input_shields": ["llama_guard"], - "output_shields": ["llama_guard"], + "input_shields": [safety_model], + "output_shields": [safety_model], } ), ) diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 66576e9d7..b73c2d798 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -9,7 +9,7 @@ import pytest_asyncio from llama_stack.apis.models import ModelInput -from llama_stack.apis.shields import ShieldInput, ShieldType +from llama_stack.apis.shields import ShieldInput from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig @@ -41,7 +41,7 @@ def safety_llama_guard(safety_model) -> ProviderFixture: Provider( provider_id="inline::llama-guard", provider_type="inline::llama-guard", - config=LlamaGuardConfig(model=safety_model).model_dump(), + config=LlamaGuardConfig().model_dump(), ) ], ) @@ -101,6 +101,8 @@ async def safety_stack(inference_model, safety_model, request): shield_provider_type = safety_fixture.providers[0].provider_type shield_input = get_shield_to_register(shield_provider_type, safety_model) + print(f"inference_model: {inference_model}") + print(f"shield_input = {shield_input}") impls = await resolve_impls_for_test_v2( [Api.safety, Api.shields, Api.inference], providers, @@ -114,20 +116,14 @@ async def safety_stack(inference_model, safety_model, request): def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput: - shield_config = {} - shield_type = ShieldType.llama_guard - identifier = "llama_guard" - if provider_type == "meta-reference": - shield_config["model"] = safety_model - elif provider_type == "remote::together": - shield_config["model"] = safety_model - elif provider_type == "remote::bedrock": + if provider_type == "remote::bedrock": identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER") - shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") - shield_type = ShieldType.generic_content_shield + params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")} + else: + params = {} + identifier = safety_model return ShieldInput( shield_id=identifier, - shield_type=shield_type, - params=shield_config, + params=params, ) diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index 48fab9741..9daa7bf40 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -34,7 +34,6 @@ class TestSafety: for shield in response: assert isinstance(shield, Shield) - assert shield.shield_type in [v for v in ShieldType] @pytest.mark.asyncio async def test_run_shield(self, safety_stack):