mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 00:05:18 +00:00
Kill the notion of shield_type
This commit is contained in:
parent
09269e2a44
commit
b1c3a95485
20 changed files with 87 additions and 161 deletions
|
@ -21,7 +21,7 @@
|
||||||
"info": {
|
"info": {
|
||||||
"title": "[DRAFT] Llama Stack Specification",
|
"title": "[DRAFT] Llama Stack Specification",
|
||||||
"version": "0.0.1",
|
"version": "0.0.1",
|
||||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871"
|
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
|
||||||
},
|
},
|
||||||
"servers": [
|
"servers": [
|
||||||
{
|
{
|
||||||
|
@ -5743,9 +5743,6 @@
|
||||||
"const": "shield",
|
"const": "shield",
|
||||||
"default": "shield"
|
"default": "shield"
|
||||||
},
|
},
|
||||||
"shield_type": {
|
|
||||||
"$ref": "#/components/schemas/ShieldType"
|
|
||||||
},
|
|
||||||
"params": {
|
"params": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"additionalProperties": {
|
"additionalProperties": {
|
||||||
|
@ -5777,20 +5774,10 @@
|
||||||
"identifier",
|
"identifier",
|
||||||
"provider_resource_id",
|
"provider_resource_id",
|
||||||
"provider_id",
|
"provider_id",
|
||||||
"type",
|
"type"
|
||||||
"shield_type"
|
|
||||||
],
|
],
|
||||||
"title": "A safety shield resource that can be used to check content"
|
"title": "A safety shield resource that can be used to check content"
|
||||||
},
|
},
|
||||||
"ShieldType": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"generic_content_shield",
|
|
||||||
"llama_guard",
|
|
||||||
"code_scanner",
|
|
||||||
"prompt_guard"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"Trace": {
|
"Trace": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -7262,9 +7249,6 @@
|
||||||
"shield_id": {
|
"shield_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"shield_type": {
|
|
||||||
"$ref": "#/components/schemas/ShieldType"
|
|
||||||
},
|
|
||||||
"provider_shield_id": {
|
"provider_shield_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
@ -7299,8 +7283,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"shield_id",
|
"shield_id"
|
||||||
"shield_type"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"RunEvalRequest": {
|
"RunEvalRequest": {
|
||||||
|
@ -7854,13 +7837,19 @@
|
||||||
],
|
],
|
||||||
"tags": [
|
"tags": [
|
||||||
{
|
{
|
||||||
"name": "Inference"
|
"name": "MemoryBanks"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "BatchInference"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Agents"
|
"name": "Agents"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Telemetry"
|
"name": "Inference"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "DatasetIO"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Eval"
|
"name": "Eval"
|
||||||
|
@ -7869,43 +7858,37 @@
|
||||||
"name": "Models"
|
"name": "Models"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Inspect"
|
"name": "PostTraining"
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "EvalTasks"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "ScoringFunctions"
|
"name": "ScoringFunctions"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Memory"
|
"name": "Datasets"
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Safety"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "DatasetIO"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "MemoryBanks"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Shields"
|
"name": "Shields"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "PostTraining"
|
"name": "Telemetry"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Datasets"
|
"name": "Inspect"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Scoring"
|
"name": "Safety"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "SyntheticDataGeneration"
|
"name": "SyntheticDataGeneration"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "BatchInference"
|
"name": "Memory"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Scoring"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "EvalTasks"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "BuiltinTool",
|
"name": "BuiltinTool",
|
||||||
|
@ -8255,10 +8238,6 @@
|
||||||
"name": "Shield",
|
"name": "Shield",
|
||||||
"description": "A safety shield resource that can be used to check content\n\n<SchemaDefinition schemaRef=\"#/components/schemas/Shield\" />"
|
"description": "A safety shield resource that can be used to check content\n\n<SchemaDefinition schemaRef=\"#/components/schemas/Shield\" />"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "ShieldType",
|
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ShieldType\" />"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "Trace",
|
"name": "Trace",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Trace\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Trace\" />"
|
||||||
|
@ -8614,7 +8593,6 @@
|
||||||
"Session",
|
"Session",
|
||||||
"Shield",
|
"Shield",
|
||||||
"ShieldCallStep",
|
"ShieldCallStep",
|
||||||
"ShieldType",
|
|
||||||
"SpanEndPayload",
|
"SpanEndPayload",
|
||||||
"SpanStartPayload",
|
"SpanStartPayload",
|
||||||
"SpanStatus",
|
"SpanStatus",
|
||||||
|
|
|
@ -2227,11 +2227,8 @@ components:
|
||||||
type: string
|
type: string
|
||||||
shield_id:
|
shield_id:
|
||||||
type: string
|
type: string
|
||||||
shield_type:
|
|
||||||
$ref: '#/components/schemas/ShieldType'
|
|
||||||
required:
|
required:
|
||||||
- shield_id
|
- shield_id
|
||||||
- shield_type
|
|
||||||
type: object
|
type: object
|
||||||
RestAPIExecutionConfig:
|
RestAPIExecutionConfig:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
@ -2698,8 +2695,6 @@ components:
|
||||||
type: string
|
type: string
|
||||||
provider_resource_id:
|
provider_resource_id:
|
||||||
type: string
|
type: string
|
||||||
shield_type:
|
|
||||||
$ref: '#/components/schemas/ShieldType'
|
|
||||||
type:
|
type:
|
||||||
const: shield
|
const: shield
|
||||||
default: shield
|
default: shield
|
||||||
|
@ -2709,7 +2704,6 @@ components:
|
||||||
- provider_resource_id
|
- provider_resource_id
|
||||||
- provider_id
|
- provider_id
|
||||||
- type
|
- type
|
||||||
- shield_type
|
|
||||||
title: A safety shield resource that can be used to check content
|
title: A safety shield resource that can be used to check content
|
||||||
type: object
|
type: object
|
||||||
ShieldCallStep:
|
ShieldCallStep:
|
||||||
|
@ -2736,13 +2730,6 @@ components:
|
||||||
- step_id
|
- step_id
|
||||||
- step_type
|
- step_type
|
||||||
type: object
|
type: object
|
||||||
ShieldType:
|
|
||||||
enum:
|
|
||||||
- generic_content_shield
|
|
||||||
- llama_guard
|
|
||||||
- code_scanner
|
|
||||||
- prompt_guard
|
|
||||||
type: string
|
|
||||||
SpanEndPayload:
|
SpanEndPayload:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -3397,7 +3384,7 @@ info:
|
||||||
description: "This is the specification of the llama stack that provides\n \
|
description: "This is the specification of the llama stack that provides\n \
|
||||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||||
\ to\n best leverage Llama Models. The specification is still in\
|
\ to\n best leverage Llama Models. The specification is still in\
|
||||||
\ draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871"
|
\ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
|
||||||
title: '[DRAFT] Llama Stack Specification'
|
title: '[DRAFT] Llama Stack Specification'
|
||||||
version: 0.0.1
|
version: 0.0.1
|
||||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||||
|
@ -4761,24 +4748,24 @@ security:
|
||||||
servers:
|
servers:
|
||||||
- url: http://any-hosted-llama-stack.com
|
- url: http://any-hosted-llama-stack.com
|
||||||
tags:
|
tags:
|
||||||
- name: Inference
|
- name: MemoryBanks
|
||||||
|
- name: BatchInference
|
||||||
- name: Agents
|
- name: Agents
|
||||||
- name: Telemetry
|
- name: Inference
|
||||||
|
- name: DatasetIO
|
||||||
- name: Eval
|
- name: Eval
|
||||||
- name: Models
|
- name: Models
|
||||||
- name: Inspect
|
|
||||||
- name: EvalTasks
|
|
||||||
- name: ScoringFunctions
|
|
||||||
- name: Memory
|
|
||||||
- name: Safety
|
|
||||||
- name: DatasetIO
|
|
||||||
- name: MemoryBanks
|
|
||||||
- name: Shields
|
|
||||||
- name: PostTraining
|
- name: PostTraining
|
||||||
|
- name: ScoringFunctions
|
||||||
- name: Datasets
|
- name: Datasets
|
||||||
- name: Scoring
|
- name: Shields
|
||||||
|
- name: Telemetry
|
||||||
|
- name: Inspect
|
||||||
|
- name: Safety
|
||||||
- name: SyntheticDataGeneration
|
- name: SyntheticDataGeneration
|
||||||
- name: BatchInference
|
- name: Memory
|
||||||
|
- name: Scoring
|
||||||
|
- name: EvalTasks
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
||||||
name: BuiltinTool
|
name: BuiltinTool
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
||||||
|
@ -5046,8 +5033,6 @@ tags:
|
||||||
|
|
||||||
<SchemaDefinition schemaRef="#/components/schemas/Shield" />'
|
<SchemaDefinition schemaRef="#/components/schemas/Shield" />'
|
||||||
name: Shield
|
name: Shield
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldType" />
|
|
||||||
name: ShieldType
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
|
||||||
name: Trace
|
name: Trace
|
||||||
- description: 'Checkpoint created during training runs
|
- description: 'Checkpoint created during training runs
|
||||||
|
@ -5343,7 +5328,6 @@ x-tagGroups:
|
||||||
- Session
|
- Session
|
||||||
- Shield
|
- Shield
|
||||||
- ShieldCallStep
|
- ShieldCallStep
|
||||||
- ShieldType
|
|
||||||
- SpanEndPayload
|
- SpanEndPayload
|
||||||
- SpanStartPayload
|
- SpanStartPayload
|
||||||
- SpanStatus
|
- SpanStatus
|
||||||
|
|
|
@ -182,13 +182,13 @@
|
||||||
" pass\n",
|
" pass\n",
|
||||||
"\n",
|
"\n",
|
||||||
" async def run_shield(\n",
|
" async def run_shield(\n",
|
||||||
" self, shield_type: str, messages: List[dict]\n",
|
" self, shield_id: str, messages: List[dict]\n",
|
||||||
" ) -> RunShieldResponse:\n",
|
" ) -> RunShieldResponse:\n",
|
||||||
" async with httpx.AsyncClient() as client:\n",
|
" async with httpx.AsyncClient() as client:\n",
|
||||||
" response = await client.post(\n",
|
" response = await client.post(\n",
|
||||||
" f\"{self.base_url}/safety/run_shield\",\n",
|
" f\"{self.base_url}/safety/run_shield\",\n",
|
||||||
" json=dict(\n",
|
" json=dict(\n",
|
||||||
" shield_type=shield_type,\n",
|
" shield_id=shield_id,\n",
|
||||||
" messages=[encodable_dict(m) for m in messages],\n",
|
" messages=[encodable_dict(m) for m in messages],\n",
|
||||||
" ),\n",
|
" ),\n",
|
||||||
" headers={\n",
|
" headers={\n",
|
||||||
|
@ -216,7 +216,7 @@
|
||||||
" ]:\n",
|
" ]:\n",
|
||||||
" cprint(f\"User>{message['content']}\", \"green\")\n",
|
" cprint(f\"User>{message['content']}\", \"green\")\n",
|
||||||
" response = await client.run_shield(\n",
|
" response = await client.run_shield(\n",
|
||||||
" shield_type=\"llama_guard\",\n",
|
" shield_id=\"Llama-Guard-3-1B\",\n",
|
||||||
" messages=[message],\n",
|
" messages=[message],\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" print(response)\n",
|
" print(response)\n",
|
||||||
|
|
|
@ -38,7 +38,6 @@ class Dataset(CommonDatasetFields, Resource):
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||||
dataset_id: str
|
dataset_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
|
|
|
@ -34,7 +34,6 @@ class EvalTask(CommonEvalTaskFields, Resource):
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
|
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
|
||||||
eval_task_id: str
|
eval_task_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
|
|
|
@ -122,7 +122,6 @@ MemoryBank = Annotated[
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryBankInput(BaseModel):
|
class MemoryBankInput(BaseModel):
|
||||||
memory_bank_id: str
|
memory_bank_id: str
|
||||||
params: BankParams
|
params: BankParams
|
||||||
|
|
|
@ -32,7 +32,6 @@ class Model(CommonModelFields, Resource):
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ModelInput(CommonModelFields):
|
class ModelInput(CommonModelFields):
|
||||||
model_id: str
|
model_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
|
|
|
@ -96,7 +96,6 @@ class ScoringFn(CommonScoringFnFields, Resource):
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||||
scoring_fn_id: str
|
scoring_fn_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
|
|
|
@ -37,7 +37,6 @@ class ShieldsClient(Shields):
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
shield_type: ShieldType,
|
|
||||||
provider_shield_id: Optional[str],
|
provider_shield_id: Optional[str],
|
||||||
provider_id: Optional[str],
|
provider_id: Optional[str],
|
||||||
params: Optional[Dict[str, Any]],
|
params: Optional[Dict[str, Any]],
|
||||||
|
@ -47,7 +46,6 @@ class ShieldsClient(Shields):
|
||||||
f"{self.base_url}/shields/register",
|
f"{self.base_url}/shields/register",
|
||||||
json={
|
json={
|
||||||
"shield_id": shield_id,
|
"shield_id": shield_id,
|
||||||
"shield_type": shield_type,
|
|
||||||
"provider_shield_id": provider_shield_id,
|
"provider_shield_id": provider_shield_id,
|
||||||
"provider_id": provider_id,
|
"provider_id": provider_id,
|
||||||
"params": params,
|
"params": params,
|
||||||
|
@ -56,12 +54,12 @@ class ShieldsClient(Shields):
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
async def get_shield(self, shield_type: str) -> Optional[Shield]:
|
async def get_shield(self, shield_id: str) -> Optional[Shield]:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{self.base_url}/shields/get",
|
f"{self.base_url}/shields/get",
|
||||||
params={
|
params={
|
||||||
"shield_type": shield_type,
|
"shield_id": shield_id,
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
@ -13,16 +12,7 @@ from pydantic import BaseModel
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ShieldType(Enum):
|
|
||||||
generic_content_shield = "generic_content_shield"
|
|
||||||
llama_guard = "llama_guard"
|
|
||||||
code_scanner = "code_scanner"
|
|
||||||
prompt_guard = "prompt_guard"
|
|
||||||
|
|
||||||
|
|
||||||
class CommonShieldFields(BaseModel):
|
class CommonShieldFields(BaseModel):
|
||||||
shield_type: ShieldType
|
|
||||||
params: Optional[Dict[str, Any]] = None
|
params: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,7 +49,6 @@ class Shields(Protocol):
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
shield_type: ShieldType,
|
|
||||||
provider_shield_id: Optional[str] = None,
|
provider_shield_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
|
|
@ -172,13 +172,12 @@ class SafetyRouter(Safety):
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
shield_type: ShieldType,
|
|
||||||
provider_shield_id: Optional[str] = None,
|
provider_shield_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
return await self.routing_table.register_shield(
|
return await self.routing_table.register_shield(
|
||||||
shield_id, shield_type, provider_shield_id, provider_id, params
|
shield_id, provider_shield_id, provider_id, params
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
|
|
|
@ -241,7 +241,6 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
shield_type: ShieldType,
|
|
||||||
provider_shield_id: Optional[str] = None,
|
provider_shield_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
@ -260,7 +259,6 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
params = {}
|
params = {}
|
||||||
shield = Shield(
|
shield = Shield(
|
||||||
identifier=shield_id,
|
identifier=shield_id,
|
||||||
shield_type=shield_type,
|
|
||||||
provider_resource_id=provider_shield_id,
|
provider_resource_id=provider_shield_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
params=params,
|
params=params,
|
||||||
|
|
|
@ -14,6 +14,12 @@ from .config import CodeScannerConfig
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||||
|
"CodeScanner",
|
||||||
|
"CodeShield",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
def __init__(self, config: CodeScannerConfig, deps) -> None:
|
def __init__(self, config: CodeScannerConfig, deps) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -25,8 +31,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(self, shield: Shield) -> None:
|
async def register_shield(self, shield: Shield) -> None:
|
||||||
if shield.shield_type != ShieldType.code_scanner:
|
if shield.provider_resource_id not in ALLOWED_CODE_SCANNER_MODEL_IDS:
|
||||||
raise ValueError(f"Unsupported safety shield type: {shield.shield_type}")
|
raise ValueError(
|
||||||
|
f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
|
||||||
|
)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -6,32 +6,8 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.sku_list import CoreModelId, safety_models
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from pydantic import BaseModel, field_validator
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardConfig(BaseModel):
|
class LlamaGuardConfig(BaseModel):
|
||||||
model: str = "Llama-Guard-3-1B"
|
|
||||||
excluded_categories: List[str] = []
|
excluded_categories: List[str] = []
|
||||||
|
|
||||||
@field_validator("model")
|
|
||||||
@classmethod
|
|
||||||
def validate_model(cls, model: str) -> str:
|
|
||||||
permitted_models = [
|
|
||||||
m.descriptor()
|
|
||||||
for m in safety_models()
|
|
||||||
if (
|
|
||||||
m.core_model_id
|
|
||||||
in {
|
|
||||||
CoreModelId.llama_guard_3_8b,
|
|
||||||
CoreModelId.llama_guard_3_1b,
|
|
||||||
CoreModelId.llama_guard_3_11b_vision,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
if model not in permitted_models:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid model: {model}. Must be one of {permitted_models}"
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
|
@ -73,6 +73,11 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
||||||
CAT_ELECTIONS,
|
CAT_ELECTIONS,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
LLAMA_GUARD_MODEL_IDS = [
|
||||||
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
|
CoreModelId.llama_guard_3_1b.value,
|
||||||
|
CoreModelId.llama_guard_3_11b_vision.value,
|
||||||
|
]
|
||||||
|
|
||||||
MODEL_TO_SAFETY_CATEGORIES_MAP = {
|
MODEL_TO_SAFETY_CATEGORIES_MAP = {
|
||||||
CoreModelId.llama_guard_3_8b.value: (
|
CoreModelId.llama_guard_3_8b.value: (
|
||||||
|
@ -118,18 +123,16 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
self.inference_api = deps[Api.inference]
|
self.inference_api = deps[Api.inference]
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.shield = LlamaGuardShield(
|
pass
|
||||||
model=self.config.model,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
excluded_categories=self.config.excluded_categories,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(self, shield: Shield) -> None:
|
async def register_shield(self, shield: Shield) -> None:
|
||||||
if shield.shield_type != ShieldType.llama_guard:
|
if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS:
|
||||||
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
|
raise ValueError(
|
||||||
|
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}"
|
||||||
|
)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
@ -147,7 +150,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||||
messages[0] = UserMessage(content=messages[0].content)
|
messages[0] = UserMessage(content=messages[0].content)
|
||||||
|
|
||||||
return await self.shield.run(messages)
|
impl = LlamaGuardShield(
|
||||||
|
model=shield.provider_resource_id,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
excluded_categories=self.config.excluded_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await impl.run(messages)
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardShield:
|
class LlamaGuardShield:
|
||||||
|
|
|
@ -36,8 +36,10 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(self, shield: Shield) -> None:
|
async def register_shield(self, shield: Shield) -> None:
|
||||||
if shield.shield_type != ShieldType.prompt_guard:
|
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
|
||||||
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
|
raise ValueError(
|
||||||
|
f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. "
|
||||||
|
)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -20,11 +20,6 @@ from .config import BedrockSafetyConfig
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
BEDROCK_SUPPORTED_SHIELDS = [
|
|
||||||
ShieldType.generic_content_shield,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
|
@ -81,15 +81,17 @@ async def create_agent_session(agents_impl, agent_config):
|
||||||
|
|
||||||
class TestAgents:
|
class TestAgents:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_turns_with_safety(self, agents_stack, common_params):
|
async def test_agent_turns_with_safety(
|
||||||
|
self, safety_model, agents_stack, common_params
|
||||||
|
):
|
||||||
agents_impl, _ = agents_stack
|
agents_impl, _ = agents_stack
|
||||||
agent_id, session_id = await create_agent_session(
|
agent_id, session_id = await create_agent_session(
|
||||||
agents_impl,
|
agents_impl,
|
||||||
AgentConfig(
|
AgentConfig(
|
||||||
**{
|
**{
|
||||||
**common_params,
|
**common_params,
|
||||||
"input_shields": ["llama_guard"],
|
"input_shields": [safety_model],
|
||||||
"output_shields": ["llama_guard"],
|
"output_shields": [safety_model],
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,7 +9,7 @@ import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelInput
|
from llama_stack.apis.models import ModelInput
|
||||||
|
|
||||||
from llama_stack.apis.shields import ShieldInput, ShieldType
|
from llama_stack.apis.shields import ShieldInput
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||||
|
@ -41,7 +41,7 @@ def safety_llama_guard(safety_model) -> ProviderFixture:
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="inline::llama-guard",
|
provider_id="inline::llama-guard",
|
||||||
provider_type="inline::llama-guard",
|
provider_type="inline::llama-guard",
|
||||||
config=LlamaGuardConfig(model=safety_model).model_dump(),
|
config=LlamaGuardConfig().model_dump(),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -114,20 +114,14 @@ async def safety_stack(inference_model, safety_model, request):
|
||||||
|
|
||||||
|
|
||||||
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
|
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
|
||||||
shield_config = {}
|
if provider_type == "remote::bedrock":
|
||||||
shield_type = ShieldType.llama_guard
|
|
||||||
identifier = "llama_guard"
|
|
||||||
if provider_type == "meta-reference":
|
|
||||||
shield_config["model"] = safety_model
|
|
||||||
elif provider_type == "remote::together":
|
|
||||||
shield_config["model"] = safety_model
|
|
||||||
elif provider_type == "remote::bedrock":
|
|
||||||
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
||||||
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
|
||||||
shield_type = ShieldType.generic_content_shield
|
else:
|
||||||
|
params = {}
|
||||||
|
identifier = safety_model
|
||||||
|
|
||||||
return ShieldInput(
|
return ShieldInput(
|
||||||
shield_id=identifier,
|
shield_id=identifier,
|
||||||
shield_type=shield_type,
|
params=params,
|
||||||
params=shield_config,
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -34,7 +34,6 @@ class TestSafety:
|
||||||
|
|
||||||
for shield in response:
|
for shield in response:
|
||||||
assert isinstance(shield, Shield)
|
assert isinstance(shield, Shield)
|
||||||
assert shield.shield_type in [v for v in ShieldType]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_shield(self, safety_stack):
|
async def test_run_shield(self, safety_stack):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue