diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index c46100c38..1a8d6734f 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -89,7 +89,7 @@ jobs: -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \ --embedding-model=all-MiniLM-L6-v2 \ - --safety-shield=ollama \ + --safety-shield=$SAFETY_MODEL \ --color=yes \ --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log diff --git a/llama_stack/providers/remote/inference/ollama/models.py b/llama_stack/providers/remote/inference/ollama/models.py index 64ddb23d9..7c0a19a1a 100644 --- a/llama_stack/providers/remote/inference/ollama/models.py +++ b/llama_stack/providers/remote/inference/ollama/models.py @@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import ( build_model_entry, ) +SAFETY_MODELS_ENTRIES = [ + # The Llama Guard models don't have their full fp16 versions + # so we are going to alias their default version to the canonical SKU + build_hf_repo_model_entry( + "llama-guard3:8b", + CoreModelId.llama_guard_3_8b.value, + ), + build_hf_repo_model_entry( + "llama-guard3:1b", + CoreModelId.llama_guard_3_1b.value, + ), +] + MODEL_ENTRIES = [ build_hf_repo_model_entry( "llama3.1:8b-instruct-fp16", @@ -73,16 +86,6 @@ MODEL_ENTRIES = [ "llama3.3:70b", CoreModelId.llama3_3_70b_instruct.value, ), - # The Llama Guard models don't have their full fp16 versions - # so we are going to alias their default version to the canonical SKU - build_hf_repo_model_entry( - "llama-guard3:8b", - CoreModelId.llama_guard_3_8b.value, - ), - build_hf_repo_model_entry( - "llama-guard3:1b", - CoreModelId.llama_guard_3_1b.value, - ), ProviderModelEntry( provider_model_id="all-minilm:l6-v2", aliases=["all-minilm"], @@ -100,4 +103,4 @@ MODEL_ENTRIES = [ "context_length": 8192, }, ), -] +] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index 4eccfb25c..e5c13aa74 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -68,7 +68,7 @@ def get_distribution_template() -> DistributionTemplate: ), ] - default_models = get_model_registry(available_models) + default_models, _ = get_model_registry(available_models) return DistributionTemplate( name="nvidia", distro_type="self_hosted", diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index 942905dae..56ee9c47d 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -146,7 +146,8 @@ def get_distribution_template() -> DistributionTemplate: ), ] - default_models = get_model_registry(available_models) + [ + models, _ = get_model_registry(available_models) + default_models = models + [ ModelInput( model_id="meta-llama/Llama-3.3-70B-Instruct", provider_id="groq", diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 888a2c3bf..ad449cb1b 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -1171,24 +1171,8 @@ models: provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} model_type: embedding shields: -- shield_id: ${env.ENABLE_OLLAMA:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b} -- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b} -- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision} -- shield_id: ${env.ENABLE_TOGETHER:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B} -- shield_id: ${env.ENABLE_TOGETHER:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo} -- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B} +- shield_id: ${env.SAFETY_MODEL:=__disabled__} + provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index 6b8aa8974..c0ac44183 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -12,7 +12,6 @@ from llama_stack.distribution.datatypes import ( ModelInput, Provider, ProviderSpec, - ShieldInput, ToolGroupInput, ) from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -32,75 +31,39 @@ from llama_stack.providers.registry.inference import available_providers from llama_stack.providers.remote.inference.anthropic.models import ( MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.anthropic.models import ( - SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.bedrock.models import ( MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.bedrock.models import ( - SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.cerebras.models import ( MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.cerebras.models import ( - SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.databricks.databricks import ( MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.databricks.databricks import ( - SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.fireworks.models import ( MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.fireworks.models import ( - SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.gemini.models import ( MODEL_ENTRIES as GEMINI_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.gemini.models import ( - SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.groq.models import ( MODEL_ENTRIES as GROQ_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.groq.models import ( - SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.nvidia.models import ( MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.nvidia.models import ( - SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.openai.models import ( MODEL_ENTRIES as OPENAI_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.openai.models import ( - SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.runpod.runpod import ( MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.runpod.runpod import ( - SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.sambanova.models import ( MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.sambanova.models import ( - SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.together.models import ( MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.together.models import ( - SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.config import ( PGVectorVectorIOConfig, @@ -111,6 +74,7 @@ from llama_stack.templates.template import ( DistributionTemplate, RunConfigSettings, get_model_registry, + get_shield_registry, ) @@ -164,28 +128,13 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]: """Get model entries for a specific provider type.""" safety_model_entries_map = { - "openai": OPENAI_SAFETY_MODELS_ENTRIES, - "fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES, - "together": TOGETHER_SAFETY_MODELS_ENTRIES, - "anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES, - "gemini": GEMINI_SAFETY_MODELS_ENTRIES, - "groq": GROQ_SAFETY_MODELS_ENTRIES, - "sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES, - "cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES, - "bedrock": BEDROCK_SAFETY_MODELS_ENTRIES, - "databricks": DATABRICKS_SAFETY_MODELS_ENTRIES, - "nvidia": NVIDIA_SAFETY_MODELS_ENTRIES, - "runpod": RUNPOD_SAFETY_MODELS_ENTRIES, - } - - # Special handling for providers with dynamic model entries - if provider_type == "ollama": - return [ + "ollama": [ ProviderModelEntry( - provider_model_id="llama-guard3:1b", + provider_model_id="${env.SAFETY_MODEL:=__disabled__}", model_type=ModelType.llm, ), - ] + ], + } return safety_model_entries_map.get(provider_type, []) @@ -246,28 +195,20 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro # build a list of shields for all possible providers -def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]: - shields = [] +def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]: + available_models = {} for provider in providers: provider_type = provider.provider_type.split("::")[1] safety_model_entries = _get_model_safety_entries_for_provider(provider_type) if len(safety_model_entries) == 0: continue - if provider.provider_id: - shield_id = provider.provider_id - else: - raise ValueError(f"Provider {provider.provider_type} has no provider_id") - for safety_model_entry in safety_model_entries: - print(f"provider.provider_id: {provider.provider_id}") - print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}") - shields.append( - ShieldInput( - provider_id="llama-guard", - shield_id=shield_id, - provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}", - ) - ) - return shields + + env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}" + provider_id = f"${{env.{env_var}:=__disabled__}}" + + available_models[provider_id] = safety_model_entries + + return available_models def get_distribution_template() -> DistributionTemplate: @@ -307,8 +248,6 @@ def get_distribution_template() -> DistributionTemplate: ), ] - shields = get_shields_for_providers(remote_inference_providers) - providers = { "inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]), "vector_io": ([p.provider_type for p in vector_io_providers]), @@ -361,7 +300,10 @@ def get_distribution_template() -> DistributionTemplate: }, ) - default_models = get_model_registry(available_models) + default_models, ids_conflict_in_models = get_model_registry(available_models) + + available_safety_models = get_safety_models_for_providers(remote_inference_providers) + shields = get_shield_registry(available_safety_models, ids_conflict_in_models) return DistributionTemplate( name=name, diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index dceb13c8b..fb2528873 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge def get_model_registry( available_models: dict[str, list[ProviderModelEntry]], -) -> list[ModelInput]: +) -> tuple[list[ModelInput], bool]: models = [] # check for conflicts in model ids @@ -74,7 +74,50 @@ def get_model_registry( metadata=entry.metadata, ) ) - return models + return models, ids_conflict + + +def get_shield_registry( + available_safety_models: dict[str, list[ProviderModelEntry]], + ids_conflict_in_models: bool, +) -> list[ShieldInput]: + shields = [] + + # check for conflicts in shield ids + all_ids = set() + ids_conflict = False + + for _, entries in available_safety_models.items(): + for entry in entries: + ids = [entry.provider_model_id] + entry.aliases + for model_id in ids: + if model_id in all_ids: + ids_conflict = True + rich.print( + f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]" + ) + break + all_ids.update(ids) + if ids_conflict: + break + if ids_conflict: + break + + for provider_id, entries in available_safety_models.items(): + for entry in entries: + ids = [entry.provider_model_id] + entry.aliases + for model_id in ids: + identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id + shields.append( + ShieldInput( + shield_id=identifier, + provider_shield_id=f"{provider_id}/{entry.provider_model_id}" + if ids_conflict_in_models + else entry.provider_model_id, + ) + ) + + return shields class DefaultModel(BaseModel): diff --git a/llama_stack/templates/watsonx/watsonx.py b/llama_stack/templates/watsonx/watsonx.py index 7fa3a55e5..ea185f05d 100644 --- a/llama_stack/templates/watsonx/watsonx.py +++ b/llama_stack/templates/watsonx/watsonx.py @@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate: }, ) - default_models = get_model_registry(available_models) + default_models, _ = get_model_registry(available_models) return DistributionTemplate( name="watsonx", distro_type="remote_hosted",