Remove the "ShieldType" concept (#430)

# What does this PR do?

This PR kills the notion of "ShieldType". The impetus for this is the
realization:

> Why is keyword llama-guard appearing so many times everywhere,
sometimes with hyphens, sometimes with underscores?

Now that we have a notion of "provider specific resource identifiers"
and "user specific aliases" for those and the fact that this works with
models ("Llama3.1-8B-Instruct" <> "fireworks/llama-3pv1-..."), we can
follow the same rules for Shields.

So each Safety provider can make up a notion of identifiers it has
registered. This already happens with Bedrock correctly. We just
generalize it for Llama Guard, Prompt Guard, etc.

For Llama Guard, we further simplify by just adopting the underlying
model name itself as the identifier! No confusion necessary.

While doing this, I noticed a bug in our DistributionRegistry where we
weren't scoping identifiers by type. Fixed.

## Feature/Issue validation/testing/test plan

Ran (inference, safety, memory, agents) tests with ollama and fireworks
providers.
This commit is contained in:
Ashwin Bharambe 2024-11-12 12:37:24 -08:00 committed by GitHub
parent 09269e2a44
commit 983d6ce2df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 150 additions and 209 deletions

View file

@ -4,7 +4,8 @@ In short, provide a summary of what this PR does and why. Usually, the relevant
- [ ] Addresses issue (#issue) - [ ] Addresses issue (#issue)
## Feature/Issue validation/testing/test plan
## Test Plan
Please describe: Please describe:
- tests you ran to verify your changes with result summaries. - tests you ran to verify your changes with result summaries.

View file

@ -180,8 +180,8 @@
" tools=tools,\n", " tools=tools,\n",
" tool_choice=\"auto\",\n", " tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n", " tool_prompt_format=\"json\",\n",
" input_shields=[\"llama_guard\"],\n", " input_shields=[\"Llama-Guard-3-1B\"],\n",
" output_shields=[\"llama_guard\"],\n", " output_shields=[\"Llama-Guard-3-1B\"],\n",
" enable_session_persistence=True,\n", " enable_session_persistence=True,\n",
" )\n", " )\n",
"\n", "\n",

View file

@ -21,7 +21,7 @@
"info": { "info": {
"title": "[DRAFT] Llama Stack Specification", "title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1", "version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871" "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
}, },
"servers": [ "servers": [
{ {
@ -5743,9 +5743,6 @@
"const": "shield", "const": "shield",
"default": "shield" "default": "shield"
}, },
"shield_type": {
"$ref": "#/components/schemas/ShieldType"
},
"params": { "params": {
"type": "object", "type": "object",
"additionalProperties": { "additionalProperties": {
@ -5777,20 +5774,10 @@
"identifier", "identifier",
"provider_resource_id", "provider_resource_id",
"provider_id", "provider_id",
"type", "type"
"shield_type"
], ],
"title": "A safety shield resource that can be used to check content" "title": "A safety shield resource that can be used to check content"
}, },
"ShieldType": {
"type": "string",
"enum": [
"generic_content_shield",
"llama_guard",
"code_scanner",
"prompt_guard"
]
},
"Trace": { "Trace": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -7262,9 +7249,6 @@
"shield_id": { "shield_id": {
"type": "string" "type": "string"
}, },
"shield_type": {
"$ref": "#/components/schemas/ShieldType"
},
"provider_shield_id": { "provider_shield_id": {
"type": "string" "type": "string"
}, },
@ -7299,8 +7283,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"shield_id", "shield_id"
"shield_type"
] ]
}, },
"RunEvalRequest": { "RunEvalRequest": {
@ -7854,13 +7837,19 @@
], ],
"tags": [ "tags": [
{ {
"name": "Inference" "name": "MemoryBanks"
},
{
"name": "BatchInference"
}, },
{ {
"name": "Agents" "name": "Agents"
}, },
{ {
"name": "Telemetry" "name": "Inference"
},
{
"name": "DatasetIO"
}, },
{ {
"name": "Eval" "name": "Eval"
@ -7869,43 +7858,37 @@
"name": "Models" "name": "Models"
}, },
{ {
"name": "Inspect" "name": "PostTraining"
},
{
"name": "EvalTasks"
}, },
{ {
"name": "ScoringFunctions" "name": "ScoringFunctions"
}, },
{ {
"name": "Memory" "name": "Datasets"
},
{
"name": "Safety"
},
{
"name": "DatasetIO"
},
{
"name": "MemoryBanks"
}, },
{ {
"name": "Shields" "name": "Shields"
}, },
{ {
"name": "PostTraining" "name": "Telemetry"
}, },
{ {
"name": "Datasets" "name": "Inspect"
}, },
{ {
"name": "Scoring" "name": "Safety"
}, },
{ {
"name": "SyntheticDataGeneration" "name": "SyntheticDataGeneration"
}, },
{ {
"name": "BatchInference" "name": "Memory"
},
{
"name": "Scoring"
},
{
"name": "EvalTasks"
}, },
{ {
"name": "BuiltinTool", "name": "BuiltinTool",
@ -8255,10 +8238,6 @@
"name": "Shield", "name": "Shield",
"description": "A safety shield resource that can be used to check content\n\n<SchemaDefinition schemaRef=\"#/components/schemas/Shield\" />" "description": "A safety shield resource that can be used to check content\n\n<SchemaDefinition schemaRef=\"#/components/schemas/Shield\" />"
}, },
{
"name": "ShieldType",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ShieldType\" />"
},
{ {
"name": "Trace", "name": "Trace",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Trace\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/Trace\" />"
@ -8614,7 +8593,6 @@
"Session", "Session",
"Shield", "Shield",
"ShieldCallStep", "ShieldCallStep",
"ShieldType",
"SpanEndPayload", "SpanEndPayload",
"SpanStartPayload", "SpanStartPayload",
"SpanStatus", "SpanStatus",

View file

@ -2227,11 +2227,8 @@ components:
type: string type: string
shield_id: shield_id:
type: string type: string
shield_type:
$ref: '#/components/schemas/ShieldType'
required: required:
- shield_id - shield_id
- shield_type
type: object type: object
RestAPIExecutionConfig: RestAPIExecutionConfig:
additionalProperties: false additionalProperties: false
@ -2698,8 +2695,6 @@ components:
type: string type: string
provider_resource_id: provider_resource_id:
type: string type: string
shield_type:
$ref: '#/components/schemas/ShieldType'
type: type:
const: shield const: shield
default: shield default: shield
@ -2709,7 +2704,6 @@ components:
- provider_resource_id - provider_resource_id
- provider_id - provider_id
- type - type
- shield_type
title: A safety shield resource that can be used to check content title: A safety shield resource that can be used to check content
type: object type: object
ShieldCallStep: ShieldCallStep:
@ -2736,13 +2730,6 @@ components:
- step_id - step_id
- step_type - step_type
type: object type: object
ShieldType:
enum:
- generic_content_shield
- llama_guard
- code_scanner
- prompt_guard
type: string
SpanEndPayload: SpanEndPayload:
additionalProperties: false additionalProperties: false
properties: properties:
@ -3397,7 +3384,7 @@ info:
description: "This is the specification of the llama stack that provides\n \ description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\ \ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\ \ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871" \ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
title: '[DRAFT] Llama Stack Specification' title: '[DRAFT] Llama Stack Specification'
version: 0.0.1 version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -4761,24 +4748,24 @@ security:
servers: servers:
- url: http://any-hosted-llama-stack.com - url: http://any-hosted-llama-stack.com
tags: tags:
- name: Inference - name: MemoryBanks
- name: BatchInference
- name: Agents - name: Agents
- name: Telemetry - name: Inference
- name: DatasetIO
- name: Eval - name: Eval
- name: Models - name: Models
- name: Inspect
- name: EvalTasks
- name: ScoringFunctions
- name: Memory
- name: Safety
- name: DatasetIO
- name: MemoryBanks
- name: Shields
- name: PostTraining - name: PostTraining
- name: ScoringFunctions
- name: Datasets - name: Datasets
- name: Scoring - name: Shields
- name: Telemetry
- name: Inspect
- name: Safety
- name: SyntheticDataGeneration - name: SyntheticDataGeneration
- name: BatchInference - name: Memory
- name: Scoring
- name: EvalTasks
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" /> - description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage" - description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
@ -5046,8 +5033,6 @@ tags:
<SchemaDefinition schemaRef="#/components/schemas/Shield" />' <SchemaDefinition schemaRef="#/components/schemas/Shield" />'
name: Shield name: Shield
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldType" />
name: ShieldType
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" /> - description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
name: Trace name: Trace
- description: 'Checkpoint created during training runs - description: 'Checkpoint created during training runs
@ -5343,7 +5328,6 @@ x-tagGroups:
- Session - Session
- Shield - Shield
- ShieldCallStep - ShieldCallStep
- ShieldType
- SpanEndPayload - SpanEndPayload
- SpanStartPayload - SpanStartPayload
- SpanStatus - SpanStatus

View file

@ -182,13 +182,13 @@
" pass\n", " pass\n",
"\n", "\n",
" async def run_shield(\n", " async def run_shield(\n",
" self, shield_type: str, messages: List[dict]\n", " self, shield_id: str, messages: List[dict]\n",
" ) -> RunShieldResponse:\n", " ) -> RunShieldResponse:\n",
" async with httpx.AsyncClient() as client:\n", " async with httpx.AsyncClient() as client:\n",
" response = await client.post(\n", " response = await client.post(\n",
" f\"{self.base_url}/safety/run_shield\",\n", " f\"{self.base_url}/safety/run_shield\",\n",
" json=dict(\n", " json=dict(\n",
" shield_type=shield_type,\n", " shield_id=shield_id,\n",
" messages=[encodable_dict(m) for m in messages],\n", " messages=[encodable_dict(m) for m in messages],\n",
" ),\n", " ),\n",
" headers={\n", " headers={\n",
@ -216,7 +216,7 @@
" ]:\n", " ]:\n",
" cprint(f\"User>{message['content']}\", \"green\")\n", " cprint(f\"User>{message['content']}\", \"green\")\n",
" response = await client.run_shield(\n", " response = await client.run_shield(\n",
" shield_type=\"llama_guard\",\n", " shield_id=\"Llama-Guard-3-1B\",\n",
" messages=[message],\n", " messages=[message],\n",
" )\n", " )\n",
" print(response)\n", " print(response)\n",

View file

@ -38,7 +38,6 @@ class Dataset(CommonDatasetFields, Resource):
return self.provider_resource_id return self.provider_resource_id
@json_schema_type
class DatasetInput(CommonDatasetFields, BaseModel): class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str dataset_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None

View file

@ -34,7 +34,6 @@ class EvalTask(CommonEvalTaskFields, Resource):
return self.provider_resource_id return self.provider_resource_id
@json_schema_type
class EvalTaskInput(CommonEvalTaskFields, BaseModel): class EvalTaskInput(CommonEvalTaskFields, BaseModel):
eval_task_id: str eval_task_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None

View file

@ -122,7 +122,6 @@ MemoryBank = Annotated[
] ]
@json_schema_type
class MemoryBankInput(BaseModel): class MemoryBankInput(BaseModel):
memory_bank_id: str memory_bank_id: str
params: BankParams params: BankParams

View file

@ -32,7 +32,6 @@ class Model(CommonModelFields, Resource):
return self.provider_resource_id return self.provider_resource_id
@json_schema_type
class ModelInput(CommonModelFields): class ModelInput(CommonModelFields):
model_id: str model_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None

View file

@ -27,7 +27,7 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
def encodable_dict(d: BaseModel): def encodable_dict(d: BaseModel):
return json.loads(d.json()) return json.loads(d.model_dump_json())
class SafetyClient(Safety): class SafetyClient(Safety):
@ -80,7 +80,7 @@ async def run_main(host: str, port: int, image_path: str = None):
) )
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
response = await client.run_shield( response = await client.run_shield(
shield_id="llama_guard", shield_id="Llama-Guard-3-1B",
messages=[message], messages=[message],
) )
print(response) print(response)

View file

@ -96,7 +96,6 @@ class ScoringFn(CommonScoringFnFields, Resource):
return self.provider_resource_id return self.provider_resource_id
@json_schema_type
class ScoringFnInput(CommonScoringFnFields, BaseModel): class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str scoring_fn_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None

View file

@ -37,7 +37,6 @@ class ShieldsClient(Shields):
async def register_shield( async def register_shield(
self, self,
shield_id: str, shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str], provider_shield_id: Optional[str],
provider_id: Optional[str], provider_id: Optional[str],
params: Optional[Dict[str, Any]], params: Optional[Dict[str, Any]],
@ -47,7 +46,6 @@ class ShieldsClient(Shields):
f"{self.base_url}/shields/register", f"{self.base_url}/shields/register",
json={ json={
"shield_id": shield_id, "shield_id": shield_id,
"shield_type": shield_type,
"provider_shield_id": provider_shield_id, "provider_shield_id": provider_shield_id,
"provider_id": provider_id, "provider_id": provider_id,
"params": params, "params": params,
@ -56,12 +54,12 @@ class ShieldsClient(Shields):
) )
response.raise_for_status() response.raise_for_status()
async def get_shield(self, shield_type: str) -> Optional[Shield]: async def get_shield(self, shield_id: str) -> Optional[Shield]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/shields/get", f"{self.base_url}/shields/get",
params={ params={
"shield_type": shield_type, "shield_id": shield_id,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -13,16 +12,7 @@ from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
@json_schema_type
class ShieldType(Enum):
generic_content_shield = "generic_content_shield"
llama_guard = "llama_guard"
code_scanner = "code_scanner"
prompt_guard = "prompt_guard"
class CommonShieldFields(BaseModel): class CommonShieldFields(BaseModel):
shield_type: ShieldType
params: Optional[Dict[str, Any]] = None params: Optional[Dict[str, Any]] = None
@ -59,7 +49,6 @@ class Shields(Protocol):
async def register_shield( async def register_shield(
self, self,
shield_id: str, shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str] = None, provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,

View file

@ -172,13 +172,12 @@ class SafetyRouter(Safety):
async def register_shield( async def register_shield(
self, self,
shield_id: str, shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str] = None, provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
return await self.routing_table.register_shield( return await self.routing_table.register_shield(
shield_id, shield_type, provider_shield_id, provider_id, params shield_id, provider_shield_id, provider_id, params
) )
async def run_shield( async def run_shield(

View file

@ -136,17 +136,18 @@ class CommonRoutingTableImpl(RoutingTable):
else: else:
raise ValueError("Unknown routing table type") raise ValueError("Unknown routing table type")
apiname, objtype = apiname_object()
# Get objects from disk registry # Get objects from disk registry
objects = self.dist_registry.get_cached(routing_key) objects = self.dist_registry.get_cached(objtype, routing_key)
if not objects: if not objects:
apiname, objname = apiname_object()
provider_ids = list(self.impls_by_provider_id.keys()) provider_ids = list(self.impls_by_provider_id.keys())
if len(provider_ids) > 1: if len(provider_ids) > 1:
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
else: else:
provider_ids_str = f"provider: `{provider_ids[0]}`" provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError( raise ValueError(
f"{objname.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}." f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
) )
for obj in objects: for obj in objects:
@ -156,19 +157,19 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`") raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier( async def get_object_by_identifier(
self, identifier: str self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]: ) -> Optional[RoutableObjectWithProvider]:
# Get from disk registry # Get from disk registry
objects = await self.dist_registry.get(identifier) objects = await self.dist_registry.get(type, identifier)
if not objects: if not objects:
return None return None
# kind of ill-defined behavior here, but we'll just return the first one assert len(objects) == 1
return objects[0] return objects[0]
async def register_object(self, obj: RoutableObjectWithProvider): async def register_object(self, obj: RoutableObjectWithProvider):
# Get existing objects from registry # Get existing objects from registry
existing_objects = await self.dist_registry.get(obj.identifier) existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
# Check for existing registration # Check for existing registration
for existing_obj in existing_objects: for existing_obj in existing_objects:
@ -200,7 +201,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return await self.get_all_with_type("model") return await self.get_all_with_type("model")
async def get_model(self, identifier: str) -> Optional[Model]: async def get_model(self, identifier: str) -> Optional[Model]:
return await self.get_object_by_identifier(identifier) return await self.get_object_by_identifier("model", identifier)
async def register_model( async def register_model(
self, self,
@ -236,12 +237,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return await self.get_all_with_type(ResourceType.shield.value) return await self.get_all_with_type(ResourceType.shield.value)
async def get_shield(self, identifier: str) -> Optional[Shield]: async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier(identifier) return await self.get_object_by_identifier("shield", identifier)
async def register_shield( async def register_shield(
self, self,
shield_id: str, shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str] = None, provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
@ -260,7 +260,6 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
params = {} params = {}
shield = Shield( shield = Shield(
identifier=shield_id, identifier=shield_id,
shield_type=shield_type,
provider_resource_id=provider_shield_id, provider_resource_id=provider_shield_id,
provider_id=provider_id, provider_id=provider_id,
params=params, params=params,
@ -274,7 +273,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return await self.get_all_with_type(ResourceType.memory_bank.value) return await self.get_all_with_type(ResourceType.memory_bank.value)
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return await self.get_object_by_identifier(memory_bank_id) return await self.get_object_by_identifier("memory_bank", memory_bank_id)
async def register_memory_bank( async def register_memory_bank(
self, self,
@ -312,7 +311,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
return await self.get_all_with_type("dataset") return await self.get_all_with_type("dataset")
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier(dataset_id) return await self.get_object_by_identifier("dataset", dataset_id)
async def register_dataset( async def register_dataset(
self, self,
@ -348,10 +347,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]: async def list_scoring_functions(self) -> List[ScoringFn]:
return await self.get_all_with_type(ResourceType.scoring_function.value) return await self.get_all_with_type("scoring_function")
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier(scoring_fn_id) return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
async def register_scoring_function( async def register_scoring_function(
self, self,
@ -389,7 +388,7 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
return await self.get_all_with_type("eval_task") return await self.get_all_with_type("eval_task")
async def get_eval_task(self, name: str) -> Optional[EvalTask]: async def get_eval_task(self, name: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier(name) return await self.get_object_by_identifier("eval_task", name)
async def register_eval_task( async def register_eval_task(
self, self,

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict from typing import Any, Dict
from termcolor import colored
from termcolor import colored from termcolor import colored

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import Dict, List, Optional, Protocol from typing import Dict, List, Optional, Protocol, Tuple
import pydantic import pydantic
@ -35,7 +35,8 @@ class DistributionRegistry(Protocol):
async def register(self, obj: RoutableObjectWithProvider) -> bool: ... async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
KEY_FORMAT = "distributions:registry:v1::{}" KEY_VERSION = "v1"
KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}"
class DiskDistributionRegistry(DistributionRegistry): class DiskDistributionRegistry(DistributionRegistry):
@ -45,18 +46,24 @@ class DiskDistributionRegistry(DistributionRegistry):
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: def get_cached(
self, type: str, identifier: str
) -> List[RoutableObjectWithProvider]:
# Disk registry does not have a cache # Disk registry does not have a cache
return [] return []
async def get_all(self) -> List[RoutableObjectWithProvider]: async def get_all(self) -> List[RoutableObjectWithProvider]:
start_key = KEY_FORMAT.format("") start_key = KEY_FORMAT.format(type="", identifier="")
end_key = KEY_FORMAT.format("\xff") end_key = KEY_FORMAT.format(type="", identifier="\xff")
keys = await self.kvstore.range(start_key, end_key) keys = await self.kvstore.range(start_key, end_key)
return [await self.get(key.split(":")[-1]) for key in keys]
async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys]
json_str = await self.kvstore.get(KEY_FORMAT.format(identifier)) return [await self.get(type, identifier) for type, identifier in tuples]
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(
KEY_FORMAT.format(type=type, identifier=identifier)
)
if not json_str: if not json_str:
return [] return []
@ -70,7 +77,7 @@ class DiskDistributionRegistry(DistributionRegistry):
] ]
async def register(self, obj: RoutableObjectWithProvider) -> bool: async def register(self, obj: RoutableObjectWithProvider) -> bool:
existing_objects = await self.get(obj.identifier) existing_objects = await self.get(obj.type, obj.identifier)
# dont register if the object's providerid already exists # dont register if the object's providerid already exists
for eobj in existing_objects: for eobj in existing_objects:
if eobj.provider_id == obj.provider_id: if eobj.provider_id == obj.provider_id:
@ -82,7 +89,8 @@ class DiskDistributionRegistry(DistributionRegistry):
obj.model_dump_json() for obj in existing_objects obj.model_dump_json() for obj in existing_objects
] # Fixed variable name ] # Fixed variable name
await self.kvstore.set( await self.kvstore.set(
KEY_FORMAT.format(obj.identifier), json.dumps(objects_json) KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
json.dumps(objects_json),
) )
return True return True
@ -90,33 +98,36 @@ class DiskDistributionRegistry(DistributionRegistry):
class CachedDiskDistributionRegistry(DiskDistributionRegistry): class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore): def __init__(self, kvstore: KVStore):
super().__init__(kvstore) super().__init__(kvstore)
self.cache: Dict[str, List[RoutableObjectWithProvider]] = {} self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {}
async def initialize(self) -> None: async def initialize(self) -> None:
start_key = KEY_FORMAT.format("") start_key = KEY_FORMAT.format(type="", identifier="")
end_key = KEY_FORMAT.format("\xff") end_key = KEY_FORMAT.format(type="", identifier="\xff")
keys = await self.kvstore.range(start_key, end_key) keys = await self.kvstore.range(start_key, end_key)
for key in keys: for key in keys:
identifier = key.split(":")[-1] type, identifier = key.split(":")[-2:]
objects = await super().get(identifier) objects = await super().get(type, identifier)
if objects: if objects:
self.cache[identifier] = objects self.cache[type, identifier] = objects
def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: def get_cached(
return self.cache.get(identifier, []) self, type: str, identifier: str
) -> List[RoutableObjectWithProvider]:
return self.cache.get((type, identifier), [])
async def get_all(self) -> List[RoutableObjectWithProvider]: async def get_all(self) -> List[RoutableObjectWithProvider]:
return [item for sublist in self.cache.values() for item in sublist] return [item for sublist in self.cache.values() for item in sublist]
async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
if identifier in self.cache: cachekey = (type, identifier)
return self.cache[identifier] if cachekey in self.cache:
return self.cache[cachekey]
objects = await super().get(identifier) objects = await super().get(type, identifier)
if objects: if objects:
self.cache[identifier] = objects self.cache[cachekey] = objects
return objects return objects
@ -126,16 +137,17 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
if success: if success:
# Then update cache # Then update cache
if obj.identifier not in self.cache: cachekey = (obj.type, obj.identifier)
self.cache[obj.identifier] = [] if cachekey not in self.cache:
self.cache[cachekey] = []
# Check if provider already exists in cache # Check if provider already exists in cache
for cached_obj in self.cache[obj.identifier]: for cached_obj in self.cache[cachekey]:
if cached_obj.provider_id == obj.provider_id: if cached_obj.provider_id == obj.provider_id:
return success return success
# If not, update cache # If not, update cache
self.cache[obj.identifier].append(obj) self.cache[cachekey].append(obj)
return success return success

View file

@ -14,6 +14,12 @@ from .config import CodeScannerConfig
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner",
"CodeShield",
]
class MetaReferenceCodeScannerSafetyImpl(Safety): class MetaReferenceCodeScannerSafetyImpl(Safety):
def __init__(self, config: CodeScannerConfig, deps) -> None: def __init__(self, config: CodeScannerConfig, deps) -> None:
self.config = config self.config = config
@ -25,8 +31,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
pass pass
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.code_scanner: if shield.provider_resource_id not in ALLOWED_CODE_SCANNER_MODEL_IDS:
raise ValueError(f"Unsupported safety shield type: {shield.shield_type}") raise ValueError(
f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
)
async def run_shield( async def run_shield(
self, self,

View file

@ -6,32 +6,8 @@
from typing import List from typing import List
from llama_models.sku_list import CoreModelId, safety_models from pydantic import BaseModel
from pydantic import BaseModel, field_validator
class LlamaGuardConfig(BaseModel): class LlamaGuardConfig(BaseModel):
model: str = "Llama-Guard-3-1B"
excluded_categories: List[str] = [] excluded_categories: List[str] = []
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in safety_models()
if (
m.core_model_id
in {
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_1b,
CoreModelId.llama_guard_3_11b_vision,
}
)
]
if model not in permitted_models:
raise ValueError(
f"Invalid model: {model}. Must be one of {permitted_models}"
)
return model

View file

@ -73,6 +73,11 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_ELECTIONS, CAT_ELECTIONS,
] ]
LLAMA_GUARD_MODEL_IDS = [
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
MODEL_TO_SAFETY_CATEGORIES_MAP = { MODEL_TO_SAFETY_CATEGORIES_MAP = {
CoreModelId.llama_guard_3_8b.value: ( CoreModelId.llama_guard_3_8b.value: (
@ -118,18 +123,16 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
self.inference_api = deps[Api.inference] self.inference_api = deps[Api.inference]
async def initialize(self) -> None: async def initialize(self) -> None:
self.shield = LlamaGuardShield( pass
model=self.config.model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
)
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.llama_guard: if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS:
raise ValueError(f"Unsupported shield type: {shield.shield_type}") raise ValueError(
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}"
)
async def run_shield( async def run_shield(
self, self,
@ -147,7 +150,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if len(messages) > 0 and messages[0].role != Role.user.value: if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content) messages[0] = UserMessage(content=messages[0].content)
return await self.shield.run(messages) impl = LlamaGuardShield(
model=shield.provider_resource_id,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
)
return await impl.run(messages)
class LlamaGuardShield: class LlamaGuardShield:

View file

@ -36,8 +36,10 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass pass
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.prompt_guard: if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError(f"Unsupported shield type: {shield.shield_type}") raise ValueError(
f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. "
)
async def run_shield( async def run_shield(
self, self,

View file

@ -20,11 +20,6 @@ from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BEDROCK_SUPPORTED_SHIELDS = [
ShieldType.generic_content_shield,
]
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: BedrockSafetyConfig) -> None: def __init__(self, config: BedrockSafetyConfig) -> None:
self.config = config self.config = config

View file

@ -44,7 +44,7 @@ def agents_meta_reference() -> ProviderFixture:
providers=[ providers=[
Provider( Provider(
provider_id="meta-reference", provider_id="meta-reference",
provider_type="meta-reference", provider_type="inline::meta-reference",
config=MetaReferenceAgentsImplConfig( config=MetaReferenceAgentsImplConfig(
# TODO: make this an in-memory store # TODO: make this an in-memory store
persistence_store=SqliteKVStoreConfig( persistence_store=SqliteKVStoreConfig(

View file

@ -81,15 +81,17 @@ async def create_agent_session(agents_impl, agent_config):
class TestAgents: class TestAgents:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_agent_turns_with_safety(self, agents_stack, common_params): async def test_agent_turns_with_safety(
self, safety_model, agents_stack, common_params
):
agents_impl, _ = agents_stack agents_impl, _ = agents_stack
agent_id, session_id = await create_agent_session( agent_id, session_id = await create_agent_session(
agents_impl, agents_impl,
AgentConfig( AgentConfig(
**{ **{
**common_params, **common_params,
"input_shields": ["llama_guard"], "input_shields": [safety_model],
"output_shields": ["llama_guard"], "output_shields": [safety_model],
} }
), ),
) )

View file

@ -9,7 +9,7 @@ import pytest_asyncio
from llama_stack.apis.models import ModelInput from llama_stack.apis.models import ModelInput
from llama_stack.apis.shields import ShieldInput, ShieldType from llama_stack.apis.shields import ShieldInput
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
@ -41,7 +41,7 @@ def safety_llama_guard(safety_model) -> ProviderFixture:
Provider( Provider(
provider_id="inline::llama-guard", provider_id="inline::llama-guard",
provider_type="inline::llama-guard", provider_type="inline::llama-guard",
config=LlamaGuardConfig(model=safety_model).model_dump(), config=LlamaGuardConfig().model_dump(),
) )
], ],
) )
@ -101,6 +101,8 @@ async def safety_stack(inference_model, safety_model, request):
shield_provider_type = safety_fixture.providers[0].provider_type shield_provider_type = safety_fixture.providers[0].provider_type
shield_input = get_shield_to_register(shield_provider_type, safety_model) shield_input = get_shield_to_register(shield_provider_type, safety_model)
print(f"inference_model: {inference_model}")
print(f"shield_input = {shield_input}")
impls = await resolve_impls_for_test_v2( impls = await resolve_impls_for_test_v2(
[Api.safety, Api.shields, Api.inference], [Api.safety, Api.shields, Api.inference],
providers, providers,
@ -114,20 +116,14 @@ async def safety_stack(inference_model, safety_model, request):
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput: def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
shield_config = {} if provider_type == "remote::bedrock":
shield_type = ShieldType.llama_guard
identifier = "llama_guard"
if provider_type == "meta-reference":
shield_config["model"] = safety_model
elif provider_type == "remote::together":
shield_config["model"] = safety_model
elif provider_type == "remote::bedrock":
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER") identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
shield_type = ShieldType.generic_content_shield else:
params = {}
identifier = safety_model
return ShieldInput( return ShieldInput(
shield_id=identifier, shield_id=identifier,
shield_type=shield_type, params=params,
params=shield_config,
) )

View file

@ -34,7 +34,6 @@ class TestSafety:
for shield in response: for shield in response:
assert isinstance(shield, Shield) assert isinstance(shield, Shield)
assert shield.shield_type in [v for v in ShieldType]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_shield(self, safety_stack): async def test_run_shield(self, safety_stack):