Merge remote-tracking branch 'origin/main' into RFC-0001-The-Llama-Stack

This commit is contained in:
Ashwin Bharambe 2024-08-20 18:58:47 -07:00
commit 75bbe787b6
35 changed files with 309 additions and 90 deletions

View file

@ -110,6 +110,35 @@ class Session(BaseModel):
started_at: datetime
@json_schema_type
class ToolPromptFormat(Enum):
"""This Enum refers to the prompt format for calling zero shot tools
`json` --
Refers to the json format for calling tools.
The json format takes the form like
{
"type": "function",
"function" : {
"name": "function_name",
"description": "function_description",
"parameters": {...}
}
}
`function_tag` --
This is an example of how you could define
your own user defined format for making tool calls.
The function_tag format looks like this,
<function=function_name>(parameters)</function>
The detailed prompts for each of these formats are defined in `system_prompt.py`
"""
json = "json"
function_tag = "function_tag"
@json_schema_type
class AgenticSystemInstanceConfig(BaseModel):
instructions: str
@ -127,6 +156,9 @@ class AgenticSystemInstanceConfig(BaseModel):
# if you completely want to replace the messages prefixed by the system,
# this is debug only
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
class AgenticSystemTurnResponseEventType(Enum):

View file

@ -13,8 +13,15 @@ import fire
import httpx
from llama_models.llama3_1.api.datatypes import BuiltinTool, SamplingParams
from llama_models.llama3.api.datatypes import (
BuiltinTool,
SamplingParams,
ToolParamDefinition,
UserMessage,
)
from termcolor import cprint
from llama_toolchain.agentic_system.event_logger import EventLogger
from .api import (
AgenticSystem,
AgenticSystemCreateRequest,
@ -25,6 +32,7 @@ from .api import (
AgenticSystemToolDefinition,
AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseStreamChunk,
ToolPromptFormat,
)
@ -87,7 +95,7 @@ class AgenticSystemClient(AgenticSystem):
async def run_main(host: str, port: int):
# client to test remote impl of agentic system
api = await AgenticSystemClient(f"http://{host}:{port}")
api = AgenticSystemClient(f"http://{host}:{port}")
tool_definitions = [
AgenticSystemToolDefinition(
@ -96,13 +104,28 @@ async def run_main(host: str, port: int):
AgenticSystemToolDefinition(
tool_name=BuiltinTool.wolfram_alpha,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.photogen,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.code_interpreter,
),
]
tool_definitions += [
AgenticSystemToolDefinition(
tool_name="get_boiling_point",
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
parameters={
"liquid_name": ToolParamDefinition(
param_type="str",
description="The name of the liquid",
required=True,
),
"celcius": ToolParamDefinition(
param_type="str",
description="Whether to return the boiling point in Celcius",
required=False,
),
},
),
]
create_request = AgenticSystemCreateRequest(
model="Meta-Llama3.1-8B-Instruct",
@ -114,12 +137,44 @@ async def run_main(host: str, port: int):
output_shields=[],
quantization_config=None,
debug_prefix_messages=[],
tool_prompt_format=ToolPromptFormat.json,
),
)
create_response = await api.create_agentic_system(create_request)
print(create_response)
# TODO: Add chat session / turn apis to test e2e
session_response = await api.create_agentic_system_session(
AgenticSystemSessionCreateRequest(
system_id=create_response.system_id,
session_name="test_session",
)
)
print(session_response)
user_prompts = [
"Who are you?",
"what is the 100th prime number?",
"Search web for who was 44th President of USA?",
"Write code to check if a number is prime. Use that to check if 7 is prime",
"What is the boiling point of polyjuicepotion ?",
]
for content in user_prompts:
cprint(f"User> {content}", color="blue")
iterator = api.create_agentic_system_turn(
AgenticSystemTurnCreateRequest(
system_id=create_response.system_id,
session_id=session_response.session_id,
messages=[
UserMessage(content=content),
],
stream=True,
)
)
async for event, log in EventLogger().log(iterator):
if log is not None:
log.print()
def main(host: str, port: int):

View file

@ -6,16 +6,16 @@
from typing import Optional
from llama_models.llama3_1.api.datatypes import ToolResponseMessage
from llama_models.llama3_1.api.tool_utils import ToolUtils
from llama_models.llama3.api.datatypes import ToolResponseMessage
from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_toolchain.agentic_system.api import (
AgenticSystemTurnResponseEventType,
StepType,
)
from termcolor import cprint
class LogEvent:
def __init__(

View file

@ -10,6 +10,8 @@ import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Optional
from termcolor import cprint
from llama_toolchain.agentic_system.api.datatypes import (
AgenticSystemInstanceConfig,
AgenticSystemTurnResponseEvent,
@ -24,6 +26,7 @@ from llama_toolchain.agentic_system.api.datatypes import (
ShieldCallStep,
StepType,
ToolExecutionStep,
ToolPromptFormat,
Turn,
)
@ -51,7 +54,6 @@ from llama_toolchain.safety.api.datatypes import (
ShieldDefinition,
ShieldResponse,
)
from termcolor import cprint
from llama_toolchain.agentic_system.api.endpoints import * # noqa
from .safety import SafetyException, ShieldRunnerMixin
@ -74,6 +76,7 @@ class AgentInstance(ShieldRunnerMixin):
output_shields: List[ShieldDefinition],
max_infer_iters: int = 10,
prefix_messages: Optional[List[Message]] = None,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
):
self.system_id = system_id
self.instance_config = instance_config
@ -86,7 +89,9 @@ class AgentInstance(ShieldRunnerMixin):
self.prefix_messages = prefix_messages
else:
self.prefix_messages = get_agentic_prefix_messages(
builtin_tools, custom_tool_definitions
builtin_tools,
custom_tool_definitions,
tool_prompt_format,
)
for m in self.prefix_messages:

View file

@ -108,6 +108,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
input_shields=cfg.input_shields,
output_shields=cfg.output_shields,
prefix_messages=cfg.debug_prefix_messages,
tool_prompt_format=cfg.tool_prompt_format,
)
return AgenticSystemCreateResponse(

View file

@ -6,14 +6,15 @@
from typing import List
from llama_models.llama3_1.api.datatypes import Message, Role
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
from termcolor import cprint
from llama_toolchain.safety.api.datatypes import (
OnViolationAction,
ShieldDefinition,
ShieldResponse,
)
from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety
from termcolor import cprint
class SafetyException(Exception): # noqa: N818
@ -36,12 +37,11 @@ class ShieldRunnerMixin:
async def run_shields(
self, messages: List[Message], shields: List[ShieldDefinition]
) -> List[ShieldResponse]:
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
# TODO(ashwin): we need to change the type of the message, this kind of modification
# is no longer appropriate
messages[0].role = Role.user.value
messages[0] = UserMessage(content=messages[0].content)
res = await self.safety_api.run_shields(
RunShieldRequest(

View file

@ -5,21 +5,27 @@
# the root directory of this source tree.
import json
import textwrap
from datetime import datetime
from typing import List
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
from llama_toolchain.inference.api import (
BuiltinTool,
Message,
SystemMessage,
ToolDefinition,
UserMessage,
)
from .tools.builtin import SingleMessageBuiltinTool
def get_agentic_prefix_messages(
builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition]
builtin_tools: List[SingleMessageBuiltinTool],
custom_tools: List[ToolDefinition],
tool_prompt_format: ToolPromptFormat,
) -> List[Message]:
messages = []
content = ""
@ -34,28 +40,52 @@ def get_agentic_prefix_messages(
]
)
if tool_str:
content += f"Tools: {tool_str}\n"
content += f"Tools: {tool_str}"
current_date = datetime.now()
formatted_date = current_date.strftime("%d %B %Y")
date_str = f"""
Cutting Knowledge Date: December 2023
Today Date: {formatted_date}\n\n"""
Today Date: {formatted_date}\n"""
content += date_str
messages.append(SystemMessage(content=content))
if custom_tools:
custom_message = get_system_prompt_for_custom_tools(custom_tools)
content += custom_message
if tool_prompt_format == ToolPromptFormat.function_tag:
text = prompt_for_function_tag(custom_tools)
messages.append(UserMessage(content=text))
elif tool_prompt_format == ToolPromptFormat.json:
text = prompt_for_json(custom_tools)
messages.append(UserMessage(content=text))
else:
raise NotImplementedError(
f"Tool prompt format {tool_prompt_format} is not supported"
)
else:
messages.append(SystemMessage(content=content))
# TODO: Replace this hard coded message with instructions coming in the request
if False:
content += "You are a helpful Assistant."
messages.append(SystemMessage(content=content))
return messages
def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str:
def prompt_for_json(custom_tools: List[ToolDefinition]) -> str:
tool_defs = "\n".join(
translate_custom_tool_definition_to_json(t) for t in custom_tools
)
content = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{tool_defs}
Return function calls in JSON format.
"""
)
content = content.lstrip("\n").format(tool_defs=tool_defs)
return content
def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str:
custom_tool_params = ""
for t in custom_tools:
custom_tool_params += get_instruction_string(t) + "\n"
@ -76,7 +106,6 @@ Reminder:
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
"""
return content
@ -98,7 +127,6 @@ def get_parameters_string(custom_tool_definition) -> str:
)
# NOTE: Unused right now
def translate_custom_tool_definition_to_json(tool_def):
"""Translates ToolDefinition to json as expected by model
eg. output for a function
@ -149,4 +177,4 @@ def translate_custom_tool_definition_to_json(tool_def):
else:
func_def["function"]["parameters"] = {}
return json.dumps(func_def)
return json.dumps(func_def, indent=4)

View file

@ -9,7 +9,7 @@ import json
from abc import abstractmethod
from typing import Dict, List
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.agentic_system.api import * # noqa: F403
# TODO: this is symptomatic of us needing to pull more tooling related utilities

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, List
from llama_models.llama3_1.api.datatypes import StopReason, ToolResponseMessage
from llama_models.llama3.api.datatypes import StopReason, ToolResponseMessage
from llama_toolchain.agentic_system.api import (
AgenticSystem,

View file

@ -7,7 +7,7 @@
import uuid
from typing import Any, List, Optional
from llama_models.llama3_1.api.datatypes import BuiltinTool, Message, SamplingParams
from llama_models.llama3.api.datatypes import BuiltinTool, Message, SamplingParams
from llama_toolchain.agentic_system.api import (
AgenticSystemCreateRequest,
@ -15,6 +15,7 @@ from llama_toolchain.agentic_system.api import (
AgenticSystemSessionCreateRequest,
AgenticSystemToolDefinition,
)
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
from llama_toolchain.agentic_system.client import AgenticSystemClient
from llama_toolchain.agentic_system.tools.custom.execute import (
@ -64,6 +65,7 @@ async def get_agent_system_instance(
custom_tools: Optional[List[Any]] = None,
disable_safety: bool = False,
model: str = "Meta-Llama3.1-8B-Instruct",
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> AgenticSystemClientWrapper:
custom_tools = custom_tools or []
@ -113,6 +115,7 @@ async def get_agent_system_instance(
]
),
sampling_params=SamplingParams(),
tool_prompt_format=tool_prompt_format,
),
)
create_response = await api.create_agentic_system(create_request)

View file

@ -6,18 +6,22 @@
import argparse
import asyncio
import json
import os
import shutil
import time
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Dict, List
import httpx
from llama_toolchain.cli.subcommand import Subcommand
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets"""
@ -45,7 +49,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--model-id",
choices=[x.descriptor() for x in models],
required=True,
required=False,
)
parser.add_argument(
"--hf-token",
@ -70,6 +74,12 @@ For source=huggingface, files matching any of the patterns are not downloaded. D
safetensors files to avoid downloading duplicate weights.
""",
)
parser.add_argument(
"--manifest-file",
type=str,
help="For source=meta, you can download models from a manifest file containing a file => URL mapping",
required=False,
)
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
@ -88,7 +98,7 @@ def _hf_download(
if repo_id is None:
raise ValueError(f"No repo id found for model {model.descriptor()}")
output_dir = model_local_dir(model)
output_dir = model_local_dir(model.descriptor())
os.makedirs(output_dir, exist_ok=True)
try:
true_output_dir = snapshot_download(
@ -118,7 +128,7 @@ def _meta_download(model: "Model", meta_url: str):
from llama_toolchain.common.model_utils import model_local_dir
output_dir = Path(model_local_dir(model))
output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True)
info = llama_meta_net_info(model)
@ -139,6 +149,14 @@ def _meta_download(model: "Model", meta_url: str):
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_models.sku_list import resolve_model
if args.manifest_file:
_download_from_manifest(args.manifest_file)
return
if args.model_id is None:
parser.error("Please provide a model id")
return
model = resolve_model(args.model_id)
if model is None:
parser.error(f"Model {args.model_id} not found")
@ -156,6 +174,54 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
_meta_download(model, meta_url)
class ModelEntry(BaseModel):
model_id: str
files: Dict[str, str]
class Manifest(BaseModel):
models: List[ModelEntry]
expires_on: datetime
def _download_from_manifest(manifest_file: str):
from llama_toolchain.common.model_utils import model_local_dir
with open(manifest_file, "r") as f:
d = json.load(f)
manifest = Manifest(**d)
if datetime.now() > manifest.expires_on:
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
for entry in manifest.models:
print(f"Downloading model {entry.model_id}...")
output_dir = Path(model_local_dir(entry.model_id))
os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()):
cprint(f"Output directory {output_dir} is not empty.", "red")
while True:
resp = input(
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
)
if resp.lower() == "restart" or resp.lower() == "r":
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
break
elif resp.lower() == "continue" or resp.lower() == "c":
print("Continuing download...")
break
else:
cprint("Invalid response. Please try again.", "red")
for fname, url in entry.files.items():
output_file = str(output_dir / fname)
downloader = ResumableDownloader(url, output_file)
asyncio.run(downloader.download())
class ResumableDownloader:
def __init__(
self,
@ -190,7 +256,7 @@ class ResumableDownloader:
async def download(self) -> None:
self.start_time = time.time()
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(follow_redirects=True) as client:
await self.get_file_info(client)
if os.path.exists(self.output_file):
@ -222,7 +288,7 @@ class ResumableDownloader:
headers = {
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
}
# print(f"Downloading `{self.output_file}`....{headers}")
print(f"Downloading `{self.output_file}`....{headers}")
try:
async with client.stream(
"GET", self.url, headers=headers

View file

@ -7,10 +7,10 @@
import argparse
import textwrap
from llama_toolchain.cli.subcommand import Subcommand
from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand
class ModelTemplate(Subcommand):
"""Llama model cli for describe a model template (message formats)"""
@ -48,10 +48,11 @@ class ModelTemplate(Subcommand):
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
from llama_models.llama3_1.api.interface import (
from llama_models.llama3.api.interface import (
list_jinja_templates,
render_jinja_template,
)
from llama_toolchain.cli.table import print_table
if args.name:

View file

@ -7,7 +7,7 @@
from enum import Enum
from typing import Dict, Optional
from llama_models.llama3_1.api.datatypes import URL
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type

View file

@ -1,9 +1,13 @@
import os
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.datatypes import Model
import os
from .config_dirs import DEFAULT_CHECKPOINT_DIR
def model_local_dir(model: Model) -> str:
return os.path.join(DEFAULT_CHECKPOINT_DIR, model.descriptor())
def model_local_dir(descriptor: str) -> str:
return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3_1.api.datatypes import URL
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel

View file

@ -7,7 +7,7 @@
from enum import Enum
from typing import Any, Dict, Optional
from llama_models.llama3_1.api.datatypes import URL
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type

View file

@ -10,7 +10,7 @@ from llama_models.schema_utils import webmethod
from pydantic import BaseModel
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403

View file

@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
class LogProbConfig(BaseModel):

View file

@ -22,22 +22,22 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from llama_models.llama3_1.api.args import ModelArgs
from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3_1.reference_impl.model import Transformer
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.inference.api import QuantizationType
from termcolor import cprint
from .config import MetaReferenceImplConfig
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model))
checkpoint_dir = Path(model_local_dir(model.descriptor()))
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
checkpoint_dir = checkpoint_dir / "original"

View file

@ -8,7 +8,7 @@ import asyncio
from typing import AsyncIterator, Dict, Union
from llama_models.llama3_1.api.datatypes import StopReason
from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Api, ProviderSpec

View file

@ -10,9 +10,9 @@ from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
from llama_models.llama3_1.api.chat_format import ChatFormat
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig

View file

@ -9,15 +9,17 @@ from typing import AsyncGenerator, Dict
import httpx
from llama_models.llama3_1.api.datatypes import (
from llama_models.llama3.api.datatypes import (
BuiltinTool,
CompletionMessage,
Message,
StopReason,
ToolCall,
)
from llama_models.llama3_1.api.tool_utils import ToolUtils
from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import (
ChatCompletionRequest,
@ -30,7 +32,6 @@ from llama_toolchain.inference.api import (
ToolCallDelta,
ToolCallParseStatus,
)
from ollama import AsyncClient
from .config import OllamaImplConfig
@ -64,10 +65,10 @@ class OllamaInference(Inference):
async def initialize(self) -> None:
try:
await self.client.ps()
except httpx.ConnectError:
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
)
) from e
async def shutdown(self) -> None:
pass

View file

@ -13,7 +13,7 @@ from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3_1.api.model import Transformer, TransformerBlock
from llama_models.llama3.api.model import Transformer, TransformerBlock
from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat,

View file

@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403
from .datatypes import * # noqa: F403

View file

@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
@json_schema_type

View file

@ -7,13 +7,12 @@
from enum import Enum
from typing import Dict, Optional, Union
from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
from pydantic import BaseModel, validator
from pydantic import BaseModel
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
@json_schema_type
@ -43,6 +42,16 @@ class ShieldDefinition(BaseModel):
on_violation_action: OnViolationAction = OnViolationAction.RAISE
execution_config: Optional[RestAPIExecutionConfig] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v
@json_schema_type
class ShieldResponse(BaseModel):
@ -51,3 +60,13 @@ class ShieldResponse(BaseModel):
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v

View file

@ -7,7 +7,7 @@
from .datatypes import * # noqa: F403
from typing import List, Protocol
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3.api.datatypes import Message
# this dependency is annoying and we need a forked up version anyway
from llama_models.schema_utils import webmethod

View file

@ -9,7 +9,7 @@ import asyncio
import fire
import httpx
from llama_models.llama3_1.api.datatypes import UserMessage
from llama_models.llama3.api.datatypes import UserMessage
from termcolor import cprint
from .api import (

View file

@ -36,7 +36,7 @@ async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]
def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"
model_dir = model_local_dir(model)
model_dir = model_local_dir(model.descriptor())
return model_dir
@ -73,30 +73,34 @@ class MetaReferenceSafetyImpl(Safety):
return RunShieldResponse(responses=responses)
def shield_type_equals(a: ShieldType, b: ShieldType):
return a == b or a == b.value
def shield_config_to_shield(
sc: ShieldDefinition, safety_config: SafetyConfig
) -> ShieldBase:
if sc.shield_type == BuiltinShield.llama_guard:
if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard):
assert (
safety_config.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
elif sc.shield_type == BuiltinShield.jailbreak_shield:
elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield):
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif sc.shield_type == BuiltinShield.injection_shield:
elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield):
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return InjectionShield.instance(model_dir)
elif sc.shield_type == BuiltinShield.code_scanner_guard:
elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard):
return CodeScannerShield.instance()
elif sc.shield_type == BuiltinShield.third_party_shield:
elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield):
return ThirdPartyShield.instance()
else:
raise ValueError(f"Unknown shield type: {sc.shield_type}")

View file

@ -7,7 +7,7 @@
from abc import ABC, abstractmethod
from typing import List, Union
from llama_models.llama3_1.api.datatypes import Attachment, Message
from llama_models.llama3.api.datatypes import Attachment, Message
from llama_toolchain.safety.api.datatypes import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -6,7 +6,7 @@
from typing import List
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3.api.datatypes import Message
from llama_toolchain.safety.meta_reference.shields.base import (
OnViolationAction,

View file

@ -10,7 +10,7 @@ from string import Template
from typing import List, Optional
import torch
from llama_models.llama3_1.api.datatypes import Message, Role
from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse

View file

@ -9,7 +9,7 @@ from typing import List
import torch
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer

View file

@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403

View file

@ -16,7 +16,7 @@ def read_requirements():
setup(
name="llama_toolchain",
version="0.0.5",
version="0.0.8",
author="Meta Llama",
author_email="llama-oss@meta.com",
description="Llama toolchain",