mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 21:34:36 +00:00
Merge branch 'main' into fix_llama_guard_inference
This commit is contained in:
commit
7a8b5c1604
75 changed files with 979 additions and 182 deletions
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -1,3 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
|
@ -1,3 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
|
|
|
@ -1,3 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# supress warnings and spew of logs from hugging face
|
||||
import transformers
|
||||
|
||||
|
|
|
@ -1,3 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
|
|
|
@ -1,3 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from codeshield.cs import CodeShield
|
||||
from termcolor import cprint
|
||||
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -1,3 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
|
@ -5,7 +17,11 @@ from llama_models.llama3_1.api.datatypes import Message
|
|||
|
||||
parent_dir = "../.."
|
||||
sys.path.append(parent_dir)
|
||||
from llama_toolchain.safety.shields.base import OnViolationAction, ShieldBase, ShieldResponse
|
||||
from llama_toolchain.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
ShieldBase,
|
||||
ShieldResponse,
|
||||
)
|
||||
|
||||
_INSTANCE = None
|
||||
|
||||
|
|
|
@ -1,3 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import re
|
||||
|
||||
from string import Template
|
||||
|
@ -100,7 +112,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
def instance(
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
model_dir: str = None,
|
||||
excluded_categories: List[str] = [],
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
) -> "LlamaGuardShield":
|
||||
|
@ -119,7 +131,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
self,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
model_dir: str = None,
|
||||
excluded_categories: List[str] = [],
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
):
|
||||
|
@ -129,6 +141,8 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
assert model_dir is not None, "Llama Guard model_dir is None"
|
||||
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
assert len(excluded_categories) == 0 or all(
|
||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||
|
|
|
@ -1,3 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import auto, Enum
|
||||
from typing import List
|
||||
|
||||
|
|
|
@ -1,3 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
|
@ -6,7 +18,7 @@ from llama_models.llama3_1.api.datatypes import Message, Role
|
|||
from .base import OnViolationAction, ShieldBase, ShieldResponse
|
||||
|
||||
|
||||
class SafetyException(Exception):
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
def __init__(self, response: ShieldResponse):
|
||||
self.response = response
|
||||
super().__init__(response.violation_return_message)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue