mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
Merge 4e720dcbef
into 8ed69978f9
This commit is contained in:
commit
e32ecc7d33
9 changed files with 144 additions and 24 deletions
|
@ -6,9 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import Message
|
||||||
Message,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.safety.safety import ModerationObject
|
from llama_stack.apis.safety.safety import ModerationObject
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
@ -68,6 +66,7 @@ class SafetyRouter(Safety):
|
||||||
list_shields_response = await self.routing_table.list_shields()
|
list_shields_response = await self.routing_table.list_shields()
|
||||||
|
|
||||||
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
|
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
raise ValueError(f"No shield associated with provider_resource id {model}")
|
raise ValueError(f"No shield associated with provider_resource id {model}")
|
||||||
if len(matches) > 1:
|
if len(matches) > 1:
|
||||||
|
|
|
@ -28,6 +28,7 @@ distribution_spec:
|
||||||
- provider_type: inline::localfs
|
- provider_type: inline::localfs
|
||||||
safety:
|
safety:
|
||||||
- provider_type: inline::llama-guard
|
- provider_type: inline::llama-guard
|
||||||
|
- provider_type: inline::code-scanner
|
||||||
agents:
|
agents:
|
||||||
- provider_type: inline::meta-reference
|
- provider_type: inline::meta-reference
|
||||||
telemetry:
|
telemetry:
|
||||||
|
|
|
@ -134,6 +134,8 @@ providers:
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
config:
|
config:
|
||||||
excluded_categories: []
|
excluded_categories: []
|
||||||
|
- provider_id: code-scanner
|
||||||
|
provider_type: inline::code-scanner
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
@ -215,6 +217,9 @@ shields:
|
||||||
- shield_id: llama-guard
|
- shield_id: llama-guard
|
||||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||||
|
- shield_id: code-scanner
|
||||||
|
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||||
|
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
|
|
|
@ -28,6 +28,7 @@ distribution_spec:
|
||||||
- provider_type: inline::localfs
|
- provider_type: inline::localfs
|
||||||
safety:
|
safety:
|
||||||
- provider_type: inline::llama-guard
|
- provider_type: inline::llama-guard
|
||||||
|
- provider_type: inline::code-scanner
|
||||||
agents:
|
agents:
|
||||||
- provider_type: inline::meta-reference
|
- provider_type: inline::meta-reference
|
||||||
telemetry:
|
telemetry:
|
||||||
|
|
|
@ -134,6 +134,8 @@ providers:
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
config:
|
config:
|
||||||
excluded_categories: []
|
excluded_categories: []
|
||||||
|
- provider_id: code-scanner
|
||||||
|
provider_type: inline::code-scanner
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
@ -215,6 +217,9 @@ shields:
|
||||||
- shield_id: llama-guard
|
- shield_id: llama-guard
|
||||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||||
|
- shield_id: code-scanner
|
||||||
|
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||||
|
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
|
|
|
@ -15,19 +15,14 @@ from llama_stack.core.datatypes import (
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.distributions.template import (
|
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||||
DistributionTemplate,
|
|
||||||
RunConfigSettings,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.datatypes import RemoteProviderSpec
|
from llama_stack.providers.datatypes import RemoteProviderSpec
|
||||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
SentenceTransformersInferenceConfig,
|
SentenceTransformersInferenceConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.milvus.config import (
|
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
|
||||||
MilvusVectorIOConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||||
SQLiteVectorIOConfig,
|
SQLiteVectorIOConfig,
|
||||||
)
|
)
|
||||||
|
@ -119,7 +114,10 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
BuildProvider(provider_type="remote::pgvector"),
|
BuildProvider(provider_type="remote::pgvector"),
|
||||||
],
|
],
|
||||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
"safety": [
|
||||||
|
BuildProvider(provider_type="inline::llama-guard"),
|
||||||
|
BuildProvider(provider_type="inline::code-scanner"),
|
||||||
|
],
|
||||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||||
"post_training": [BuildProvider(provider_type="inline::huggingface")],
|
"post_training": [BuildProvider(provider_type="inline::huggingface")],
|
||||||
|
@ -167,6 +165,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_id="${env.SAFETY_MODEL:+llama-guard}",
|
provider_id="${env.SAFETY_MODEL:+llama-guard}",
|
||||||
provider_shield_id="${env.SAFETY_MODEL:=}",
|
provider_shield_id="${env.SAFETY_MODEL:=}",
|
||||||
),
|
),
|
||||||
|
ShieldInput(
|
||||||
|
shield_id="code-scanner",
|
||||||
|
provider_id="${env.CODE_SCANNER_MODEL:+code-scanner}",
|
||||||
|
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
|
|
|
@ -5,7 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
import uuid
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from codeshield.cs import CodeShieldScanResult
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import (
|
from llama_stack.apis.safety import (
|
||||||
|
@ -14,6 +18,7 @@ from llama_stack.apis.safety import (
|
||||||
SafetyViolation,
|
SafetyViolation,
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
|
@ -24,8 +29,8 @@ from .config import CodeScannerConfig
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||||
"CodeScanner",
|
"code-scanner",
|
||||||
"CodeShield",
|
"code-shield",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,3 +74,55 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
|
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
|
||||||
)
|
)
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
|
def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults:
|
||||||
|
categories = {}
|
||||||
|
category_scores = {}
|
||||||
|
category_applied_input_types = {}
|
||||||
|
|
||||||
|
flagged = scan_result.is_insecure
|
||||||
|
user_message = None
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
if scan_result.is_insecure:
|
||||||
|
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
|
||||||
|
categories = dict.fromkeys(pattern_ids, True)
|
||||||
|
category_scores = dict.fromkeys(pattern_ids, 1.0)
|
||||||
|
category_applied_input_types = {key: ["text"] for key in pattern_ids}
|
||||||
|
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
|
||||||
|
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}
|
||||||
|
|
||||||
|
return ModerationObjectResults(
|
||||||
|
flagged=flagged,
|
||||||
|
categories=categories,
|
||||||
|
category_scores=category_scores,
|
||||||
|
category_applied_input_types=category_applied_input_types,
|
||||||
|
user_message=user_message,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||||
|
inputs = input if isinstance(input, list) else [input]
|
||||||
|
results = []
|
||||||
|
|
||||||
|
from codeshield.cs import CodeShield
|
||||||
|
|
||||||
|
for text_input in inputs:
|
||||||
|
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
|
||||||
|
try:
|
||||||
|
scan_result = await CodeShield.scan_code(text_input)
|
||||||
|
moderation_result = self.get_moderation_object_results(scan_result)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"CodeShield.scan_code failed: {e}")
|
||||||
|
# create safe fallback response on scanner failure to avoid blocking legitimate requests
|
||||||
|
moderation_result = ModerationObjectResults(
|
||||||
|
flagged=False,
|
||||||
|
categories={},
|
||||||
|
category_scores={},
|
||||||
|
category_applied_input_types={},
|
||||||
|
user_message=None,
|
||||||
|
metadata={"scanner_error": str(e)},
|
||||||
|
)
|
||||||
|
results.append(moderation_result)
|
||||||
|
|
||||||
|
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
|
||||||
|
|
|
@ -11,11 +11,7 @@ from string import Template
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import Inference, Message, UserMessage
|
||||||
Inference,
|
|
||||||
Message,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.safety import (
|
from llama_stack.apis.safety import (
|
||||||
RunShieldResponse,
|
RunShieldResponse,
|
||||||
Safety,
|
Safety,
|
||||||
|
@ -72,7 +68,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
|
||||||
}
|
}
|
||||||
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
|
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
||||||
CAT_VIOLENT_CRIMES,
|
CAT_VIOLENT_CRIMES,
|
||||||
CAT_NON_VIOLENT_CRIMES,
|
CAT_NON_VIOLENT_CRIMES,
|
||||||
|
@ -460,7 +455,7 @@ class LlamaGuardShield:
|
||||||
|
|
||||||
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
|
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
|
||||||
"""Check if content is safe based on response and unsafe code."""
|
"""Check if content is safe based on response and unsafe code."""
|
||||||
if response.strip() == SAFE_RESPONSE:
|
if response.strip().lower().startswith(SAFE_RESPONSE):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if unsafe_code:
|
if unsafe_code:
|
||||||
|
|
|
@ -27,10 +27,10 @@ def data_url_from_image(file_path):
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def code_scanner_shield_id(available_shields):
|
def code_scanner_shield_id(available_shields):
|
||||||
if "CodeScanner" in available_shields:
|
if "code-scanner" in available_shields:
|
||||||
return "CodeScanner"
|
return "code-scanner"
|
||||||
|
|
||||||
pytest.skip("CodeScanner shield is not available. Skipping.")
|
pytest.skip("code-scanner shield is not available. Skipping.")
|
||||||
|
|
||||||
|
|
||||||
def test_unsafe_examples(client_with_models, shield_id):
|
def test_unsafe_examples(client_with_models, shield_id):
|
||||||
|
@ -163,6 +163,60 @@ def test_safety_with_code_scanner(client_with_models, code_scanner_shield_id, mo
|
||||||
assert response.violation.user_message == "Sorry, I found security concerns in the code."
|
assert response.violation.user_message == "Sorry, I found security concerns in the code."
|
||||||
|
|
||||||
|
|
||||||
|
def test_unsafe_input_with_code_scanner_run_moderation(client_with_models, code_scanner_shield_id, model_providers):
|
||||||
|
code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0
|
||||||
|
if not code_safety_enabled:
|
||||||
|
pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}")
|
||||||
|
|
||||||
|
insecure_code = [
|
||||||
|
"""
|
||||||
|
Here is an example of code that hashes a password using MD5:
|
||||||
|
|
||||||
|
```
|
||||||
|
import hashlib
|
||||||
|
password = 'securepassword123'.encode()
|
||||||
|
hashed_password = hashlib.md5(password).hexdigest()
|
||||||
|
print("Insecure hashed password:", hashed_password)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
moderation_object = client_with_models.moderations.create(
|
||||||
|
input=insecure_code,
|
||||||
|
model=code_scanner_shield_id,
|
||||||
|
)
|
||||||
|
assert moderation_object.results[0].flagged is True, f"Code scanner should have flagged {insecure_code} as insecure"
|
||||||
|
assert all(value is True for value in moderation_object.results[0].categories.values()), (
|
||||||
|
"Code scanner shield should have detected code insecure category"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_safe_input_with_code_scanner_run_moderation(client_with_models, code_scanner_shield_id, model_providers):
|
||||||
|
code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0
|
||||||
|
if not code_safety_enabled:
|
||||||
|
pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}")
|
||||||
|
|
||||||
|
secure_code = [
|
||||||
|
"""
|
||||||
|
Extract the first 5 characters from a string:
|
||||||
|
```
|
||||||
|
text = "Hello World"
|
||||||
|
first_five = text[:5]
|
||||||
|
print(first_five) # Output: "Hello"
|
||||||
|
|
||||||
|
# Safe handling for strings shorter than 5 characters
|
||||||
|
def get_first_five(text):
|
||||||
|
return text[:5] if text else ""
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
moderation_object = client_with_models.moderations.create(
|
||||||
|
input=secure_code,
|
||||||
|
model=code_scanner_shield_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert moderation_object.results[0].flagged is False, "Code scanner should not have flagged the code as insecure"
|
||||||
|
|
||||||
|
|
||||||
# We can use an instance of the LlamaGuard shield to detect attempts to misuse
|
# We can use an instance of the LlamaGuard shield to detect attempts to misuse
|
||||||
# the interpreter as this is one of the existing categories it checks for
|
# the interpreter as this is one of the existing categories it checks for
|
||||||
def test_safety_with_code_interpreter_abuse(client_with_models, shield_id):
|
def test_safety_with_code_interpreter_abuse(client_with_models, shield_id):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue