add safety

This commit is contained in:
Ashwin Bharambe 2024-07-09 16:08:47 -07:00
parent 256f1d5991
commit 13e1667e7a
5 changed files with 279 additions and 106 deletions

View file

@ -9,6 +9,7 @@ from model_types import (
Message, Message,
PretrainedModel, PretrainedModel,
SamplingParams, SamplingParams,
SafetyViolation,
StopReason, StopReason,
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
@ -51,13 +52,6 @@ class ToolExecutionStep(ExecutionStepBase):
tool_responses: List[ToolResponse] tool_responses: List[ToolResponse]
@dataclass
class SafetyViolation:
violation_type: str
details: str
suggested_user_response: Optional[str] = None
@dataclass @dataclass
class SafetyFilteringStep(ExecutionStepBase): class SafetyFilteringStep(ExecutionStepBase):
step_type = ExecutionStepType.safety_filtering step_type = ExecutionStepType.safety_filtering

View file

@ -18,6 +18,7 @@ from model_types import (
PretrainedModel, PretrainedModel,
SamplingParams, SamplingParams,
StopReason, StopReason,
ShieldConfig,
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
ToolResponse, ToolResponse,
@ -118,13 +119,16 @@ class AgenticSystemCreateRequest:
instructions: str instructions: str
model: InstructModel model: InstructModel
# zero-shot tool definitions as input to the model # zero-shot or built-in tool configurations as input to the model
available_tools: List[Union[BuiltinTool, ToolDefinition]] = field( available_tools: List[ToolDefinition] = field(default_factory=list)
default_factory=list
)
# tools which aren't executable are emitted as tool calls which the users can
# execute themselves.
executable_tools: Set[str] = field(default_factory=set) executable_tools: Set[str] = field(default_factory=set)
input_shields: List[ShieldConfig] = field(default_factory=list)
output_shields: List[ShieldConfig] = field(default_factory=list)
@json_schema_type @json_schema_type
@dataclass @dataclass

View file

@ -5,6 +5,28 @@ from typing import Any, Dict, List, Optional, Set, Union
from strong_typing.schema import json_schema_type from strong_typing.schema import json_schema_type
class ShieldType(Enum):
"""The type of safety shield."""
llama_guard = "llama_guard"
prompt_guard = "prompt_guard"
code_guard = "code_guard"
@json_schema_type
@dataclass
class ShieldConfig:
shield_type: ShieldType
params: Dict[str, Any] = field(default_factory=dict)
@dataclass
class SafetyViolation:
violation_type: str
details: str
suggested_user_response: Optional[str] = None
@json_schema_type( @json_schema_type(
schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"} schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
) )
@ -58,24 +80,22 @@ class ToolResponse:
response: str response: str
@dataclass
class ToolDefinition:
tool_name: str
parameters: Dict[str, Any]
# TODO: we need to document the parameters for the tool calls # TODO: we need to document the parameters for the tool calls
class BuiltinTool(Enum): class BuiltinTool(Enum):
"""
Builtin tools are tools the model is natively aware of and was potentially fine-tuned with.
"""
web_search = "web_search" web_search = "web_search"
math = "math" math = "math"
image_gen = "image_gen" image_gen = "image_gen"
code_interpreter = "code_interpreter" code_interpreter = "code_interpreter"
@dataclass
class ToolDefinition:
tool_name: Union[BuiltinTool, str]
parameters: Optional[Dict[str, Any]] = None
input_shields: List[ShieldConfig] = field(default_factory=list)
output_shields: List[ShieldConfig] = field(default_factory=list)
class StopReason(Enum): class StopReason(Enum):
""" """
Stop reasons are used to indicate why the model stopped generating text. Stop reasons are used to indicate why the model stopped generating text.
@ -117,6 +137,3 @@ class PretrainedModel(Enum):
class InstructModel(Enum): class InstructModel(Enum):
llama3_8b_chat = "llama3_8b_chat" llama3_8b_chat = "llama3_8b_chat"
llama3_70b_chat = "llama3_70b_chat" llama3_70b_chat = "llama3_70b_chat"

View file

@ -174,6 +174,50 @@
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema", "jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
"components": { "components": {
"schemas": { "schemas": {
"ShieldConfig": {
"type": "object",
"properties": {
"shield_type": {
"type": "string",
"enum": [
"llama_guard",
"prompt_guard",
"code_guard"
],
"title": "The type of safety shield."
},
"params": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"shield_type",
"params"
]
},
"AgenticSystemCreateRequest": { "AgenticSystemCreateRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -190,55 +234,67 @@
"available_tools": { "available_tools": {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [ "type": "object",
{ "properties": {
"type": "string", "tool_name": {
"enum": [ "oneOf": [
"web_search", {
"math", "type": "string",
"image_gen", "enum": [
"code_interpreter" "web_search",
], "math",
"title": "Builtin tools are tools the model is natively aware of and was potentially fine-tuned with." "image_gen",
}, "code_interpreter"
{ ]
"type": "object",
"properties": {
"tool_name": {
"type": "string"
}, },
"parameters": { {
"type": "object", "type": "string"
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
} }
},
"additionalProperties": false,
"required": [
"tool_name",
"parameters"
] ]
},
"parameters": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"input_shields": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ShieldConfig"
}
},
"output_shields": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ShieldConfig"
}
} }
},
"additionalProperties": false,
"required": [
"tool_name",
"input_shields",
"output_shields"
] ]
} }
}, },
@ -248,6 +304,18 @@
"type": "string" "type": "string"
}, },
"uniqueItems": true "uniqueItems": true
},
"input_shields": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ShieldConfig"
}
},
"output_shields": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ShieldConfig"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -255,7 +323,9 @@
"instructions", "instructions",
"model", "model",
"available_tools", "available_tools",
"executable_tools" "executable_tools",
"input_shields",
"output_shields"
] ]
}, },
"AgenticSystemCreateResponse": { "AgenticSystemCreateResponse": {
@ -897,14 +967,26 @@
"math", "math",
"image_gen", "image_gen",
"code_interpreter" "code_interpreter"
], ]
"title": "Builtin tools are tools the model is natively aware of and was potentially fine-tuned with."
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"tool_name": { "tool_name": {
"type": "string" "oneOf": [
{
"type": "string",
"enum": [
"web_search",
"math",
"image_gen",
"code_interpreter"
]
},
{
"type": "string"
}
]
}, },
"parameters": { "parameters": {
"type": "object", "type": "object",
@ -930,12 +1012,25 @@
} }
] ]
} }
},
"input_shields": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ShieldConfig"
}
},
"output_shields": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ShieldConfig"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"tool_name", "tool_name",
"parameters" "input_shields",
"output_shields"
] ]
} }
] ]
@ -1344,11 +1439,15 @@
} }
], ],
"tags": [ "tags": [
{
"name": "Inference"
},
{ {
"name": "AgenticSystem" "name": "AgenticSystem"
}, },
{ {
"name": "Inference" "name": "ShieldConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ShieldConfig\" />"
}, },
{ {
"name": "AgenticSystemCreateRequest", "name": "AgenticSystemCreateRequest",
@ -1436,6 +1535,7 @@
"CompletionResponse", "CompletionResponse",
"CompletionResponseStreamChunk", "CompletionResponseStreamChunk",
"Message", "Message",
"ShieldConfig",
"URL" "URL"
] ]
} }

View file

@ -6,39 +6,50 @@ components:
properties: properties:
available_tools: available_tools:
items: items:
oneOf: additionalProperties: false
- enum: properties:
- web_search input_shields:
- math items:
- image_gen $ref: '#/components/schemas/ShieldConfig'
- code_interpreter type: array
title: Builtin tools are tools the model is natively aware of and was output_shields:
potentially fine-tuned with. items:
type: string $ref: '#/components/schemas/ShieldConfig'
- additionalProperties: false type: array
properties: parameters:
parameters: additionalProperties:
additionalProperties: oneOf:
oneOf: - type: 'null'
- type: 'null' - type: boolean
- type: boolean - type: number
- type: number - type: string
- type: string - type: array
- type: array - type: object
- type: object type: object
type: object tool_name:
tool_name: oneOf:
- enum:
- web_search
- math
- image_gen
- code_interpreter
type: string type: string
required: - type: string
- tool_name required:
- parameters - tool_name
type: object - input_shields
- output_shields
type: object
type: array type: array
executable_tools: executable_tools:
items: items:
type: string type: string
type: array type: array
uniqueItems: true uniqueItems: true
input_shields:
items:
$ref: '#/components/schemas/ShieldConfig'
type: array
instructions: instructions:
type: string type: string
model: model:
@ -46,11 +57,17 @@ components:
- llama3_8b_chat - llama3_8b_chat
- llama3_70b_chat - llama3_70b_chat
type: string type: string
output_shields:
items:
$ref: '#/components/schemas/ShieldConfig'
type: array
required: required:
- instructions - instructions
- model - model
- available_tools - available_tools
- executable_tools - executable_tools
- input_shields
- output_shields
type: object type: object
AgenticSystemCreateResponse: AgenticSystemCreateResponse:
additionalProperties: false additionalProperties: false
@ -375,11 +392,17 @@ components:
- math - math
- image_gen - image_gen
- code_interpreter - code_interpreter
title: Builtin tools are tools the model is natively aware of and was
potentially fine-tuned with.
type: string type: string
- additionalProperties: false - additionalProperties: false
properties: properties:
input_shields:
items:
$ref: '#/components/schemas/ShieldConfig'
type: array
output_shields:
items:
$ref: '#/components/schemas/ShieldConfig'
type: array
parameters: parameters:
additionalProperties: additionalProperties:
oneOf: oneOf:
@ -391,10 +414,18 @@ components:
- type: object - type: object
type: object type: object
tool_name: tool_name:
type: string oneOf:
- enum:
- web_search
- math
- image_gen
- code_interpreter
type: string
- type: string
required: required:
- tool_name - tool_name
- parameters - input_shields
- output_shields
type: object type: object
type: array type: array
logprobs: logprobs:
@ -719,6 +750,30 @@ components:
- tool_calls - tool_calls
- tool_responses - tool_responses
type: object type: object
ShieldConfig:
additionalProperties: false
properties:
params:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
shield_type:
enum:
- llama_guard
- prompt_guard
- code_guard
title: The type of safety shield.
type: string
required:
- shield_type
- params
type: object
URL: URL:
format: uri format: uri
pattern: ^(https?://|file://|data:) pattern: ^(https?://|file://|data:)
@ -815,8 +870,10 @@ security:
servers: servers:
- url: http://llama.meta.com - url: http://llama.meta.com
tags: tags:
- name: AgenticSystem
- name: Inference - name: Inference
- name: AgenticSystem
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" />
name: ShieldConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest"
/> />
name: AgenticSystemCreateRequest name: AgenticSystemCreateRequest
@ -903,4 +960,5 @@ x-tagGroups:
- CompletionResponse - CompletionResponse
- CompletionResponseStreamChunk - CompletionResponseStreamChunk
- Message - Message
- ShieldConfig
- URL - URL