ci: test safety with starter (#2628)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 39s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 12s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 50s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.12, safety) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.13, agents) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.13, inference) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, datasets) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 1m10s
Integration Tests / test-matrix (library, 3.13, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.13, safety) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.13, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.13, tool_runtime) (push) Failing after 16s
Integration Tests / test-matrix (library, 3.13, vector_io) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, inference) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, datasets) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, agents) (push) Failing after 17s
Integration Tests / test-matrix (server, 3.12, inspect) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1m30s
Integration Tests / test-matrix (server, 3.12, safety) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.12, providers) (push) Failing after 15s
Integration Tests / test-matrix (server, 3.12, scoring) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.13, agents) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.12, vector_io) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.13, datasets) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.13, inference) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.12, post_training) (push) Failing after 25s
Integration Tests / test-matrix (server, 3.13, inspect) (push) Failing after 7s
Integration Tests / test-matrix (server, 3.13, providers) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.13, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, scoring) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 13s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 7s
Integration Tests / test-matrix (server, 3.13, safety) (push) Failing after 25s
Integration Tests / test-matrix (server, 3.13, post_training) (push) Failing after 27s
Integration Tests / test-matrix (server, 3.13, tool_runtime) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 9s
Test Llama Stack Build / generate-matrix (push) Successful in 14s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 16s
Test Llama Stack Build / build-single-provider (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, tool_runtime) (push) Failing after 1m7s
Update ReadTheDocs / update-readthedocs (push) Failing after 12s
Unit Tests / unit-tests (3.13) (push) Failing after 14s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 29s
Test External Providers / test-external-providers (venv) (push) Failing after 17s
Test Llama Stack Build / build (push) Failing after 13s
Unit Tests / unit-tests (3.12) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 35s
Python Package Build Test / build (3.12) (push) Failing after 31s
Python Package Build Test / build (3.13) (push) Failing after 29s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 34s
Pre-commit / pre-commit (push) Successful in 1m24s

# What does this PR do?

We are now testing the safety capability with the starter image. This
includes a few changes:

* Enable the safety integration test
* Relax the shield model requirements from llama-guard to make it work
  with llama-guard3:8b coming from Ollama
* Expose a shield for each inference provider in the starter distro. The
  shield will only be registered if the provider is enabled.

Closes: https://github.com/meta-llama/llama-stack/issues/2528

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-07-09 16:53:50 +02:00 committed by GitHub
parent de01eefdef
commit 9b7eecebcf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 621 additions and 126 deletions

View file

@ -7,3 +7,5 @@ runs:
shell: bash shell: bash
run: | run: |
docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models
# TODO: rebuild an ollama image with llama-guard3:1b
docker exec ollama ollama pull llama-guard3:1b

View file

@ -24,7 +24,7 @@ jobs:
matrix: matrix:
# Listing tests manually since some of them currently fail # Listing tests manually since some of them currently fail
# TODO: generate matrix list from tests/integration when fixed # TODO: generate matrix list from tests/integration when fixed
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime, vector_io] test-type: [agents, inference, datasets, inspect, safety, scoring, post_training, providers, tool_runtime, vector_io]
client-type: [library, server] client-type: [library, server]
python-version: ["3.12", "3.13"] python-version: ["3.12", "3.13"]
fail-fast: false # we want to run all tests regardless of failure fail-fast: false # we want to run all tests regardless of failure
@ -51,11 +51,23 @@ jobs:
free -h free -h
df -h df -h
- name: Verify Ollama status is OK
if: matrix.client-type == 'http'
run: |
echo "Verifying Ollama status..."
ollama_status=$(curl -s -L http://127.0.0.1:8321/v1/providers/ollama|jq --raw-output .health.status)
echo "Ollama status: $ollama_status"
if [ "$ollama_status" != "OK" ]; then
echo "Ollama health check failed"
exit 1
fi
- name: Run Integration Tests - name: Run Integration Tests
env: env:
OLLAMA_INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" # for server tests OLLAMA_INFERENCE_MODEL: "llama3.2:3b-instruct-fp16" # for server tests
ENABLE_OLLAMA: "ollama" # for server tests ENABLE_OLLAMA: "ollama" # for server tests
OLLAMA_URL: "http://0.0.0.0:11434" OLLAMA_URL: "http://0.0.0.0:11434"
SAFETY_MODEL: "llama-guard3:1b"
# Use 'shell' to get pipefail behavior # Use 'shell' to get pipefail behavior
# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference
# TODO: write a precommit hook to detect if a test contains a pipe but does not use 'shell: bash' # TODO: write a precommit hook to detect if a test contains a pipe but does not use 'shell: bash'
@ -68,8 +80,9 @@ jobs:
fi fi
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="ollama/meta-llama/Llama-3.2-3B-Instruct" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \
--embedding-model=all-MiniLM-L6-v2 \ --embedding-model=all-MiniLM-L6-v2 \
--safety-shield=ollama \
--color=yes \ --color=yes \
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log

View file

@ -98,6 +98,7 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method) method = getattr(impls[api], register_method)
for obj in objects: for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# Do not register models on disabled providers # Do not register models on disabled providers
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__": if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.") logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
@ -112,6 +113,11 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
): ):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.") logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
continue continue
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.")
continue
# we want to maintain the type information in arguments to method. # we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict, # instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
# we use model_dump() to find all the attrs and then getattr to get the still typed value. # we use model_dump() to find all the attrs and then getattr to get the still typed value.

View file

@ -146,10 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass pass
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS: # Allow any model to be registered as a shield
raise ValueError( # The model will be validated during runtime when making inference calls
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}" pass
)
async def run_shield( async def run_shield(
self, self,
@ -167,11 +166,25 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if len(messages) > 0 and messages[0].role != Role.user.value: if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content) messages[0] = UserMessage(content=messages[0].content)
model = LLAMA_GUARD_MODEL_IDS[shield.provider_resource_id] # Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
model_id = shield.provider_resource_id
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
if model_id in LLAMA_GUARD_MODEL_IDS:
# Use the mapped model for categories but the original model_id for inference
mapped_model = LLAMA_GUARD_MODEL_IDS[model_id]
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
else:
# For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
impl = LlamaGuardShield( impl = LlamaGuardShield(
model=model, model=model_id,
inference_api=self.inference_api, inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories, excluded_categories=self.config.excluded_categories,
safety_categories=safety_categories,
) )
return await impl.run(messages) return await impl.run(messages)
@ -183,20 +196,21 @@ class LlamaGuardShield:
model: str, model: str,
inference_api: Inference, inference_api: Inference,
excluded_categories: list[str] | None = None, excluded_categories: list[str] | None = None,
safety_categories: list[str] | None = None,
): ):
if excluded_categories is None: if excluded_categories is None:
excluded_categories = [] excluded_categories = []
if safety_categories is None:
safety_categories = []
assert len(excluded_categories) == 0 or all( assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
if model not in MODEL_TO_SAFETY_CATEGORIES_MAP:
raise ValueError(f"Unsupported model: {model}")
self.model = model self.model = model
self.inference_api = inference_api self.inference_api = inference_api
self.excluded_categories = excluded_categories self.excluded_categories = excluded_categories
self.safety_categories = safety_categories
def check_unsafe_response(self, response: str) -> str | None: def check_unsafe_response(self, response: str) -> str | None:
match = re.match(r"^unsafe\n(.*)$", response) match = re.match(r"^unsafe\n(.*)$", response)
@ -214,7 +228,7 @@ class LlamaGuardShield:
final_categories = [] final_categories = []
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.model] all_categories = self.safety_categories
for cat in all_categories: for cat in all_categories:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat] cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories: if cat_code in excluded_categories:

View file

@ -15,21 +15,26 @@ LLM_MODEL_IDS = [
"anthropic/claude-3-5-haiku-latest", "anthropic/claude-3-5-haiku-latest",
] ]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [ MODEL_ENTRIES = (
ProviderModelEntry( [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
provider_model_id="anthropic/voyage-3", + [
model_type=ModelType.embedding, ProviderModelEntry(
metadata={"embedding_dimension": 1024, "context_length": 32000}, provider_model_id="anthropic/voyage-3",
), model_type=ModelType.embedding,
ProviderModelEntry( metadata={"embedding_dimension": 1024, "context_length": 32000},
provider_model_id="anthropic/voyage-3-lite", ),
model_type=ModelType.embedding, ProviderModelEntry(
metadata={"embedding_dimension": 512, "context_length": 32000}, provider_model_id="anthropic/voyage-3-lite",
), model_type=ModelType.embedding,
ProviderModelEntry( metadata={"embedding_dimension": 512, "context_length": 32000},
provider_model_id="anthropic/voyage-code-3", ),
model_type=ModelType.embedding, ProviderModelEntry(
metadata={"embedding_dimension": 1024, "context_length": 32000}, provider_model_id="anthropic/voyage-code-3",
), model_type=ModelType.embedding,
] metadata={"embedding_dimension": 1024, "context_length": 32000},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -9,6 +9,10 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = []
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"meta.llama3-1-8b-instruct-v1:0", "meta.llama3-1-8b-instruct-v1:0",
@ -22,4 +26,4 @@ MODEL_ENTRIES = [
"meta.llama3-1-405b-instruct-v1:0", "meta.llama3-1-405b-instruct-v1:0",
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -9,6 +9,9 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = []
# https://inference-docs.cerebras.ai/models
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"llama3.1-8b", "llama3.1-8b",
@ -18,4 +21,8 @@ MODEL_ENTRIES = [
"llama-3.3-70b", "llama-3.3-70b",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
] build_hf_repo_model_entry(
"llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -47,7 +47,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import DatabricksImplConfig from .config import DatabricksImplConfig
model_entries = [ SAFETY_MODELS_ENTRIES = []
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"databricks-meta-llama-3-1-70b-instruct", "databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
@ -56,7 +59,7 @@ model_entries = [
"databricks-meta-llama-3-1-405b-instruct", "databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
), ),
] ] + SAFETY_MODELS_ENTRIES
class DatabricksInferenceAdapter( class DatabricksInferenceAdapter(
@ -66,7 +69,7 @@ class DatabricksInferenceAdapter(
OpenAICompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin,
): ):
def __init__(self, config: DatabricksImplConfig) -> None: def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=model_entries) ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -11,6 +11,17 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-8b-instruct", "accounts/fireworks/models/llama-v3p1-8b-instruct",
@ -40,14 +51,6 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-v3p3-70b-instruct", "accounts/fireworks/models/llama-v3p3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
build_hf_repo_model_entry( build_hf_repo_model_entry(
"accounts/fireworks/models/llama4-scout-instruct-basic", "accounts/fireworks/models/llama4-scout-instruct-basic",
CoreModelId.llama4_scout_17b_16e_instruct.value, CoreModelId.llama4_scout_17b_16e_instruct.value,
@ -64,4 +67,4 @@ MODEL_ENTRIES = [
"context_length": 8192, "context_length": 8192,
}, },
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -17,11 +17,16 @@ LLM_MODEL_IDS = [
"gemini/gemini-2.5-pro", "gemini/gemini-2.5-pro",
] ]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [ MODEL_ENTRIES = (
ProviderModelEntry( [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
provider_model_id="gemini/text-embedding-004", + [
model_type=ModelType.embedding, ProviderModelEntry(
metadata={"embedding_dimension": 768, "context_length": 2048}, provider_model_id="gemini/text-embedding-004",
), model_type=ModelType.embedding,
] metadata={"embedding_dimension": 768, "context_length": 2048},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -10,6 +10,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_entry, build_model_entry,
) )
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"groq/llama3-8b-8192", "groq/llama3-8b-8192",
@ -51,4 +53,4 @@ MODEL_ENTRIES = [
"groq/meta-llama/llama-4-maverick-17b-128e-instruct", "groq/meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value, CoreModelId.llama4_maverick_17b_128e_instruct.value,
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -11,6 +11,9 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = []
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"meta/llama3-8b-instruct", "meta/llama3-8b-instruct",
@ -99,4 +102,4 @@ MODEL_ENTRIES = [
), ),
# TODO(mf): how do we handle Nemotron models? # TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct", # "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
] ] + SAFETY_MODELS_ENTRIES

View file

@ -48,16 +48,20 @@ EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192), "text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192), "text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
} }
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = (
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [ [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
ProviderModelEntry( + [
provider_model_id=model_id, ProviderModelEntry(
model_type=ModelType.embedding, provider_model_id=model_id,
metadata={ model_type=ModelType.embedding,
"embedding_dimension": model_info.embedding_dimension, metadata={
"context_length": model_info.context_length, "embedding_dimension": model_info.embedding_dimension,
}, "context_length": model_info.context_length,
) },
for model_id, model_info in EMBEDDING_MODEL_IDS.items() )
] for model_id, model_info in EMBEDDING_MODEL_IDS.items()
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -11,7 +11,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate # from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin,
@ -25,6 +25,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import RunpodImplConfig from .config import RunpodImplConfig
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
RUNPOD_SUPPORTED_MODELS = { RUNPOD_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B", "Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B", "Llama3.1-70B": "meta-llama/Llama-3.1-70B",
@ -40,6 +42,14 @@ RUNPOD_SUPPORTED_MODELS = {
"Llama3.2-3B": "meta-llama/Llama-3.2-3B", "Llama3.2-3B": "meta-llama/Llama-3.2-3B",
} }
SAFETY_MODELS_ENTRIES = []
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES
class RunpodInferenceAdapter( class RunpodInferenceAdapter(
ModelRegistryHelper, ModelRegistryHelper,

View file

@ -9,6 +9,14 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.1-8B-Instruct", "sambanova/Meta-Llama-3.1-8B-Instruct",
@ -46,8 +54,4 @@ MODEL_ENTRIES = [
"sambanova/Llama-4-Maverick-17B-128E-Instruct", "sambanova/Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value, CoreModelId.llama4_maverick_17b_128e_instruct.value,
), ),
build_hf_repo_model_entry( ] + SAFETY_MODELS_ENTRIES
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]

View file

@ -11,6 +11,16 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
@ -40,14 +50,6 @@ MODEL_ENTRIES = [
"meta-llama/Llama-3.3-70B-Instruct-Turbo", "meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval", provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
model_type=ModelType.embedding, model_type=ModelType.embedding,
@ -78,4 +80,4 @@ MODEL_ENTRIES = [
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
], ],
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -44,6 +44,7 @@ def build_hf_repo_model_entry(
] ]
if additional_aliases: if additional_aliases:
aliases.extend(additional_aliases) aliases.extend(additional_aliases)
aliases = [alias for alias in aliases if alias is not None]
return ProviderModelEntry( return ProviderModelEntry(
provider_model_id=provider_model_id, provider_model_id=provider_model_id,
aliases=aliases, aliases=aliases,
@ -90,7 +91,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model # embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id provider_resource_id = model.provider_resource_id
if provider_resource_id: if provider_resource_id:
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences if provider_resource_id != supported_model_id: # be idempotent, only reject differences
raise ValueError( raise ValueError(
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first." f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
) )

View file

@ -256,11 +256,46 @@ inference_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db
models: models:
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama3.1-8b
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama3.1-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama3.1-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama-3.3-70b
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-3.3-70b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-3.3-70b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama-4-scout-17b-16e-instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_INFERENCE_MODEL:=__disabled__} model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_INFERENCE_MODEL:=__disabled__}
provider_id: ${env.ENABLE_OLLAMA:=__disabled__} provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
provider_model_id: ${env.OLLAMA_INFERENCE_MODEL:=__disabled__} provider_model_id: ${env.OLLAMA_INFERENCE_MODEL:=__disabled__}
model_type: llm model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
provider_model_id: ${env.SAFETY_MODEL:=__disabled__}
model_type: llm
- metadata: - metadata:
embedding_dimension: ${env.OLLAMA_EMBEDDING_DIMENSION:=384} embedding_dimension: ${env.OLLAMA_EMBEDDING_DIMENSION:=384}
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_EMBEDDING_MODEL:=__disabled__} model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}
@ -342,26 +377,6 @@ models:
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-8b
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-11b-vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama4-scout-instruct-basic model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama4-scout-instruct-basic
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
@ -389,6 +404,26 @@ models:
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: nomic-ai/nomic-embed-text-v1.5 provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding model_type: embedding
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-8b
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-11b-vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo
provider_id: ${env.ENABLE_TOGETHER:=__disabled__} provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
@ -459,26 +494,6 @@ models:
provider_id: ${env.ENABLE_TOGETHER:=__disabled__} provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo
model_type: llm model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Meta-Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Meta-Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision-Turbo
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: - metadata:
embedding_dimension: 768 embedding_dimension: 768
context_length: 8192 context_length: 8192
@ -523,6 +538,264 @@ models:
provider_id: ${env.ENABLE_TOGETHER:=__disabled__} provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
model_type: llm model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision-Turbo
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-8b-instruct-v1:0
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-8b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-8b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-70b-instruct-v1:0
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-70b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-70b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-405b-instruct-v1:0
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-405b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-405b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/databricks-meta-llama-3-1-70b-instruct
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/databricks-meta-llama-3-1-405b-instruct
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama3-8b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3-8B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama3-70b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3-70B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-8b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-70b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-405b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-1b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-1B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-3b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-3B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-11b-vision-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-90b-vision-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.3-70b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata:
embedding_dimension: 2048
context_length: 8192
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/llama-3.2-nv-embedqa-1b-v2
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/nv-embedqa-e5-v5
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: nvidia/nv-embedqa-e5-v5
model_type: embedding
- metadata:
embedding_dimension: 4096
context_length: 512
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/nv-embedqa-mistral-7b-v2
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/snowflake/arctic-embed-l
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: snowflake/arctic-embed-l
model_type: embedding
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-8B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-70B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-70B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B:bf16-mp8
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B:bf16-mp8
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B:bf16-mp16
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B:bf16-mp16
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-8B-Instruct
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-8B-Instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-70B-Instruct
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-70B-Instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct:bf16-mp8
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B-Instruct:bf16-mp8
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B-Instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct:bf16-mp16
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B-Instruct:bf16-mp16
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.2-1B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.2-1B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.2-3B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.2-3B
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o
provider_id: ${env.ENABLE_OPENAI:=__disabled__} provider_id: ${env.ENABLE_OPENAI:=__disabled__}
@ -894,7 +1167,25 @@ models:
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
model_type: embedding model_type: embedding
shields: [] shields:
- shield_id: ${env.ENABLE_OLLAMA:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b}
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b}
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision}
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo}
- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -12,6 +12,7 @@ from llama_stack.distribution.datatypes import (
ModelInput, ModelInput,
Provider, Provider,
ProviderSpec, ProviderSpec,
ShieldInput,
ToolGroupInput, ToolGroupInput,
) )
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -31,24 +32,75 @@ from llama_stack.providers.registry.inference import available_providers
from llama_stack.providers.remote.inference.anthropic.models import ( from llama_stack.providers.remote.inference.anthropic.models import (
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES, MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.anthropic.models import (
SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.bedrock.models import (
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.bedrock.models import (
SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.cerebras.models import (
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.cerebras.models import (
SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.databricks.databricks import (
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.databricks.databricks import (
SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.fireworks.models import ( from llama_stack.providers.remote.inference.fireworks.models import (
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.fireworks.models import (
SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.gemini.models import ( from llama_stack.providers.remote.inference.gemini.models import (
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES, MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.gemini.models import (
SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.groq.models import ( from llama_stack.providers.remote.inference.groq.models import (
MODEL_ENTRIES as GROQ_MODEL_ENTRIES, MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.groq.models import (
SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.nvidia.models import (
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.nvidia.models import (
SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.openai.models import ( from llama_stack.providers.remote.inference.openai.models import (
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES, MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.openai.models import (
SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.runpod.runpod import (
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.runpod.runpod import (
SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.sambanova.models import ( from llama_stack.providers.remote.inference.sambanova.models import (
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES, MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.sambanova.models import (
SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.together.models import ( from llama_stack.providers.remote.inference.together.models import (
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES, MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.together.models import (
SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import ( from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig, PGVectorVectorIOConfig,
@ -72,6 +124,11 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
"gemini": GEMINI_MODEL_ENTRIES, "gemini": GEMINI_MODEL_ENTRIES,
"groq": GROQ_MODEL_ENTRIES, "groq": GROQ_MODEL_ENTRIES,
"sambanova": SAMBANOVA_MODEL_ENTRIES, "sambanova": SAMBANOVA_MODEL_ENTRIES,
"cerebras": CEREBRAS_MODEL_ENTRIES,
"bedrock": BEDROCK_MODEL_ENTRIES,
"databricks": DATABRICKS_MODEL_ENTRIES,
"nvidia": NVIDIA_MODEL_ENTRIES,
"runpod": RUNPOD_MODEL_ENTRIES,
} }
# Special handling for providers with dynamic model entries # Special handling for providers with dynamic model entries
@ -81,6 +138,10 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}", provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
model_type=ModelType.llm, model_type=ModelType.llm,
), ),
ProviderModelEntry(
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}", provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
model_type=ModelType.embedding, model_type=ModelType.embedding,
@ -100,6 +161,35 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
return model_entries_map.get(provider_type, []) return model_entries_map.get(provider_type, [])
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
"""Get model entries for a specific provider type."""
safety_model_entries_map = {
"openai": OPENAI_SAFETY_MODELS_ENTRIES,
"fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES,
"together": TOGETHER_SAFETY_MODELS_ENTRIES,
"anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES,
"gemini": GEMINI_SAFETY_MODELS_ENTRIES,
"groq": GROQ_SAFETY_MODELS_ENTRIES,
"sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES,
"cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES,
"bedrock": BEDROCK_SAFETY_MODELS_ENTRIES,
"databricks": DATABRICKS_SAFETY_MODELS_ENTRIES,
"nvidia": NVIDIA_SAFETY_MODELS_ENTRIES,
"runpod": RUNPOD_SAFETY_MODELS_ENTRIES,
}
# Special handling for providers with dynamic model entries
if provider_type == "ollama":
return [
ProviderModelEntry(
provider_model_id="llama-guard3:1b",
model_type=ModelType.llm,
),
]
return safety_model_entries_map.get(provider_type, [])
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]: def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
"""Get configuration for a provider using its adapter's config class.""" """Get configuration for a provider using its adapter's config class."""
config_class = instantiate_class_type(provider_spec.config_class) config_class = instantiate_class_type(provider_spec.config_class)
@ -155,6 +245,31 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
return inference_providers, available_models return inference_providers, available_models
# build a list of shields for all possible providers
def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]:
shields = []
for provider in providers:
provider_type = provider.provider_type.split("::")[1]
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
if len(safety_model_entries) == 0:
continue
if provider.provider_id:
shield_id = provider.provider_id
else:
raise ValueError(f"Provider {provider.provider_type} has no provider_id")
for safety_model_entry in safety_model_entries:
print(f"provider.provider_id: {provider.provider_id}")
print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}")
shields.append(
ShieldInput(
provider_id="llama-guard",
shield_id=shield_id,
provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}",
)
)
return shields
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
remote_inference_providers, available_models = get_remote_inference_providers() remote_inference_providers, available_models = get_remote_inference_providers()
@ -192,6 +307,8 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
shields = get_shields_for_providers(remote_inference_providers)
providers = { providers = {
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]), "inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
"vector_io": ([p.provider_type for p in vector_io_providers]), "vector_io": ([p.provider_type for p in vector_io_providers]),
@ -266,9 +383,7 @@ def get_distribution_template() -> DistributionTemplate:
default_models=default_models + [embedding_model], default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
# TODO: add a way to enable/disable shields on the fly # TODO: add a way to enable/disable shields on the fly
# default_shields=[ default_shields=shields,
# ShieldInput(provider_id="llama-guard", shield_id="${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}")
# ],
), ),
}, },
run_config_env_vars={ run_config_env_vars={

View file

@ -7,7 +7,8 @@ FROM --platform=linux/amd64 ollama/ollama:latest
RUN ollama serve & \ RUN ollama serve & \
sleep 5 && \ sleep 5 && \
ollama pull llama3.2:3b-instruct-fp16 && \ ollama pull llama3.2:3b-instruct-fp16 && \
ollama pull all-minilm:l6-v2 ollama pull all-minilm:l6-v2 && \
ollama pull llama-guard3:1b
# Set the entrypoint to start ollama serve # Set the entrypoint to start ollama serve
ENTRYPOINT ["ollama", "serve"] ENTRYPOINT ["ollama", "serve"]