add safety

This commit is contained in:
Ashwin Bharambe 2024-07-09 16:08:47 -07:00
parent 256f1d5991
commit 13e1667e7a
5 changed files with 279 additions and 106 deletions

View file

@ -9,6 +9,7 @@ from model_types import (
Message,
PretrainedModel,
SamplingParams,
SafetyViolation,
StopReason,
ToolCall,
ToolDefinition,
@ -51,13 +52,6 @@ class ToolExecutionStep(ExecutionStepBase):
tool_responses: List[ToolResponse]
@dataclass
class SafetyViolation:
violation_type: str
details: str
suggested_user_response: Optional[str] = None
@dataclass
class SafetyFilteringStep(ExecutionStepBase):
step_type = ExecutionStepType.safety_filtering

View file

@ -18,6 +18,7 @@ from model_types import (
PretrainedModel,
SamplingParams,
StopReason,
ShieldConfig,
ToolCall,
ToolDefinition,
ToolResponse,
@ -118,13 +119,16 @@ class AgenticSystemCreateRequest:
instructions: str
model: InstructModel
# zero-shot tool definitions as input to the model
available_tools: List[Union[BuiltinTool, ToolDefinition]] = field(
default_factory=list
)
# zero-shot or built-in tool configurations as input to the model
available_tools: List[ToolDefinition] = field(default_factory=list)
# tools which aren't executable are emitted as tool calls which the users can
# execute themselves.
executable_tools: Set[str] = field(default_factory=set)
input_shields: List[ShieldConfig] = field(default_factory=list)
output_shields: List[ShieldConfig] = field(default_factory=list)
@json_schema_type
@dataclass

View file

@ -5,6 +5,28 @@ from typing import Any, Dict, List, Optional, Set, Union
from strong_typing.schema import json_schema_type
class ShieldType(Enum):
"""The type of safety shield."""
llama_guard = "llama_guard"
prompt_guard = "prompt_guard"
code_guard = "code_guard"
@json_schema_type
@dataclass
class ShieldConfig:
shield_type: ShieldType
params: Dict[str, Any] = field(default_factory=dict)
@dataclass
class SafetyViolation:
violation_type: str
details: str
suggested_user_response: Optional[str] = None
@json_schema_type(
schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
)
@ -58,24 +80,22 @@ class ToolResponse:
response: str
@dataclass
class ToolDefinition:
tool_name: str
parameters: Dict[str, Any]
# TODO: we need to document the parameters for the tool calls
class BuiltinTool(Enum):
"""
Builtin tools are tools the model is natively aware of and was potentially fine-tuned with.
"""
web_search = "web_search"
math = "math"
image_gen = "image_gen"
code_interpreter = "code_interpreter"
@dataclass
class ToolDefinition:
tool_name: Union[BuiltinTool, str]
parameters: Optional[Dict[str, Any]] = None
input_shields: List[ShieldConfig] = field(default_factory=list)
output_shields: List[ShieldConfig] = field(default_factory=list)
class StopReason(Enum):
"""
Stop reasons are used to indicate why the model stopped generating text.
@ -117,6 +137,3 @@ class PretrainedModel(Enum):
class InstructModel(Enum):
llama3_8b_chat = "llama3_8b_chat"
llama3_70b_chat = "llama3_70b_chat"

View file

@ -174,6 +174,50 @@
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
"components": {
"schemas": {
"ShieldConfig": {
"type": "object",
"properties": {
"shield_type": {
"type": "string",
"enum": [
"llama_guard",
"prompt_guard",
"code_guard"
],
"title": "The type of safety shield."
},
"params": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"shield_type",
"params"
]
},
"AgenticSystemCreateRequest": {
"type": "object",
"properties": {
@ -190,55 +234,67 @@
"available_tools": {
"type": "array",
"items": {
"oneOf": [
{
"type": "string",
"enum": [
"web_search",
"math",
"image_gen",
"code_interpreter"
],
"title": "Builtin tools are tools the model is natively aware of and was potentially fine-tuned with."
},
{
"type": "object",
"properties": {
"tool_name": {
"type": "string"
"type": "object",
"properties": {
"tool_name": {
"oneOf": [
{
"type": "string",
"enum": [
"web_search",
"math",
"image_gen",
"code_interpreter"
]
},
"parameters": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
{
"type": "string"
}
},
"additionalProperties": false,
"required": [
"tool_name",
"parameters"
]
},
"parameters": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"input_shields": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ShieldConfig"
}
},
"output_shields": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ShieldConfig"
}
}
},
"additionalProperties": false,
"required": [
"tool_name",
"input_shields",
"output_shields"
]
}
},
@ -248,6 +304,18 @@
"type": "string"
},
"uniqueItems": true
},
"input_shields": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ShieldConfig"
}
},
"output_shields": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ShieldConfig"
}
}
},
"additionalProperties": false,
@ -255,7 +323,9 @@
"instructions",
"model",
"available_tools",
"executable_tools"
"executable_tools",
"input_shields",
"output_shields"
]
},
"AgenticSystemCreateResponse": {
@ -897,14 +967,26 @@
"math",
"image_gen",
"code_interpreter"
],
"title": "Builtin tools are tools the model is natively aware of and was potentially fine-tuned with."
]
},
{
"type": "object",
"properties": {
"tool_name": {
"type": "string"
"oneOf": [
{
"type": "string",
"enum": [
"web_search",
"math",
"image_gen",
"code_interpreter"
]
},
{
"type": "string"
}
]
},
"parameters": {
"type": "object",
@ -930,12 +1012,25 @@
}
]
}
},
"input_shields": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ShieldConfig"
}
},
"output_shields": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ShieldConfig"
}
}
},
"additionalProperties": false,
"required": [
"tool_name",
"parameters"
"input_shields",
"output_shields"
]
}
]
@ -1344,11 +1439,15 @@
}
],
"tags": [
{
"name": "Inference"
},
{
"name": "AgenticSystem"
},
{
"name": "Inference"
"name": "ShieldConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ShieldConfig\" />"
},
{
"name": "AgenticSystemCreateRequest",
@ -1436,6 +1535,7 @@
"CompletionResponse",
"CompletionResponseStreamChunk",
"Message",
"ShieldConfig",
"URL"
]
}

View file

@ -6,39 +6,50 @@ components:
properties:
available_tools:
items:
oneOf:
- enum:
- web_search
- math
- image_gen
- code_interpreter
title: Builtin tools are tools the model is natively aware of and was
potentially fine-tuned with.
type: string
- additionalProperties: false
properties:
parameters:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
tool_name:
additionalProperties: false
properties:
input_shields:
items:
$ref: '#/components/schemas/ShieldConfig'
type: array
output_shields:
items:
$ref: '#/components/schemas/ShieldConfig'
type: array
parameters:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
tool_name:
oneOf:
- enum:
- web_search
- math
- image_gen
- code_interpreter
type: string
required:
- tool_name
- parameters
type: object
- type: string
required:
- tool_name
- input_shields
- output_shields
type: object
type: array
executable_tools:
items:
type: string
type: array
uniqueItems: true
input_shields:
items:
$ref: '#/components/schemas/ShieldConfig'
type: array
instructions:
type: string
model:
@ -46,11 +57,17 @@ components:
- llama3_8b_chat
- llama3_70b_chat
type: string
output_shields:
items:
$ref: '#/components/schemas/ShieldConfig'
type: array
required:
- instructions
- model
- available_tools
- executable_tools
- input_shields
- output_shields
type: object
AgenticSystemCreateResponse:
additionalProperties: false
@ -375,11 +392,17 @@ components:
- math
- image_gen
- code_interpreter
title: Builtin tools are tools the model is natively aware of and was
potentially fine-tuned with.
type: string
- additionalProperties: false
properties:
input_shields:
items:
$ref: '#/components/schemas/ShieldConfig'
type: array
output_shields:
items:
$ref: '#/components/schemas/ShieldConfig'
type: array
parameters:
additionalProperties:
oneOf:
@ -391,10 +414,18 @@ components:
- type: object
type: object
tool_name:
type: string
oneOf:
- enum:
- web_search
- math
- image_gen
- code_interpreter
type: string
- type: string
required:
- tool_name
- parameters
- input_shields
- output_shields
type: object
type: array
logprobs:
@ -719,6 +750,30 @@ components:
- tool_calls
- tool_responses
type: object
ShieldConfig:
additionalProperties: false
properties:
params:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
shield_type:
enum:
- llama_guard
- prompt_guard
- code_guard
title: The type of safety shield.
type: string
required:
- shield_type
- params
type: object
URL:
format: uri
pattern: ^(https?://|file://|data:)
@ -815,8 +870,10 @@ security:
servers:
- url: http://llama.meta.com
tags:
- name: AgenticSystem
- name: Inference
- name: AgenticSystem
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" />
name: ShieldConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest"
/>
name: AgenticSystemCreateRequest
@ -903,4 +960,5 @@ x-tagGroups:
- CompletionResponse
- CompletionResponseStreamChunk
- Message
- ShieldConfig
- URL