fix safety in starter

This commit is contained in:
Hardik Shah 2025-07-11 16:42:24 -07:00
parent 51d9fd4808
commit 3d83322003
8 changed files with 84 additions and 111 deletions

View file

@ -89,7 +89,7 @@ jobs:
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="ollama/llama3.2:3b-instruct-fp16" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \
--embedding-model=all-MiniLM-L6-v2 \ --embedding-model=all-MiniLM-L6-v2 \
--safety-shield=ollama \ --safety-shield=$SAFETY_MODEL \
--color=yes \ --color=yes \
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log

View file

@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_entry, build_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16", "llama3.1:8b-instruct-fp16",
@ -73,16 +86,6 @@ MODEL_ENTRIES = [
"llama3.3:70b", "llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="all-minilm:l6-v2", provider_model_id="all-minilm:l6-v2",
aliases=["all-minilm"], aliases=["all-minilm"],
@ -100,4 +103,4 @@ MODEL_ENTRIES = [
"context_length": 8192, "context_length": 8192,
}, },
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -68,7 +68,7 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
default_models = get_model_registry(available_models) default_models, _ = get_model_registry(available_models)
return DistributionTemplate( return DistributionTemplate(
name="nvidia", name="nvidia",
distro_type="self_hosted", distro_type="self_hosted",

View file

@ -146,7 +146,8 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
default_models = get_model_registry(available_models) + [ models, _ = get_model_registry(available_models)
default_models = models + [
ModelInput( ModelInput(
model_id="meta-llama/Llama-3.3-70B-Instruct", model_id="meta-llama/Llama-3.3-70B-Instruct",
provider_id="groq", provider_id="groq",

View file

@ -1171,24 +1171,8 @@ models:
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
model_type: embedding model_type: embedding
shields: shields:
- shield_id: ${env.ENABLE_OLLAMA:=__disabled__} - shield_id: ${env.SAFETY_MODEL:=__disabled__}
provider_id: llama-guard provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b}
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b}
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision}
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo}
- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -12,7 +12,6 @@ from llama_stack.distribution.datatypes import (
ModelInput, ModelInput,
Provider, Provider,
ProviderSpec, ProviderSpec,
ShieldInput,
ToolGroupInput, ToolGroupInput,
) )
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -32,75 +31,39 @@ from llama_stack.providers.registry.inference import available_providers
from llama_stack.providers.remote.inference.anthropic.models import ( from llama_stack.providers.remote.inference.anthropic.models import (
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES, MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.anthropic.models import (
SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.bedrock.models import ( from llama_stack.providers.remote.inference.bedrock.models import (
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES, MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.bedrock.models import (
SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.cerebras.models import ( from llama_stack.providers.remote.inference.cerebras.models import (
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES, MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.cerebras.models import (
SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.databricks.databricks import ( from llama_stack.providers.remote.inference.databricks.databricks import (
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES, MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.databricks.databricks import (
SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.fireworks.models import ( from llama_stack.providers.remote.inference.fireworks.models import (
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.fireworks.models import (
SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.gemini.models import ( from llama_stack.providers.remote.inference.gemini.models import (
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES, MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.gemini.models import (
SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.groq.models import ( from llama_stack.providers.remote.inference.groq.models import (
MODEL_ENTRIES as GROQ_MODEL_ENTRIES, MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.groq.models import (
SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.nvidia.models import ( from llama_stack.providers.remote.inference.nvidia.models import (
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES, MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.nvidia.models import (
SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.openai.models import ( from llama_stack.providers.remote.inference.openai.models import (
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES, MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.openai.models import (
SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.runpod.runpod import ( from llama_stack.providers.remote.inference.runpod.runpod import (
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES, MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.runpod.runpod import (
SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.sambanova.models import ( from llama_stack.providers.remote.inference.sambanova.models import (
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES, MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.sambanova.models import (
SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.together.models import ( from llama_stack.providers.remote.inference.together.models import (
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES, MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.together.models import (
SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import ( from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig, PGVectorVectorIOConfig,
@ -111,6 +74,7 @@ from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
RunConfigSettings, RunConfigSettings,
get_model_registry, get_model_registry,
get_shield_registry,
) )
@ -164,28 +128,13 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]: def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
"""Get model entries for a specific provider type.""" """Get model entries for a specific provider type."""
safety_model_entries_map = { safety_model_entries_map = {
"openai": OPENAI_SAFETY_MODELS_ENTRIES, "ollama": [
"fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES,
"together": TOGETHER_SAFETY_MODELS_ENTRIES,
"anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES,
"gemini": GEMINI_SAFETY_MODELS_ENTRIES,
"groq": GROQ_SAFETY_MODELS_ENTRIES,
"sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES,
"cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES,
"bedrock": BEDROCK_SAFETY_MODELS_ENTRIES,
"databricks": DATABRICKS_SAFETY_MODELS_ENTRIES,
"nvidia": NVIDIA_SAFETY_MODELS_ENTRIES,
"runpod": RUNPOD_SAFETY_MODELS_ENTRIES,
}
# Special handling for providers with dynamic model entries
if provider_type == "ollama":
return [
ProviderModelEntry( ProviderModelEntry(
provider_model_id="llama-guard3:1b", provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm, model_type=ModelType.llm,
), ),
] ],
}
return safety_model_entries_map.get(provider_type, []) return safety_model_entries_map.get(provider_type, [])
@ -246,28 +195,20 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
# build a list of shields for all possible providers # build a list of shields for all possible providers
def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]: def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
shields = [] available_models = {}
for provider in providers: for provider in providers:
provider_type = provider.provider_type.split("::")[1] provider_type = provider.provider_type.split("::")[1]
safety_model_entries = _get_model_safety_entries_for_provider(provider_type) safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
if len(safety_model_entries) == 0: if len(safety_model_entries) == 0:
continue continue
if provider.provider_id:
shield_id = provider.provider_id env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
else: provider_id = f"${{env.{env_var}:=__disabled__}}"
raise ValueError(f"Provider {provider.provider_type} has no provider_id")
for safety_model_entry in safety_model_entries: available_models[provider_id] = safety_model_entries
print(f"provider.provider_id: {provider.provider_id}")
print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}") return available_models
shields.append(
ShieldInput(
provider_id="llama-guard",
shield_id=shield_id,
provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}",
)
)
return shields
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
@ -307,8 +248,6 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
shields = get_shields_for_providers(remote_inference_providers)
providers = { providers = {
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]), "inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
"vector_io": ([p.provider_type for p in vector_io_providers]), "vector_io": ([p.provider_type for p in vector_io_providers]),
@ -361,7 +300,10 @@ def get_distribution_template() -> DistributionTemplate:
}, },
) )
default_models = get_model_registry(available_models) default_models, ids_conflict_in_models = get_model_registry(available_models)
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,

View file

@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge
def get_model_registry( def get_model_registry(
available_models: dict[str, list[ProviderModelEntry]], available_models: dict[str, list[ProviderModelEntry]],
) -> list[ModelInput]: ) -> tuple[list[ModelInput], bool]:
models = [] models = []
# check for conflicts in model ids # check for conflicts in model ids
@ -74,7 +74,50 @@ def get_model_registry(
metadata=entry.metadata, metadata=entry.metadata,
) )
) )
return models return models, ids_conflict
def get_shield_registry(
available_safety_models: dict[str, list[ProviderModelEntry]],
ids_conflict_in_models: bool,
) -> list[ShieldInput]:
shields = []
# check for conflicts in shield ids
all_ids = set()
ids_conflict = False
for _, entries in available_safety_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
if model_id in all_ids:
ids_conflict = True
rich.print(
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
)
break
all_ids.update(ids)
if ids_conflict:
break
if ids_conflict:
break
for provider_id, entries in available_safety_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
shields.append(
ShieldInput(
shield_id=identifier,
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
if ids_conflict_in_models
else entry.provider_model_id,
)
)
return shields
class DefaultModel(BaseModel): class DefaultModel(BaseModel):

View file

@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate:
}, },
) )
default_models = get_model_registry(available_models) default_models, _ = get_model_registry(available_models)
return DistributionTemplate( return DistributionTemplate(
name="watsonx", name="watsonx",
distro_type="remote_hosted", distro_type="remote_hosted",