From b1c3a9548539e03d0a41d645011cee586e12b790 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 12 Nov 2024 11:41:23 -0800 Subject: [PATCH] Kill the notion of shield_type --- docs/resources/llama-stack-spec.html | 68 +++++++------------ docs/resources/llama-stack-spec.yaml | 42 ++++-------- docs/zero_to_hero_guide/06_Safety101.ipynb | 6 +- llama_stack/apis/datasets/datasets.py | 1 - llama_stack/apis/eval_tasks/eval_tasks.py | 1 - llama_stack/apis/memory_banks/memory_banks.py | 1 - llama_stack/apis/models/models.py | 1 - .../scoring_functions/scoring_functions.py | 1 - llama_stack/apis/shields/client.py | 6 +- llama_stack/apis/shields/shields.py | 11 --- llama_stack/distribution/routers/routers.py | 3 +- .../distribution/routers/routing_tables.py | 2 - .../safety/code_scanner/code_scanner.py | 12 +++- .../inline/safety/llama_guard/config.py | 26 +------ .../inline/safety/llama_guard/llama_guard.py | 25 ++++--- .../safety/prompt_guard/prompt_guard.py | 6 +- .../remote/safety/bedrock/bedrock.py | 5 -- .../providers/tests/agents/test_agents.py | 8 ++- .../providers/tests/safety/fixtures.py | 22 +++--- .../providers/tests/safety/test_safety.py | 1 - 20 files changed, 87 insertions(+), 161 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 231633464..7ef4ece21 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782" }, "servers": [ { @@ -5743,9 +5743,6 @@ "const": "shield", "default": "shield" }, - "shield_type": { - "$ref": "#/components/schemas/ShieldType" - }, "params": { "type": "object", "additionalProperties": { @@ -5777,20 +5774,10 @@ "identifier", "provider_resource_id", "provider_id", - "type", - "shield_type" + "type" ], "title": "A safety shield resource that can be used to check content" }, - "ShieldType": { - "type": "string", - "enum": [ - "generic_content_shield", - "llama_guard", - "code_scanner", - "prompt_guard" - ] - }, "Trace": { "type": "object", "properties": { @@ -7262,9 +7249,6 @@ "shield_id": { "type": "string" }, - "shield_type": { - "$ref": "#/components/schemas/ShieldType" - }, "provider_shield_id": { "type": "string" }, @@ -7299,8 +7283,7 @@ }, "additionalProperties": false, "required": [ - "shield_id", - "shield_type" + "shield_id" ] }, "RunEvalRequest": { @@ -7854,13 +7837,19 @@ ], "tags": [ { - "name": "Inference" + "name": "MemoryBanks" + }, + { + "name": "BatchInference" }, { "name": "Agents" }, { - "name": "Telemetry" + "name": "Inference" + }, + { + "name": "DatasetIO" }, { "name": "Eval" @@ -7869,43 +7858,37 @@ "name": "Models" }, { - "name": "Inspect" - }, - { - "name": "EvalTasks" + "name": "PostTraining" }, { "name": "ScoringFunctions" }, { - "name": "Memory" - }, - { - "name": "Safety" - }, - { - "name": "DatasetIO" - }, - { - "name": "MemoryBanks" + "name": "Datasets" }, { "name": "Shields" }, { - "name": "PostTraining" + "name": "Telemetry" }, { - "name": "Datasets" + "name": "Inspect" }, { - "name": "Scoring" + "name": "Safety" }, { "name": "SyntheticDataGeneration" }, { - "name": "BatchInference" + "name": "Memory" + }, + { + "name": "Scoring" + }, + { + "name": "EvalTasks" }, { "name": "BuiltinTool", @@ -8255,10 +8238,6 @@ "name": "Shield", "description": "A safety shield resource that can be used to check content\n\n" }, - { - "name": "ShieldType", - "description": "" - }, { "name": "Trace", "description": "" @@ -8614,7 +8593,6 @@ "Session", "Shield", "ShieldCallStep", - "ShieldType", "SpanEndPayload", "SpanStartPayload", "SpanStatus", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 4e02e8075..b86c0df61 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -2227,11 +2227,8 @@ components: type: string shield_id: type: string - shield_type: - $ref: '#/components/schemas/ShieldType' required: - shield_id - - shield_type type: object RestAPIExecutionConfig: additionalProperties: false @@ -2698,8 +2695,6 @@ components: type: string provider_resource_id: type: string - shield_type: - $ref: '#/components/schemas/ShieldType' type: const: shield default: shield @@ -2709,7 +2704,6 @@ components: - provider_resource_id - provider_id - type - - shield_type title: A safety shield resource that can be used to check content type: object ShieldCallStep: @@ -2736,13 +2730,6 @@ components: - step_id - step_type type: object - ShieldType: - enum: - - generic_content_shield - - llama_guard - - code_scanner - - prompt_guard - type: string SpanEndPayload: additionalProperties: false properties: @@ -3397,7 +3384,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871" + \ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4761,24 +4748,24 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Inference +- name: MemoryBanks +- name: BatchInference - name: Agents -- name: Telemetry +- name: Inference +- name: DatasetIO - name: Eval - name: Models -- name: Inspect -- name: EvalTasks -- name: ScoringFunctions -- name: Memory -- name: Safety -- name: DatasetIO -- name: MemoryBanks -- name: Shields - name: PostTraining +- name: ScoringFunctions - name: Datasets -- name: Scoring +- name: Shields +- name: Telemetry +- name: Inspect +- name: Safety - name: SyntheticDataGeneration -- name: BatchInference +- name: Memory +- name: Scoring +- name: EvalTasks - description: name: BuiltinTool - description: ' name: Shield -- description: - name: ShieldType - description: name: Trace - description: 'Checkpoint created during training runs @@ -5343,7 +5328,6 @@ x-tagGroups: - Session - Shield - ShieldCallStep - - ShieldType - SpanEndPayload - SpanStartPayload - SpanStatus diff --git a/docs/zero_to_hero_guide/06_Safety101.ipynb b/docs/zero_to_hero_guide/06_Safety101.ipynb index e1e9301d3..f5352627e 100644 --- a/docs/zero_to_hero_guide/06_Safety101.ipynb +++ b/docs/zero_to_hero_guide/06_Safety101.ipynb @@ -182,13 +182,13 @@ " pass\n", "\n", " async def run_shield(\n", - " self, shield_type: str, messages: List[dict]\n", + " self, shield_id: str, messages: List[dict]\n", " ) -> RunShieldResponse:\n", " async with httpx.AsyncClient() as client:\n", " response = await client.post(\n", " f\"{self.base_url}/safety/run_shield\",\n", " json=dict(\n", - " shield_type=shield_type,\n", + " shield_id=shield_id,\n", " messages=[encodable_dict(m) for m in messages],\n", " ),\n", " headers={\n", @@ -216,7 +216,7 @@ " ]:\n", " cprint(f\"User>{message['content']}\", \"green\")\n", " response = await client.run_shield(\n", - " shield_type=\"llama_guard\",\n", + " shield_id=\"Llama-Guard-3-1B\",\n", " messages=[message],\n", " )\n", " print(response)\n", diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index f0f02b3c5..2dc74e6ec 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -38,7 +38,6 @@ class Dataset(CommonDatasetFields, Resource): return self.provider_resource_id -@json_schema_type class DatasetInput(CommonDatasetFields, BaseModel): dataset_id: str provider_id: Optional[str] = None diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py index 10c35c3ee..940dafc06 100644 --- a/llama_stack/apis/eval_tasks/eval_tasks.py +++ b/llama_stack/apis/eval_tasks/eval_tasks.py @@ -34,7 +34,6 @@ class EvalTask(CommonEvalTaskFields, Resource): return self.provider_resource_id -@json_schema_type class EvalTaskInput(CommonEvalTaskFields, BaseModel): eval_task_id: str provider_id: Optional[str] = None diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 83b292612..c0a0c643a 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -122,7 +122,6 @@ MemoryBank = Annotated[ ] -@json_schema_type class MemoryBankInput(BaseModel): memory_bank_id: str params: BankParams diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index a5d226886..2cd12b4bc 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -32,7 +32,6 @@ class Model(CommonModelFields, Resource): return self.provider_resource_id -@json_schema_type class ModelInput(CommonModelFields): model_id: str provider_id: Optional[str] = None diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 7a2a83c72..251a683c1 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -96,7 +96,6 @@ class ScoringFn(CommonScoringFnFields, Resource): return self.provider_resource_id -@json_schema_type class ScoringFnInput(CommonScoringFnFields, BaseModel): scoring_fn_id: str provider_id: Optional[str] = None diff --git a/llama_stack/apis/shields/client.py b/llama_stack/apis/shields/client.py index 2f6b5e649..7556d2d12 100644 --- a/llama_stack/apis/shields/client.py +++ b/llama_stack/apis/shields/client.py @@ -37,7 +37,6 @@ class ShieldsClient(Shields): async def register_shield( self, shield_id: str, - shield_type: ShieldType, provider_shield_id: Optional[str], provider_id: Optional[str], params: Optional[Dict[str, Any]], @@ -47,7 +46,6 @@ class ShieldsClient(Shields): f"{self.base_url}/shields/register", json={ "shield_id": shield_id, - "shield_type": shield_type, "provider_shield_id": provider_shield_id, "provider_id": provider_id, "params": params, @@ -56,12 +54,12 @@ class ShieldsClient(Shields): ) response.raise_for_status() - async def get_shield(self, shield_type: str) -> Optional[Shield]: + async def get_shield(self, shield_id: str) -> Optional[Shield]: async with httpx.AsyncClient() as client: response = await client.get( f"{self.base_url}/shields/get", params={ - "shield_type": shield_type, + "shield_id": shield_id, }, headers={"Content-Type": "application/json"}, ) diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 1dcfd4f4c..5ee444f68 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod @@ -13,16 +12,7 @@ from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType -@json_schema_type -class ShieldType(Enum): - generic_content_shield = "generic_content_shield" - llama_guard = "llama_guard" - code_scanner = "code_scanner" - prompt_guard = "prompt_guard" - - class CommonShieldFields(BaseModel): - shield_type: ShieldType params: Optional[Dict[str, Any]] = None @@ -59,7 +49,6 @@ class Shields(Protocol): async def register_shield( self, shield_id: str, - shield_type: ShieldType, provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5f6395e0d..220dfdb56 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -172,13 +172,12 @@ class SafetyRouter(Safety): async def register_shield( self, shield_id: str, - shield_type: ShieldType, provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ) -> Shield: return await self.routing_table.register_shield( - shield_id, shield_type, provider_shield_id, provider_id, params + shield_id, provider_shield_id, provider_id, params ) async def run_shield( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 7b369df2c..0ca798641 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -241,7 +241,6 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def register_shield( self, shield_id: str, - shield_type: ShieldType, provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, @@ -260,7 +259,6 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): params = {} shield = Shield( identifier=shield_id, - shield_type=shield_type, provider_resource_id=provider_shield_id, provider_id=provider_id, params=params, diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 1ca65c9bb..c477c685c 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -14,6 +14,12 @@ from .config import CodeScannerConfig from llama_stack.apis.safety import * # noqa: F403 +ALLOWED_CODE_SCANNER_MODEL_IDS = [ + "CodeScanner", + "CodeShield", +] + + class MetaReferenceCodeScannerSafetyImpl(Safety): def __init__(self, config: CodeScannerConfig, deps) -> None: self.config = config @@ -25,8 +31,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.code_scanner: - raise ValueError(f"Unsupported safety shield type: {shield.shield_type}") + if shield.provider_resource_id not in ALLOWED_CODE_SCANNER_MODEL_IDS: + raise ValueError( + f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}" + ) async def run_shield( self, diff --git a/llama_stack/providers/inline/safety/llama_guard/config.py b/llama_stack/providers/inline/safety/llama_guard/config.py index aec856bce..72036fd1c 100644 --- a/llama_stack/providers/inline/safety/llama_guard/config.py +++ b/llama_stack/providers/inline/safety/llama_guard/config.py @@ -6,32 +6,8 @@ from typing import List -from llama_models.sku_list import CoreModelId, safety_models - -from pydantic import BaseModel, field_validator +from pydantic import BaseModel class LlamaGuardConfig(BaseModel): - model: str = "Llama-Guard-3-1B" excluded_categories: List[str] = [] - - @field_validator("model") - @classmethod - def validate_model(cls, model: str) -> str: - permitted_models = [ - m.descriptor() - for m in safety_models() - if ( - m.core_model_id - in { - CoreModelId.llama_guard_3_8b, - CoreModelId.llama_guard_3_1b, - CoreModelId.llama_guard_3_11b_vision, - } - ) - ] - if model not in permitted_models: - raise ValueError( - f"Invalid model: {model}. Must be one of {permitted_models}" - ) - return model diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 12d012b16..494c1b43e 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -73,6 +73,11 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [ CAT_ELECTIONS, ] +LLAMA_GUARD_MODEL_IDS = [ + CoreModelId.llama_guard_3_8b.value, + CoreModelId.llama_guard_3_1b.value, + CoreModelId.llama_guard_3_11b_vision.value, +] MODEL_TO_SAFETY_CATEGORIES_MAP = { CoreModelId.llama_guard_3_8b.value: ( @@ -118,18 +123,16 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): self.inference_api = deps[Api.inference] async def initialize(self) -> None: - self.shield = LlamaGuardShield( - model=self.config.model, - inference_api=self.inference_api, - excluded_categories=self.config.excluded_categories, - ) + pass async def shutdown(self) -> None: pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.llama_guard: - raise ValueError(f"Unsupported shield type: {shield.shield_type}") + if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS: + raise ValueError( + f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}" + ) async def run_shield( self, @@ -147,7 +150,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): if len(messages) > 0 and messages[0].role != Role.user.value: messages[0] = UserMessage(content=messages[0].content) - return await self.shield.run(messages) + impl = LlamaGuardShield( + model=shield.provider_resource_id, + inference_api=self.inference_api, + excluded_categories=self.config.excluded_categories, + ) + + return await impl.run(messages) class LlamaGuardShield: diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 20bfdd241..9f3d78374 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -36,8 +36,10 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.prompt_guard: - raise ValueError(f"Unsupported shield type: {shield.shield_type}") + if shield.provider_resource_id != PROMPT_GUARD_MODEL: + raise ValueError( + f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. " + ) async def run_shield( self, diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index d49035321..78e8105e0 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -20,11 +20,6 @@ from .config import BedrockSafetyConfig logger = logging.getLogger(__name__) -BEDROCK_SUPPORTED_SHIELDS = [ - ShieldType.generic_content_shield, -] - - class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: BedrockSafetyConfig) -> None: self.config = config diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index b3f3dc31c..47e5a751f 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -81,15 +81,17 @@ async def create_agent_session(agents_impl, agent_config): class TestAgents: @pytest.mark.asyncio - async def test_agent_turns_with_safety(self, agents_stack, common_params): + async def test_agent_turns_with_safety( + self, safety_model, agents_stack, common_params + ): agents_impl, _ = agents_stack agent_id, session_id = await create_agent_session( agents_impl, AgentConfig( **{ **common_params, - "input_shields": ["llama_guard"], - "output_shields": ["llama_guard"], + "input_shields": [safety_model], + "output_shields": [safety_model], } ), ) diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 66576e9d7..437c8fa54 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -9,7 +9,7 @@ import pytest_asyncio from llama_stack.apis.models import ModelInput -from llama_stack.apis.shields import ShieldInput, ShieldType +from llama_stack.apis.shields import ShieldInput from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig @@ -41,7 +41,7 @@ def safety_llama_guard(safety_model) -> ProviderFixture: Provider( provider_id="inline::llama-guard", provider_type="inline::llama-guard", - config=LlamaGuardConfig(model=safety_model).model_dump(), + config=LlamaGuardConfig().model_dump(), ) ], ) @@ -114,20 +114,14 @@ async def safety_stack(inference_model, safety_model, request): def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput: - shield_config = {} - shield_type = ShieldType.llama_guard - identifier = "llama_guard" - if provider_type == "meta-reference": - shield_config["model"] = safety_model - elif provider_type == "remote::together": - shield_config["model"] = safety_model - elif provider_type == "remote::bedrock": + if provider_type == "remote::bedrock": identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER") - shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") - shield_type = ShieldType.generic_content_shield + params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")} + else: + params = {} + identifier = safety_model return ShieldInput( shield_id=identifier, - shield_type=shield_type, - params=shield_config, + params=params, ) diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index 48fab9741..9daa7bf40 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -34,7 +34,6 @@ class TestSafety: for shield in response: assert isinstance(shield, Shield) - assert shield.shield_type in [v for v in ShieldType] @pytest.mark.asyncio async def test_run_shield(self, safety_stack):