fix: Safety in starter (#2731)

- fireworks, together do not support Llama-guard 3 8b model anymore 
- Need to default to ollama 
- current safety shields logic was not correct since the shield_id was
the provider ( which had duplicates )
- Followed similar logic to models 

Note: Seems a bit over-engineered but this can now be extended to other
providers and fits in the overall mechanism of how env_vars are used to
manage starter.

### How to test 
```
ENABLE_OLLAMA=ollama ENABLE_FIREWORKS=fireworks SAFETY_MODEL=llama-guard3:1b pytest -s -v tests/integration/ --stack-config starter -k 'not(supervised_fine_tune or builtin_tool_code or safety_with_image or code_interpreter_for or rag_and_code or truncation or register_and_unregister)' --text-model fireworks/meta-llama/Llama-3.3-70B-Instruct --vision-model fireworks/meta-llama/Llama-4-Scout-17B-16E-Instruct --safety-shield llama-guard3:1b --embedding-model all-MiniLM-L6-v2
```

### Related but not obvious in this PR 
In the llama-stack-ops repo, we run tests before publishing packages and
docker containers.
The actions in that repo were using the fireworks / together distros (
which are non-existent )

So need to update that to run with `starter` and use `ollama`
specifically for safety.
This commit is contained in:
Hardik Shah 2025-07-14 15:07:40 -07:00 committed by GitHub
parent 6ad22c209f
commit 6b8a8c1be9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 104 additions and 195 deletions

View file

@ -89,7 +89,7 @@ jobs:
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="ollama/llama3.2:3b-instruct-fp16" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \
--embedding-model=all-MiniLM-L6-v2 \ --embedding-model=all-MiniLM-L6-v2 \
--safety-shield=ollama \ --safety-shield=$SAFETY_MODEL \
--color=yes \ --color=yes \
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log

View file

@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_entry, build_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16", "llama3.1:8b-instruct-fp16",
@ -73,16 +86,6 @@ MODEL_ENTRIES = [
"llama3.3:70b", "llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="all-minilm:l6-v2", provider_model_id="all-minilm:l6-v2",
aliases=["all-minilm"], aliases=["all-minilm"],
@ -100,4 +103,4 @@ MODEL_ENTRIES = [
"context_length": 8192, "context_length": 8192,
}, },
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -68,7 +68,7 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
default_models = get_model_registry(available_models) default_models, _ = get_model_registry(available_models)
return DistributionTemplate( return DistributionTemplate(
name="nvidia", name="nvidia",
distro_type="self_hosted", distro_type="self_hosted",

View file

@ -146,7 +146,8 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
default_models = get_model_registry(available_models) + [ models, _ = get_model_registry(available_models)
default_models = models + [
ModelInput( ModelInput(
model_id="meta-llama/Llama-3.3-70B-Instruct", model_id="meta-llama/Llama-3.3-70B-Instruct",
provider_id="groq", provider_id="groq",

View file

@ -1171,24 +1171,8 @@ models:
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
model_type: embedding model_type: embedding
shields: shields:
- shield_id: ${env.ENABLE_OLLAMA:=__disabled__} - shield_id: ${env.SAFETY_MODEL:=__disabled__}
provider_id: llama-guard provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b}
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b}
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision}
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo}
- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -12,7 +12,6 @@ from llama_stack.distribution.datatypes import (
ModelInput, ModelInput,
Provider, Provider,
ProviderSpec, ProviderSpec,
ShieldInput,
ToolGroupInput, ToolGroupInput,
) )
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -32,75 +31,39 @@ from llama_stack.providers.registry.inference import available_providers
from llama_stack.providers.remote.inference.anthropic.models import ( from llama_stack.providers.remote.inference.anthropic.models import (
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES, MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.anthropic.models import (
SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.bedrock.models import ( from llama_stack.providers.remote.inference.bedrock.models import (
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES, MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.bedrock.models import (
SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.cerebras.models import ( from llama_stack.providers.remote.inference.cerebras.models import (
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES, MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.cerebras.models import (
SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.databricks.databricks import ( from llama_stack.providers.remote.inference.databricks.databricks import (
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES, MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.databricks.databricks import (
SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.fireworks.models import ( from llama_stack.providers.remote.inference.fireworks.models import (
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.fireworks.models import (
SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.gemini.models import ( from llama_stack.providers.remote.inference.gemini.models import (
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES, MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.gemini.models import (
SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.groq.models import ( from llama_stack.providers.remote.inference.groq.models import (
MODEL_ENTRIES as GROQ_MODEL_ENTRIES, MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.groq.models import (
SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.nvidia.models import ( from llama_stack.providers.remote.inference.nvidia.models import (
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES, MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.nvidia.models import (
SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.openai.models import ( from llama_stack.providers.remote.inference.openai.models import (
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES, MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.openai.models import (
SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.runpod.runpod import ( from llama_stack.providers.remote.inference.runpod.runpod import (
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES, MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.runpod.runpod import (
SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.sambanova.models import ( from llama_stack.providers.remote.inference.sambanova.models import (
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES, MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.sambanova.models import (
SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.together.models import ( from llama_stack.providers.remote.inference.together.models import (
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES, MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.together.models import (
SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import ( from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig, PGVectorVectorIOConfig,
@ -111,6 +74,7 @@ from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
RunConfigSettings, RunConfigSettings,
get_model_registry, get_model_registry,
get_shield_registry,
) )
@ -164,28 +128,13 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]: def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
"""Get model entries for a specific provider type.""" """Get model entries for a specific provider type."""
safety_model_entries_map = { safety_model_entries_map = {
"openai": OPENAI_SAFETY_MODELS_ENTRIES, "ollama": [
"fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES,
"together": TOGETHER_SAFETY_MODELS_ENTRIES,
"anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES,
"gemini": GEMINI_SAFETY_MODELS_ENTRIES,
"groq": GROQ_SAFETY_MODELS_ENTRIES,
"sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES,
"cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES,
"bedrock": BEDROCK_SAFETY_MODELS_ENTRIES,
"databricks": DATABRICKS_SAFETY_MODELS_ENTRIES,
"nvidia": NVIDIA_SAFETY_MODELS_ENTRIES,
"runpod": RUNPOD_SAFETY_MODELS_ENTRIES,
}
# Special handling for providers with dynamic model entries
if provider_type == "ollama":
return [
ProviderModelEntry( ProviderModelEntry(
provider_model_id="llama-guard3:1b", provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm, model_type=ModelType.llm,
), ),
] ],
}
return safety_model_entries_map.get(provider_type, []) return safety_model_entries_map.get(provider_type, [])
@ -246,28 +195,20 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
# build a list of shields for all possible providers # build a list of shields for all possible providers
def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]: def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
shields = [] available_models = {}
for provider in providers: for provider in providers:
provider_type = provider.provider_type.split("::")[1] provider_type = provider.provider_type.split("::")[1]
safety_model_entries = _get_model_safety_entries_for_provider(provider_type) safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
if len(safety_model_entries) == 0: if len(safety_model_entries) == 0:
continue continue
if provider.provider_id:
shield_id = provider.provider_id env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
else: provider_id = f"${{env.{env_var}:=__disabled__}}"
raise ValueError(f"Provider {provider.provider_type} has no provider_id")
for safety_model_entry in safety_model_entries: available_models[provider_id] = safety_model_entries
print(f"provider.provider_id: {provider.provider_id}")
print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}") return available_models
shields.append(
ShieldInput(
provider_id="llama-guard",
shield_id=shield_id,
provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}",
)
)
return shields
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
@ -307,8 +248,6 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
shields = get_shields_for_providers(remote_inference_providers)
providers = { providers = {
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]), "inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
"vector_io": ([p.provider_type for p in vector_io_providers]), "vector_io": ([p.provider_type for p in vector_io_providers]),
@ -361,7 +300,10 @@ def get_distribution_template() -> DistributionTemplate:
}, },
) )
default_models = get_model_registry(available_models) default_models, ids_conflict_in_models = get_model_registry(available_models)
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,

View file

@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge
def get_model_registry( def get_model_registry(
available_models: dict[str, list[ProviderModelEntry]], available_models: dict[str, list[ProviderModelEntry]],
) -> list[ModelInput]: ) -> tuple[list[ModelInput], bool]:
models = [] models = []
# check for conflicts in model ids # check for conflicts in model ids
@ -74,7 +74,50 @@ def get_model_registry(
metadata=entry.metadata, metadata=entry.metadata,
) )
) )
return models return models, ids_conflict
def get_shield_registry(
available_safety_models: dict[str, list[ProviderModelEntry]],
ids_conflict_in_models: bool,
) -> list[ShieldInput]:
shields = []
# check for conflicts in shield ids
all_ids = set()
ids_conflict = False
for _, entries in available_safety_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
if model_id in all_ids:
ids_conflict = True
rich.print(
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
)
break
all_ids.update(ids)
if ids_conflict:
break
if ids_conflict:
break
for provider_id, entries in available_safety_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
shields.append(
ShieldInput(
shield_id=identifier,
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
if ids_conflict_in_models
else entry.provider_model_id,
)
)
return shields
class DefaultModel(BaseModel): class DefaultModel(BaseModel):

View file

@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate:
}, },
) )
default_models = get_model_registry(available_models) default_models, _ = get_model_registry(available_models)
return DistributionTemplate( return DistributionTemplate(
name="watsonx", name="watsonx",
distro_type="remote_hosted", distro_type="remote_hosted",

View file

@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id):
return agent_config return agent_config
@pytest.fixture(scope="session")
def agent_config_without_safety(text_model_id):
agent_config = dict(
model=text_model_id,
instructions="You are a helpful assistant",
sampling_params={
"strategy": {
"type": "top_p",
"temperature": 0.0001,
"top_p": 0.9,
},
},
tools=[],
enable_session_persistence=False,
)
return agent_config
def test_agent_simple(llama_stack_client, agent_config): def test_agent_simple(llama_stack_client, agent_config):
agent = Agent(llama_stack_client, **agent_config) agent = Agent(llama_stack_client, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
assert expected_kw in response.output_message.content.lower() assert expected_kw in response.output_message.content.lower()
def test_rag_agent_with_attachments(llama_stack_client, agent_config): def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
urls = ["llama3.rst", "lora_finetune.rst"] urls = ["llama3.rst", "lora_finetune.rst"]
documents = [ documents = [
# passign as url # passign as url
@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
metadata={}, metadata={},
), ),
] ]
rag_agent = Agent(llama_stack_client, **agent_config) rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
session_id = rag_agent.create_session(f"test-session-{uuid4()}") session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped",
),
]
user_prompts = [ user_prompts = [
( (
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.", "I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
assert "lora" in response.output_message.content.lower() assert "lora" in response.output_message.content.lower()
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
def test_rag_and_code_agent(llama_stack_client, agent_config):
if "llama-4" in agent_config["model"].lower():
pytest.xfail("Not working for llama4")
documents = []
documents.append(
Document(
document_id="nba_wiki",
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
metadata={},
)
)
documents.append(
Document(
document_id="perplexity_wiki",
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
Srinivas, the CEO, worked at OpenAI as an AI researcher.
Konwinski was among the founding team at Databricks.
Yarats, the CTO, was an AI research scientist at Meta.
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
metadata={},
)
)
vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=128,
)
agent_config = {
**agent_config,
"tools": [
dict(
name="builtin::rag/knowledge_search",
args={"vector_db_ids": [vector_db_id]},
),
"builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client, **agent_config)
user_prompts = [
(
"when was Perplexity the company founded?",
[],
"knowledge_search",
"2022",
),
(
"when was the nba created?",
[],
"knowledge_search",
"1949",
),
]
for prompt, docs, tool_name, expected_kw in user_prompts:
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
documents=docs,
stream=False,
)
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}"
if expected_kw:
assert expected_kw in response.output_message.content.lower()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"client_tools", "client_tools",
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)], [(get_boiling_point, False), (get_boiling_point_with_metadata, True)],