Kill the notion of shield_type

This commit is contained in:
Ashwin Bharambe 2024-11-12 11:41:23 -08:00
parent 09269e2a44
commit b1c3a95485
20 changed files with 87 additions and 161 deletions

View file

@ -21,7 +21,7 @@
"info": { "info": {
"title": "[DRAFT] Llama Stack Specification", "title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1", "version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871" "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
}, },
"servers": [ "servers": [
{ {
@ -5743,9 +5743,6 @@
"const": "shield", "const": "shield",
"default": "shield" "default": "shield"
}, },
"shield_type": {
"$ref": "#/components/schemas/ShieldType"
},
"params": { "params": {
"type": "object", "type": "object",
"additionalProperties": { "additionalProperties": {
@ -5777,20 +5774,10 @@
"identifier", "identifier",
"provider_resource_id", "provider_resource_id",
"provider_id", "provider_id",
"type", "type"
"shield_type"
], ],
"title": "A safety shield resource that can be used to check content" "title": "A safety shield resource that can be used to check content"
}, },
"ShieldType": {
"type": "string",
"enum": [
"generic_content_shield",
"llama_guard",
"code_scanner",
"prompt_guard"
]
},
"Trace": { "Trace": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -7262,9 +7249,6 @@
"shield_id": { "shield_id": {
"type": "string" "type": "string"
}, },
"shield_type": {
"$ref": "#/components/schemas/ShieldType"
},
"provider_shield_id": { "provider_shield_id": {
"type": "string" "type": "string"
}, },
@ -7299,8 +7283,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"shield_id", "shield_id"
"shield_type"
] ]
}, },
"RunEvalRequest": { "RunEvalRequest": {
@ -7854,13 +7837,19 @@
], ],
"tags": [ "tags": [
{ {
"name": "Inference" "name": "MemoryBanks"
},
{
"name": "BatchInference"
}, },
{ {
"name": "Agents" "name": "Agents"
}, },
{ {
"name": "Telemetry" "name": "Inference"
},
{
"name": "DatasetIO"
}, },
{ {
"name": "Eval" "name": "Eval"
@ -7869,43 +7858,37 @@
"name": "Models" "name": "Models"
}, },
{ {
"name": "Inspect" "name": "PostTraining"
},
{
"name": "EvalTasks"
}, },
{ {
"name": "ScoringFunctions" "name": "ScoringFunctions"
}, },
{ {
"name": "Memory" "name": "Datasets"
},
{
"name": "Safety"
},
{
"name": "DatasetIO"
},
{
"name": "MemoryBanks"
}, },
{ {
"name": "Shields" "name": "Shields"
}, },
{ {
"name": "PostTraining" "name": "Telemetry"
}, },
{ {
"name": "Datasets" "name": "Inspect"
}, },
{ {
"name": "Scoring" "name": "Safety"
}, },
{ {
"name": "SyntheticDataGeneration" "name": "SyntheticDataGeneration"
}, },
{ {
"name": "BatchInference" "name": "Memory"
},
{
"name": "Scoring"
},
{
"name": "EvalTasks"
}, },
{ {
"name": "BuiltinTool", "name": "BuiltinTool",
@ -8255,10 +8238,6 @@
"name": "Shield", "name": "Shield",
"description": "A safety shield resource that can be used to check content\n\n<SchemaDefinition schemaRef=\"#/components/schemas/Shield\" />" "description": "A safety shield resource that can be used to check content\n\n<SchemaDefinition schemaRef=\"#/components/schemas/Shield\" />"
}, },
{
"name": "ShieldType",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ShieldType\" />"
},
{ {
"name": "Trace", "name": "Trace",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Trace\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/Trace\" />"
@ -8614,7 +8593,6 @@
"Session", "Session",
"Shield", "Shield",
"ShieldCallStep", "ShieldCallStep",
"ShieldType",
"SpanEndPayload", "SpanEndPayload",
"SpanStartPayload", "SpanStartPayload",
"SpanStatus", "SpanStatus",

View file

@ -2227,11 +2227,8 @@ components:
type: string type: string
shield_id: shield_id:
type: string type: string
shield_type:
$ref: '#/components/schemas/ShieldType'
required: required:
- shield_id - shield_id
- shield_type
type: object type: object
RestAPIExecutionConfig: RestAPIExecutionConfig:
additionalProperties: false additionalProperties: false
@ -2698,8 +2695,6 @@ components:
type: string type: string
provider_resource_id: provider_resource_id:
type: string type: string
shield_type:
$ref: '#/components/schemas/ShieldType'
type: type:
const: shield const: shield
default: shield default: shield
@ -2709,7 +2704,6 @@ components:
- provider_resource_id - provider_resource_id
- provider_id - provider_id
- type - type
- shield_type
title: A safety shield resource that can be used to check content title: A safety shield resource that can be used to check content
type: object type: object
ShieldCallStep: ShieldCallStep:
@ -2736,13 +2730,6 @@ components:
- step_id - step_id
- step_type - step_type
type: object type: object
ShieldType:
enum:
- generic_content_shield
- llama_guard
- code_scanner
- prompt_guard
type: string
SpanEndPayload: SpanEndPayload:
additionalProperties: false additionalProperties: false
properties: properties:
@ -3397,7 +3384,7 @@ info:
description: "This is the specification of the llama stack that provides\n \ description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\ \ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\ \ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871" \ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
title: '[DRAFT] Llama Stack Specification' title: '[DRAFT] Llama Stack Specification'
version: 0.0.1 version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -4761,24 +4748,24 @@ security:
servers: servers:
- url: http://any-hosted-llama-stack.com - url: http://any-hosted-llama-stack.com
tags: tags:
- name: Inference - name: MemoryBanks
- name: BatchInference
- name: Agents - name: Agents
- name: Telemetry - name: Inference
- name: DatasetIO
- name: Eval - name: Eval
- name: Models - name: Models
- name: Inspect
- name: EvalTasks
- name: ScoringFunctions
- name: Memory
- name: Safety
- name: DatasetIO
- name: MemoryBanks
- name: Shields
- name: PostTraining - name: PostTraining
- name: ScoringFunctions
- name: Datasets - name: Datasets
- name: Scoring - name: Shields
- name: Telemetry
- name: Inspect
- name: Safety
- name: SyntheticDataGeneration - name: SyntheticDataGeneration
- name: BatchInference - name: Memory
- name: Scoring
- name: EvalTasks
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" /> - description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage" - description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
@ -5046,8 +5033,6 @@ tags:
<SchemaDefinition schemaRef="#/components/schemas/Shield" />' <SchemaDefinition schemaRef="#/components/schemas/Shield" />'
name: Shield name: Shield
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldType" />
name: ShieldType
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" /> - description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
name: Trace name: Trace
- description: 'Checkpoint created during training runs - description: 'Checkpoint created during training runs
@ -5343,7 +5328,6 @@ x-tagGroups:
- Session - Session
- Shield - Shield
- ShieldCallStep - ShieldCallStep
- ShieldType
- SpanEndPayload - SpanEndPayload
- SpanStartPayload - SpanStartPayload
- SpanStatus - SpanStatus

View file

@ -182,13 +182,13 @@
" pass\n", " pass\n",
"\n", "\n",
" async def run_shield(\n", " async def run_shield(\n",
" self, shield_type: str, messages: List[dict]\n", " self, shield_id: str, messages: List[dict]\n",
" ) -> RunShieldResponse:\n", " ) -> RunShieldResponse:\n",
" async with httpx.AsyncClient() as client:\n", " async with httpx.AsyncClient() as client:\n",
" response = await client.post(\n", " response = await client.post(\n",
" f\"{self.base_url}/safety/run_shield\",\n", " f\"{self.base_url}/safety/run_shield\",\n",
" json=dict(\n", " json=dict(\n",
" shield_type=shield_type,\n", " shield_id=shield_id,\n",
" messages=[encodable_dict(m) for m in messages],\n", " messages=[encodable_dict(m) for m in messages],\n",
" ),\n", " ),\n",
" headers={\n", " headers={\n",
@ -216,7 +216,7 @@
" ]:\n", " ]:\n",
" cprint(f\"User>{message['content']}\", \"green\")\n", " cprint(f\"User>{message['content']}\", \"green\")\n",
" response = await client.run_shield(\n", " response = await client.run_shield(\n",
" shield_type=\"llama_guard\",\n", " shield_id=\"Llama-Guard-3-1B\",\n",
" messages=[message],\n", " messages=[message],\n",
" )\n", " )\n",
" print(response)\n", " print(response)\n",

View file

@ -38,7 +38,6 @@ class Dataset(CommonDatasetFields, Resource):
return self.provider_resource_id return self.provider_resource_id
@json_schema_type
class DatasetInput(CommonDatasetFields, BaseModel): class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str dataset_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None

View file

@ -34,7 +34,6 @@ class EvalTask(CommonEvalTaskFields, Resource):
return self.provider_resource_id return self.provider_resource_id
@json_schema_type
class EvalTaskInput(CommonEvalTaskFields, BaseModel): class EvalTaskInput(CommonEvalTaskFields, BaseModel):
eval_task_id: str eval_task_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None

View file

@ -122,7 +122,6 @@ MemoryBank = Annotated[
] ]
@json_schema_type
class MemoryBankInput(BaseModel): class MemoryBankInput(BaseModel):
memory_bank_id: str memory_bank_id: str
params: BankParams params: BankParams

View file

@ -32,7 +32,6 @@ class Model(CommonModelFields, Resource):
return self.provider_resource_id return self.provider_resource_id
@json_schema_type
class ModelInput(CommonModelFields): class ModelInput(CommonModelFields):
model_id: str model_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None

View file

@ -96,7 +96,6 @@ class ScoringFn(CommonScoringFnFields, Resource):
return self.provider_resource_id return self.provider_resource_id
@json_schema_type
class ScoringFnInput(CommonScoringFnFields, BaseModel): class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str scoring_fn_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None

View file

@ -37,7 +37,6 @@ class ShieldsClient(Shields):
async def register_shield( async def register_shield(
self, self,
shield_id: str, shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str], provider_shield_id: Optional[str],
provider_id: Optional[str], provider_id: Optional[str],
params: Optional[Dict[str, Any]], params: Optional[Dict[str, Any]],
@ -47,7 +46,6 @@ class ShieldsClient(Shields):
f"{self.base_url}/shields/register", f"{self.base_url}/shields/register",
json={ json={
"shield_id": shield_id, "shield_id": shield_id,
"shield_type": shield_type,
"provider_shield_id": provider_shield_id, "provider_shield_id": provider_shield_id,
"provider_id": provider_id, "provider_id": provider_id,
"params": params, "params": params,
@ -56,12 +54,12 @@ class ShieldsClient(Shields):
) )
response.raise_for_status() response.raise_for_status()
async def get_shield(self, shield_type: str) -> Optional[Shield]: async def get_shield(self, shield_id: str) -> Optional[Shield]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/shields/get", f"{self.base_url}/shields/get",
params={ params={
"shield_type": shield_type, "shield_id": shield_id,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -13,16 +12,7 @@ from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
@json_schema_type
class ShieldType(Enum):
generic_content_shield = "generic_content_shield"
llama_guard = "llama_guard"
code_scanner = "code_scanner"
prompt_guard = "prompt_guard"
class CommonShieldFields(BaseModel): class CommonShieldFields(BaseModel):
shield_type: ShieldType
params: Optional[Dict[str, Any]] = None params: Optional[Dict[str, Any]] = None
@ -59,7 +49,6 @@ class Shields(Protocol):
async def register_shield( async def register_shield(
self, self,
shield_id: str, shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str] = None, provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,

View file

@ -172,13 +172,12 @@ class SafetyRouter(Safety):
async def register_shield( async def register_shield(
self, self,
shield_id: str, shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str] = None, provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
return await self.routing_table.register_shield( return await self.routing_table.register_shield(
shield_id, shield_type, provider_shield_id, provider_id, params shield_id, provider_shield_id, provider_id, params
) )
async def run_shield( async def run_shield(

View file

@ -241,7 +241,6 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def register_shield( async def register_shield(
self, self,
shield_id: str, shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str] = None, provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
@ -260,7 +259,6 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
params = {} params = {}
shield = Shield( shield = Shield(
identifier=shield_id, identifier=shield_id,
shield_type=shield_type,
provider_resource_id=provider_shield_id, provider_resource_id=provider_shield_id,
provider_id=provider_id, provider_id=provider_id,
params=params, params=params,

View file

@ -14,6 +14,12 @@ from .config import CodeScannerConfig
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner",
"CodeShield",
]
class MetaReferenceCodeScannerSafetyImpl(Safety): class MetaReferenceCodeScannerSafetyImpl(Safety):
def __init__(self, config: CodeScannerConfig, deps) -> None: def __init__(self, config: CodeScannerConfig, deps) -> None:
self.config = config self.config = config
@ -25,8 +31,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
pass pass
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.code_scanner: if shield.provider_resource_id not in ALLOWED_CODE_SCANNER_MODEL_IDS:
raise ValueError(f"Unsupported safety shield type: {shield.shield_type}") raise ValueError(
f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
)
async def run_shield( async def run_shield(
self, self,

View file

@ -6,32 +6,8 @@
from typing import List from typing import List
from llama_models.sku_list import CoreModelId, safety_models from pydantic import BaseModel
from pydantic import BaseModel, field_validator
class LlamaGuardConfig(BaseModel): class LlamaGuardConfig(BaseModel):
model: str = "Llama-Guard-3-1B"
excluded_categories: List[str] = [] excluded_categories: List[str] = []
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in safety_models()
if (
m.core_model_id
in {
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_1b,
CoreModelId.llama_guard_3_11b_vision,
}
)
]
if model not in permitted_models:
raise ValueError(
f"Invalid model: {model}. Must be one of {permitted_models}"
)
return model

View file

@ -73,6 +73,11 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_ELECTIONS, CAT_ELECTIONS,
] ]
LLAMA_GUARD_MODEL_IDS = [
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
MODEL_TO_SAFETY_CATEGORIES_MAP = { MODEL_TO_SAFETY_CATEGORIES_MAP = {
CoreModelId.llama_guard_3_8b.value: ( CoreModelId.llama_guard_3_8b.value: (
@ -118,18 +123,16 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
self.inference_api = deps[Api.inference] self.inference_api = deps[Api.inference]
async def initialize(self) -> None: async def initialize(self) -> None:
self.shield = LlamaGuardShield( pass
model=self.config.model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
)
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.llama_guard: if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS:
raise ValueError(f"Unsupported shield type: {shield.shield_type}") raise ValueError(
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}"
)
async def run_shield( async def run_shield(
self, self,
@ -147,7 +150,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if len(messages) > 0 and messages[0].role != Role.user.value: if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content) messages[0] = UserMessage(content=messages[0].content)
return await self.shield.run(messages) impl = LlamaGuardShield(
model=shield.provider_resource_id,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
)
return await impl.run(messages)
class LlamaGuardShield: class LlamaGuardShield:

View file

@ -36,8 +36,10 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass pass
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.prompt_guard: if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError(f"Unsupported shield type: {shield.shield_type}") raise ValueError(
f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. "
)
async def run_shield( async def run_shield(
self, self,

View file

@ -20,11 +20,6 @@ from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BEDROCK_SUPPORTED_SHIELDS = [
ShieldType.generic_content_shield,
]
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: BedrockSafetyConfig) -> None: def __init__(self, config: BedrockSafetyConfig) -> None:
self.config = config self.config = config

View file

@ -81,15 +81,17 @@ async def create_agent_session(agents_impl, agent_config):
class TestAgents: class TestAgents:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_agent_turns_with_safety(self, agents_stack, common_params): async def test_agent_turns_with_safety(
self, safety_model, agents_stack, common_params
):
agents_impl, _ = agents_stack agents_impl, _ = agents_stack
agent_id, session_id = await create_agent_session( agent_id, session_id = await create_agent_session(
agents_impl, agents_impl,
AgentConfig( AgentConfig(
**{ **{
**common_params, **common_params,
"input_shields": ["llama_guard"], "input_shields": [safety_model],
"output_shields": ["llama_guard"], "output_shields": [safety_model],
} }
), ),
) )

View file

@ -9,7 +9,7 @@ import pytest_asyncio
from llama_stack.apis.models import ModelInput from llama_stack.apis.models import ModelInput
from llama_stack.apis.shields import ShieldInput, ShieldType from llama_stack.apis.shields import ShieldInput
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
@ -41,7 +41,7 @@ def safety_llama_guard(safety_model) -> ProviderFixture:
Provider( Provider(
provider_id="inline::llama-guard", provider_id="inline::llama-guard",
provider_type="inline::llama-guard", provider_type="inline::llama-guard",
config=LlamaGuardConfig(model=safety_model).model_dump(), config=LlamaGuardConfig().model_dump(),
) )
], ],
) )
@ -114,20 +114,14 @@ async def safety_stack(inference_model, safety_model, request):
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput: def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
shield_config = {} if provider_type == "remote::bedrock":
shield_type = ShieldType.llama_guard
identifier = "llama_guard"
if provider_type == "meta-reference":
shield_config["model"] = safety_model
elif provider_type == "remote::together":
shield_config["model"] = safety_model
elif provider_type == "remote::bedrock":
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER") identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
shield_type = ShieldType.generic_content_shield else:
params = {}
identifier = safety_model
return ShieldInput( return ShieldInput(
shield_id=identifier, shield_id=identifier,
shield_type=shield_type, params=params,
params=shield_config,
) )

View file

@ -34,7 +34,6 @@ class TestSafety:
for shield in response: for shield in response:
assert isinstance(shield, Shield) assert isinstance(shield, Shield)
assert shield.shield_type in [v for v in ShieldType]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_shield(self, safety_stack): async def test_run_shield(self, safety_stack):