diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 79701d926..fb02dd136 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -4,7 +4,8 @@ In short, provide a summary of what this PR does and why. Usually, the relevant - [ ] Addresses issue (#issue) -## Feature/Issue validation/testing/test plan + +## Test Plan Please describe: - tests you ran to verify your changes with result summaries. diff --git a/docs/_deprecating_soon.ipynb b/docs/_deprecating_soon.ipynb index 343005962..7fa4034ce 100644 --- a/docs/_deprecating_soon.ipynb +++ b/docs/_deprecating_soon.ipynb @@ -180,8 +180,8 @@ " tools=tools,\n", " tool_choice=\"auto\",\n", " tool_prompt_format=\"json\",\n", - " input_shields=[\"llama_guard\"],\n", - " output_shields=[\"llama_guard\"],\n", + " input_shields=[\"Llama-Guard-3-1B\"],\n", + " output_shields=[\"Llama-Guard-3-1B\"],\n", " enable_session_persistence=True,\n", " )\n", "\n", diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 231633464..7ef4ece21 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782" }, "servers": [ { @@ -5743,9 +5743,6 @@ "const": "shield", "default": "shield" }, - "shield_type": { - "$ref": "#/components/schemas/ShieldType" - }, "params": { "type": "object", "additionalProperties": { @@ -5777,20 +5774,10 @@ "identifier", "provider_resource_id", "provider_id", - "type", - "shield_type" + "type" ], "title": "A safety shield resource that can be used to check content" }, - "ShieldType": { - "type": "string", - "enum": [ - "generic_content_shield", - "llama_guard", - "code_scanner", - "prompt_guard" - ] - }, "Trace": { "type": "object", "properties": { @@ -7262,9 +7249,6 @@ "shield_id": { "type": "string" }, - "shield_type": { - "$ref": "#/components/schemas/ShieldType" - }, "provider_shield_id": { "type": "string" }, @@ -7299,8 +7283,7 @@ }, "additionalProperties": false, "required": [ - "shield_id", - "shield_type" + "shield_id" ] }, "RunEvalRequest": { @@ -7854,13 +7837,19 @@ ], "tags": [ { - "name": "Inference" + "name": "MemoryBanks" + }, + { + "name": "BatchInference" }, { "name": "Agents" }, { - "name": "Telemetry" + "name": "Inference" + }, + { + "name": "DatasetIO" }, { "name": "Eval" @@ -7869,43 +7858,37 @@ "name": "Models" }, { - "name": "Inspect" - }, - { - "name": "EvalTasks" + "name": "PostTraining" }, { "name": "ScoringFunctions" }, { - "name": "Memory" - }, - { - "name": "Safety" - }, - { - "name": "DatasetIO" - }, - { - "name": "MemoryBanks" + "name": "Datasets" }, { "name": "Shields" }, { - "name": "PostTraining" + "name": "Telemetry" }, { - "name": "Datasets" + "name": "Inspect" }, { - "name": "Scoring" + "name": "Safety" }, { "name": "SyntheticDataGeneration" }, { - "name": "BatchInference" + "name": "Memory" + }, + { + "name": "Scoring" + }, + { + "name": "EvalTasks" }, { "name": "BuiltinTool", @@ -8255,10 +8238,6 @@ "name": "Shield", "description": "A safety shield resource that can be used to check content\n\n" }, - { - "name": "ShieldType", - "description": "" - }, { "name": "Trace", "description": "" @@ -8614,7 +8593,6 @@ "Session", "Shield", "ShieldCallStep", - "ShieldType", "SpanEndPayload", "SpanStartPayload", "SpanStatus", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 4e02e8075..b86c0df61 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -2227,11 +2227,8 @@ components: type: string shield_id: type: string - shield_type: - $ref: '#/components/schemas/ShieldType' required: - shield_id - - shield_type type: object RestAPIExecutionConfig: additionalProperties: false @@ -2698,8 +2695,6 @@ components: type: string provider_resource_id: type: string - shield_type: - $ref: '#/components/schemas/ShieldType' type: const: shield default: shield @@ -2709,7 +2704,6 @@ components: - provider_resource_id - provider_id - type - - shield_type title: A safety shield resource that can be used to check content type: object ShieldCallStep: @@ -2736,13 +2730,6 @@ components: - step_id - step_type type: object - ShieldType: - enum: - - generic_content_shield - - llama_guard - - code_scanner - - prompt_guard - type: string SpanEndPayload: additionalProperties: false properties: @@ -3397,7 +3384,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871" + \ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4761,24 +4748,24 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Inference +- name: MemoryBanks +- name: BatchInference - name: Agents -- name: Telemetry +- name: Inference +- name: DatasetIO - name: Eval - name: Models -- name: Inspect -- name: EvalTasks -- name: ScoringFunctions -- name: Memory -- name: Safety -- name: DatasetIO -- name: MemoryBanks -- name: Shields - name: PostTraining +- name: ScoringFunctions - name: Datasets -- name: Scoring +- name: Shields +- name: Telemetry +- name: Inspect +- name: Safety - name: SyntheticDataGeneration -- name: BatchInference +- name: Memory +- name: Scoring +- name: EvalTasks - description: name: BuiltinTool - description: ' name: Shield -- description: - name: ShieldType - description: name: Trace - description: 'Checkpoint created during training runs @@ -5343,7 +5328,6 @@ x-tagGroups: - Session - Shield - ShieldCallStep - - ShieldType - SpanEndPayload - SpanStartPayload - SpanStatus diff --git a/docs/zero_to_hero_guide/06_Safety101.ipynb b/docs/zero_to_hero_guide/06_Safety101.ipynb index e1e9301d3..f5352627e 100644 --- a/docs/zero_to_hero_guide/06_Safety101.ipynb +++ b/docs/zero_to_hero_guide/06_Safety101.ipynb @@ -182,13 +182,13 @@ " pass\n", "\n", " async def run_shield(\n", - " self, shield_type: str, messages: List[dict]\n", + " self, shield_id: str, messages: List[dict]\n", " ) -> RunShieldResponse:\n", " async with httpx.AsyncClient() as client:\n", " response = await client.post(\n", " f\"{self.base_url}/safety/run_shield\",\n", " json=dict(\n", - " shield_type=shield_type,\n", + " shield_id=shield_id,\n", " messages=[encodable_dict(m) for m in messages],\n", " ),\n", " headers={\n", @@ -216,7 +216,7 @@ " ]:\n", " cprint(f\"User>{message['content']}\", \"green\")\n", " response = await client.run_shield(\n", - " shield_type=\"llama_guard\",\n", + " shield_id=\"Llama-Guard-3-1B\",\n", " messages=[message],\n", " )\n", " print(response)\n", diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index f0f02b3c5..2dc74e6ec 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -38,7 +38,6 @@ class Dataset(CommonDatasetFields, Resource): return self.provider_resource_id -@json_schema_type class DatasetInput(CommonDatasetFields, BaseModel): dataset_id: str provider_id: Optional[str] = None diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py index 10c35c3ee..940dafc06 100644 --- a/llama_stack/apis/eval_tasks/eval_tasks.py +++ b/llama_stack/apis/eval_tasks/eval_tasks.py @@ -34,7 +34,6 @@ class EvalTask(CommonEvalTaskFields, Resource): return self.provider_resource_id -@json_schema_type class EvalTaskInput(CommonEvalTaskFields, BaseModel): eval_task_id: str provider_id: Optional[str] = None diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 83b292612..c0a0c643a 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -122,7 +122,6 @@ MemoryBank = Annotated[ ] -@json_schema_type class MemoryBankInput(BaseModel): memory_bank_id: str params: BankParams diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index a5d226886..2cd12b4bc 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -32,7 +32,6 @@ class Model(CommonModelFields, Resource): return self.provider_resource_id -@json_schema_type class ModelInput(CommonModelFields): model_id: str provider_id: Optional[str] = None diff --git a/llama_stack/apis/safety/client.py b/llama_stack/apis/safety/client.py index 96168fedd..d7d4bc981 100644 --- a/llama_stack/apis/safety/client.py +++ b/llama_stack/apis/safety/client.py @@ -27,7 +27,7 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety: def encodable_dict(d: BaseModel): - return json.loads(d.json()) + return json.loads(d.model_dump_json()) class SafetyClient(Safety): @@ -80,7 +80,7 @@ async def run_main(host: str, port: int, image_path: str = None): ) cprint(f"User>{message.content}", "green") response = await client.run_shield( - shield_id="llama_guard", + shield_id="Llama-Guard-3-1B", messages=[message], ) print(response) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 7a2a83c72..251a683c1 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -96,7 +96,6 @@ class ScoringFn(CommonScoringFnFields, Resource): return self.provider_resource_id -@json_schema_type class ScoringFnInput(CommonScoringFnFields, BaseModel): scoring_fn_id: str provider_id: Optional[str] = None diff --git a/llama_stack/apis/shields/client.py b/llama_stack/apis/shields/client.py index 2f6b5e649..7556d2d12 100644 --- a/llama_stack/apis/shields/client.py +++ b/llama_stack/apis/shields/client.py @@ -37,7 +37,6 @@ class ShieldsClient(Shields): async def register_shield( self, shield_id: str, - shield_type: ShieldType, provider_shield_id: Optional[str], provider_id: Optional[str], params: Optional[Dict[str, Any]], @@ -47,7 +46,6 @@ class ShieldsClient(Shields): f"{self.base_url}/shields/register", json={ "shield_id": shield_id, - "shield_type": shield_type, "provider_shield_id": provider_shield_id, "provider_id": provider_id, "params": params, @@ -56,12 +54,12 @@ class ShieldsClient(Shields): ) response.raise_for_status() - async def get_shield(self, shield_type: str) -> Optional[Shield]: + async def get_shield(self, shield_id: str) -> Optional[Shield]: async with httpx.AsyncClient() as client: response = await client.get( f"{self.base_url}/shields/get", params={ - "shield_type": shield_type, + "shield_id": shield_id, }, headers={"Content-Type": "application/json"}, ) diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 1dcfd4f4c..5ee444f68 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod @@ -13,16 +12,7 @@ from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType -@json_schema_type -class ShieldType(Enum): - generic_content_shield = "generic_content_shield" - llama_guard = "llama_guard" - code_scanner = "code_scanner" - prompt_guard = "prompt_guard" - - class CommonShieldFields(BaseModel): - shield_type: ShieldType params: Optional[Dict[str, Any]] = None @@ -59,7 +49,6 @@ class Shields(Protocol): async def register_shield( self, shield_id: str, - shield_type: ShieldType, provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5f6395e0d..220dfdb56 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -172,13 +172,12 @@ class SafetyRouter(Safety): async def register_shield( self, shield_id: str, - shield_type: ShieldType, provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ) -> Shield: return await self.routing_table.register_shield( - shield_id, shield_type, provider_shield_id, provider_id, params + shield_id, provider_shield_id, provider_id, params ) async def run_shield( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 7b369df2c..d6fb5d662 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -136,17 +136,18 @@ class CommonRoutingTableImpl(RoutingTable): else: raise ValueError("Unknown routing table type") + apiname, objtype = apiname_object() + # Get objects from disk registry - objects = self.dist_registry.get_cached(routing_key) + objects = self.dist_registry.get_cached(objtype, routing_key) if not objects: - apiname, objname = apiname_object() provider_ids = list(self.impls_by_provider_id.keys()) if len(provider_ids) > 1: provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" else: provider_ids_str = f"provider: `{provider_ids[0]}`" raise ValueError( - f"{objname.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}." + f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." ) for obj in objects: @@ -156,19 +157,19 @@ class CommonRoutingTableImpl(RoutingTable): raise ValueError(f"Provider not found for `{routing_key}`") async def get_object_by_identifier( - self, identifier: str + self, type: str, identifier: str ) -> Optional[RoutableObjectWithProvider]: # Get from disk registry - objects = await self.dist_registry.get(identifier) + objects = await self.dist_registry.get(type, identifier) if not objects: return None - # kind of ill-defined behavior here, but we'll just return the first one + assert len(objects) == 1 return objects[0] async def register_object(self, obj: RoutableObjectWithProvider): # Get existing objects from registry - existing_objects = await self.dist_registry.get(obj.identifier) + existing_objects = await self.dist_registry.get(obj.type, obj.identifier) # Check for existing registration for existing_obj in existing_objects: @@ -200,7 +201,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return await self.get_all_with_type("model") async def get_model(self, identifier: str) -> Optional[Model]: - return await self.get_object_by_identifier(identifier) + return await self.get_object_by_identifier("model", identifier) async def register_model( self, @@ -236,12 +237,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): return await self.get_all_with_type(ResourceType.shield.value) async def get_shield(self, identifier: str) -> Optional[Shield]: - return await self.get_object_by_identifier(identifier) + return await self.get_object_by_identifier("shield", identifier) async def register_shield( self, shield_id: str, - shield_type: ShieldType, provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, @@ -260,7 +260,6 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): params = {} shield = Shield( identifier=shield_id, - shield_type=shield_type, provider_resource_id=provider_shield_id, provider_id=provider_id, params=params, @@ -274,7 +273,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): return await self.get_all_with_type(ResourceType.memory_bank.value) async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: - return await self.get_object_by_identifier(memory_bank_id) + return await self.get_object_by_identifier("memory_bank", memory_bank_id) async def register_memory_bank( self, @@ -312,7 +311,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): return await self.get_all_with_type("dataset") async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: - return await self.get_object_by_identifier(dataset_id) + return await self.get_object_by_identifier("dataset", dataset_id) async def register_dataset( self, @@ -348,10 +347,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> List[ScoringFn]: - return await self.get_all_with_type(ResourceType.scoring_function.value) + return await self.get_all_with_type("scoring_function") async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: - return await self.get_object_by_identifier(scoring_fn_id) + return await self.get_object_by_identifier("scoring_function", scoring_fn_id) async def register_scoring_function( self, @@ -389,7 +388,7 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): return await self.get_all_with_type("eval_task") async def get_eval_task(self, name: str) -> Optional[EvalTask]: - return await self.get_object_by_identifier(name) + return await self.get_object_by_identifier("eval_task", name) async def register_eval_task( self, diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 3afd51304..1c7325eee 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from typing import Any, Dict -from termcolor import colored from termcolor import colored diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 6115ea1b3..d837c4375 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json -from typing import Dict, List, Optional, Protocol +from typing import Dict, List, Optional, Protocol, Tuple import pydantic @@ -35,7 +35,8 @@ class DistributionRegistry(Protocol): async def register(self, obj: RoutableObjectWithProvider) -> bool: ... -KEY_FORMAT = "distributions:registry:v1::{}" +KEY_VERSION = "v1" +KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}" class DiskDistributionRegistry(DistributionRegistry): @@ -45,18 +46,24 @@ class DiskDistributionRegistry(DistributionRegistry): async def initialize(self) -> None: pass - def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: + def get_cached( + self, type: str, identifier: str + ) -> List[RoutableObjectWithProvider]: # Disk registry does not have a cache return [] async def get_all(self) -> List[RoutableObjectWithProvider]: - start_key = KEY_FORMAT.format("") - end_key = KEY_FORMAT.format("\xff") + start_key = KEY_FORMAT.format(type="", identifier="") + end_key = KEY_FORMAT.format(type="", identifier="\xff") keys = await self.kvstore.range(start_key, end_key) - return [await self.get(key.split(":")[-1]) for key in keys] - async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: - json_str = await self.kvstore.get(KEY_FORMAT.format(identifier)) + tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys] + return [await self.get(type, identifier) for type, identifier in tuples] + + async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: + json_str = await self.kvstore.get( + KEY_FORMAT.format(type=type, identifier=identifier) + ) if not json_str: return [] @@ -70,7 +77,7 @@ class DiskDistributionRegistry(DistributionRegistry): ] async def register(self, obj: RoutableObjectWithProvider) -> bool: - existing_objects = await self.get(obj.identifier) + existing_objects = await self.get(obj.type, obj.identifier) # dont register if the object's providerid already exists for eobj in existing_objects: if eobj.provider_id == obj.provider_id: @@ -82,7 +89,8 @@ class DiskDistributionRegistry(DistributionRegistry): obj.model_dump_json() for obj in existing_objects ] # Fixed variable name await self.kvstore.set( - KEY_FORMAT.format(obj.identifier), json.dumps(objects_json) + KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), + json.dumps(objects_json), ) return True @@ -90,33 +98,36 @@ class DiskDistributionRegistry(DistributionRegistry): class CachedDiskDistributionRegistry(DiskDistributionRegistry): def __init__(self, kvstore: KVStore): super().__init__(kvstore) - self.cache: Dict[str, List[RoutableObjectWithProvider]] = {} + self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {} async def initialize(self) -> None: - start_key = KEY_FORMAT.format("") - end_key = KEY_FORMAT.format("\xff") + start_key = KEY_FORMAT.format(type="", identifier="") + end_key = KEY_FORMAT.format(type="", identifier="\xff") keys = await self.kvstore.range(start_key, end_key) for key in keys: - identifier = key.split(":")[-1] - objects = await super().get(identifier) + type, identifier = key.split(":")[-2:] + objects = await super().get(type, identifier) if objects: - self.cache[identifier] = objects + self.cache[type, identifier] = objects - def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: - return self.cache.get(identifier, []) + def get_cached( + self, type: str, identifier: str + ) -> List[RoutableObjectWithProvider]: + return self.cache.get((type, identifier), []) async def get_all(self) -> List[RoutableObjectWithProvider]: return [item for sublist in self.cache.values() for item in sublist] - async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: - if identifier in self.cache: - return self.cache[identifier] + async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: + cachekey = (type, identifier) + if cachekey in self.cache: + return self.cache[cachekey] - objects = await super().get(identifier) + objects = await super().get(type, identifier) if objects: - self.cache[identifier] = objects + self.cache[cachekey] = objects return objects @@ -126,16 +137,17 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): if success: # Then update cache - if obj.identifier not in self.cache: - self.cache[obj.identifier] = [] + cachekey = (obj.type, obj.identifier) + if cachekey not in self.cache: + self.cache[cachekey] = [] # Check if provider already exists in cache - for cached_obj in self.cache[obj.identifier]: + for cached_obj in self.cache[cachekey]: if cached_obj.provider_id == obj.provider_id: return success # If not, update cache - self.cache[obj.identifier].append(obj) + self.cache[cachekey].append(obj) return success diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 1ca65c9bb..c477c685c 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -14,6 +14,12 @@ from .config import CodeScannerConfig from llama_stack.apis.safety import * # noqa: F403 +ALLOWED_CODE_SCANNER_MODEL_IDS = [ + "CodeScanner", + "CodeShield", +] + + class MetaReferenceCodeScannerSafetyImpl(Safety): def __init__(self, config: CodeScannerConfig, deps) -> None: self.config = config @@ -25,8 +31,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.code_scanner: - raise ValueError(f"Unsupported safety shield type: {shield.shield_type}") + if shield.provider_resource_id not in ALLOWED_CODE_SCANNER_MODEL_IDS: + raise ValueError( + f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}" + ) async def run_shield( self, diff --git a/llama_stack/providers/inline/safety/llama_guard/config.py b/llama_stack/providers/inline/safety/llama_guard/config.py index aec856bce..72036fd1c 100644 --- a/llama_stack/providers/inline/safety/llama_guard/config.py +++ b/llama_stack/providers/inline/safety/llama_guard/config.py @@ -6,32 +6,8 @@ from typing import List -from llama_models.sku_list import CoreModelId, safety_models - -from pydantic import BaseModel, field_validator +from pydantic import BaseModel class LlamaGuardConfig(BaseModel): - model: str = "Llama-Guard-3-1B" excluded_categories: List[str] = [] - - @field_validator("model") - @classmethod - def validate_model(cls, model: str) -> str: - permitted_models = [ - m.descriptor() - for m in safety_models() - if ( - m.core_model_id - in { - CoreModelId.llama_guard_3_8b, - CoreModelId.llama_guard_3_1b, - CoreModelId.llama_guard_3_11b_vision, - } - ) - ] - if model not in permitted_models: - raise ValueError( - f"Invalid model: {model}. Must be one of {permitted_models}" - ) - return model diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 12d012b16..494c1b43e 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -73,6 +73,11 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [ CAT_ELECTIONS, ] +LLAMA_GUARD_MODEL_IDS = [ + CoreModelId.llama_guard_3_8b.value, + CoreModelId.llama_guard_3_1b.value, + CoreModelId.llama_guard_3_11b_vision.value, +] MODEL_TO_SAFETY_CATEGORIES_MAP = { CoreModelId.llama_guard_3_8b.value: ( @@ -118,18 +123,16 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): self.inference_api = deps[Api.inference] async def initialize(self) -> None: - self.shield = LlamaGuardShield( - model=self.config.model, - inference_api=self.inference_api, - excluded_categories=self.config.excluded_categories, - ) + pass async def shutdown(self) -> None: pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.llama_guard: - raise ValueError(f"Unsupported shield type: {shield.shield_type}") + if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS: + raise ValueError( + f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}" + ) async def run_shield( self, @@ -147,7 +150,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): if len(messages) > 0 and messages[0].role != Role.user.value: messages[0] = UserMessage(content=messages[0].content) - return await self.shield.run(messages) + impl = LlamaGuardShield( + model=shield.provider_resource_id, + inference_api=self.inference_api, + excluded_categories=self.config.excluded_categories, + ) + + return await impl.run(messages) class LlamaGuardShield: diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 20bfdd241..9f3d78374 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -36,8 +36,10 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.prompt_guard: - raise ValueError(f"Unsupported shield type: {shield.shield_type}") + if shield.provider_resource_id != PROMPT_GUARD_MODEL: + raise ValueError( + f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. " + ) async def run_shield( self, diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index d49035321..78e8105e0 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -20,11 +20,6 @@ from .config import BedrockSafetyConfig logger = logging.getLogger(__name__) -BEDROCK_SUPPORTED_SHIELDS = [ - ShieldType.generic_content_shield, -] - - class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: BedrockSafetyConfig) -> None: self.config = config diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 64f493b88..db157174f 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -44,7 +44,7 @@ def agents_meta_reference() -> ProviderFixture: providers=[ Provider( provider_id="meta-reference", - provider_type="meta-reference", + provider_type="inline::meta-reference", config=MetaReferenceAgentsImplConfig( # TODO: make this an in-memory store persistence_store=SqliteKVStoreConfig( diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index b3f3dc31c..47e5a751f 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -81,15 +81,17 @@ async def create_agent_session(agents_impl, agent_config): class TestAgents: @pytest.mark.asyncio - async def test_agent_turns_with_safety(self, agents_stack, common_params): + async def test_agent_turns_with_safety( + self, safety_model, agents_stack, common_params + ): agents_impl, _ = agents_stack agent_id, session_id = await create_agent_session( agents_impl, AgentConfig( **{ **common_params, - "input_shields": ["llama_guard"], - "output_shields": ["llama_guard"], + "input_shields": [safety_model], + "output_shields": [safety_model], } ), ) diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 66576e9d7..b73c2d798 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -9,7 +9,7 @@ import pytest_asyncio from llama_stack.apis.models import ModelInput -from llama_stack.apis.shields import ShieldInput, ShieldType +from llama_stack.apis.shields import ShieldInput from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig @@ -41,7 +41,7 @@ def safety_llama_guard(safety_model) -> ProviderFixture: Provider( provider_id="inline::llama-guard", provider_type="inline::llama-guard", - config=LlamaGuardConfig(model=safety_model).model_dump(), + config=LlamaGuardConfig().model_dump(), ) ], ) @@ -101,6 +101,8 @@ async def safety_stack(inference_model, safety_model, request): shield_provider_type = safety_fixture.providers[0].provider_type shield_input = get_shield_to_register(shield_provider_type, safety_model) + print(f"inference_model: {inference_model}") + print(f"shield_input = {shield_input}") impls = await resolve_impls_for_test_v2( [Api.safety, Api.shields, Api.inference], providers, @@ -114,20 +116,14 @@ async def safety_stack(inference_model, safety_model, request): def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput: - shield_config = {} - shield_type = ShieldType.llama_guard - identifier = "llama_guard" - if provider_type == "meta-reference": - shield_config["model"] = safety_model - elif provider_type == "remote::together": - shield_config["model"] = safety_model - elif provider_type == "remote::bedrock": + if provider_type == "remote::bedrock": identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER") - shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") - shield_type = ShieldType.generic_content_shield + params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")} + else: + params = {} + identifier = safety_model return ShieldInput( shield_id=identifier, - shield_type=shield_type, - params=shield_config, + params=params, ) diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index 48fab9741..9daa7bf40 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -34,7 +34,6 @@ class TestSafety: for shield in response: assert isinstance(shield, Shield) - assert shield.shield_type in [v for v in ShieldType] @pytest.mark.asyncio async def test_run_shield(self, safety_stack):