mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-16 01:53:10 +00:00
ci: test safety with starter (#2628)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 39s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 12s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 50s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.12, safety) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.13, agents) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.13, inference) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, datasets) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 1m10s
Integration Tests / test-matrix (library, 3.13, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.13, safety) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.13, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.13, tool_runtime) (push) Failing after 16s
Integration Tests / test-matrix (library, 3.13, vector_io) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, inference) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, datasets) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, agents) (push) Failing after 17s
Integration Tests / test-matrix (server, 3.12, inspect) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1m30s
Integration Tests / test-matrix (server, 3.12, safety) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.12, providers) (push) Failing after 15s
Integration Tests / test-matrix (server, 3.12, scoring) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.13, agents) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.12, vector_io) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.13, datasets) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.13, inference) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.12, post_training) (push) Failing after 25s
Integration Tests / test-matrix (server, 3.13, inspect) (push) Failing after 7s
Integration Tests / test-matrix (server, 3.13, providers) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.13, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, scoring) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 13s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 7s
Integration Tests / test-matrix (server, 3.13, safety) (push) Failing after 25s
Integration Tests / test-matrix (server, 3.13, post_training) (push) Failing after 27s
Integration Tests / test-matrix (server, 3.13, tool_runtime) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 9s
Test Llama Stack Build / generate-matrix (push) Successful in 14s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 16s
Test Llama Stack Build / build-single-provider (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, tool_runtime) (push) Failing after 1m7s
Update ReadTheDocs / update-readthedocs (push) Failing after 12s
Unit Tests / unit-tests (3.13) (push) Failing after 14s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 29s
Test External Providers / test-external-providers (venv) (push) Failing after 17s
Test Llama Stack Build / build (push) Failing after 13s
Unit Tests / unit-tests (3.12) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 35s
Python Package Build Test / build (3.12) (push) Failing after 31s
Python Package Build Test / build (3.13) (push) Failing after 29s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 34s
Pre-commit / pre-commit (push) Successful in 1m24s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 39s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 12s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 50s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.12, safety) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.13, agents) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.13, inference) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, datasets) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 1m10s
Integration Tests / test-matrix (library, 3.13, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.13, safety) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.13, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.13, tool_runtime) (push) Failing after 16s
Integration Tests / test-matrix (library, 3.13, vector_io) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, inference) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, datasets) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, agents) (push) Failing after 17s
Integration Tests / test-matrix (server, 3.12, inspect) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1m30s
Integration Tests / test-matrix (server, 3.12, safety) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.12, providers) (push) Failing after 15s
Integration Tests / test-matrix (server, 3.12, scoring) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.13, agents) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.12, vector_io) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.13, datasets) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.13, inference) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.12, post_training) (push) Failing after 25s
Integration Tests / test-matrix (server, 3.13, inspect) (push) Failing after 7s
Integration Tests / test-matrix (server, 3.13, providers) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.13, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, scoring) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 13s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 7s
Integration Tests / test-matrix (server, 3.13, safety) (push) Failing after 25s
Integration Tests / test-matrix (server, 3.13, post_training) (push) Failing after 27s
Integration Tests / test-matrix (server, 3.13, tool_runtime) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 9s
Test Llama Stack Build / generate-matrix (push) Successful in 14s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 16s
Test Llama Stack Build / build-single-provider (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, tool_runtime) (push) Failing after 1m7s
Update ReadTheDocs / update-readthedocs (push) Failing after 12s
Unit Tests / unit-tests (3.13) (push) Failing after 14s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 29s
Test External Providers / test-external-providers (venv) (push) Failing after 17s
Test Llama Stack Build / build (push) Failing after 13s
Unit Tests / unit-tests (3.12) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 35s
Python Package Build Test / build (3.12) (push) Failing after 31s
Python Package Build Test / build (3.13) (push) Failing after 29s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 34s
Pre-commit / pre-commit (push) Successful in 1m24s
# What does this PR do? We are now testing the safety capability with the starter image. This includes a few changes: * Enable the safety integration test * Relax the shield model requirements from llama-guard to make it work with llama-guard3:8b coming from Ollama * Expose a shield for each inference provider in the starter distro. The shield will only be registered if the provider is enabled. Closes: https://github.com/meta-llama/llama-stack/issues/2528 Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
de01eefdef
commit
9b7eecebcf
20 changed files with 621 additions and 126 deletions
|
@ -146,10 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS:
|
||||
raise ValueError(
|
||||
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}"
|
||||
)
|
||||
# Allow any model to be registered as a shield
|
||||
# The model will be validated during runtime when making inference calls
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
@ -167,11 +166,25 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
model = LLAMA_GUARD_MODEL_IDS[shield.provider_resource_id]
|
||||
# Use the inference API's model resolution instead of hardcoded mappings
|
||||
# This allows the shield to work with any registered model
|
||||
model_id = shield.provider_resource_id
|
||||
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
if model_id in LLAMA_GUARD_MODEL_IDS:
|
||||
# Use the mapped model for categories but the original model_id for inference
|
||||
mapped_model = LLAMA_GUARD_MODEL_IDS[model_id]
|
||||
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
|
||||
else:
|
||||
# For unknown models, use default Llama Guard 3 8B categories
|
||||
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
||||
|
||||
impl = LlamaGuardShield(
|
||||
model=model,
|
||||
model=model_id,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
safety_categories=safety_categories,
|
||||
)
|
||||
|
||||
return await impl.run(messages)
|
||||
|
@ -183,20 +196,21 @@ class LlamaGuardShield:
|
|||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: list[str] | None = None,
|
||||
safety_categories: list[str] | None = None,
|
||||
):
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
if safety_categories is None:
|
||||
safety_categories = []
|
||||
|
||||
assert len(excluded_categories) == 0 or all(
|
||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||
|
||||
if model not in MODEL_TO_SAFETY_CATEGORIES_MAP:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
self.model = model
|
||||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
self.safety_categories = safety_categories
|
||||
|
||||
def check_unsafe_response(self, response: str) -> str | None:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
|
@ -214,7 +228,7 @@ class LlamaGuardShield:
|
|||
|
||||
final_categories = []
|
||||
|
||||
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.model]
|
||||
all_categories = self.safety_categories
|
||||
for cat in all_categories:
|
||||
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
|
||||
if cat_code in excluded_categories:
|
||||
|
|
|
@ -15,21 +15,26 @@ LLM_MODEL_IDS = [
|
|||
"anthropic/claude-3-5-haiku-latest",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
]
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
||||
|
|
|
@ -9,6 +9,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
|
@ -22,4 +26,4 @@ MODEL_ENTRIES = [
|
|||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
@ -9,6 +9,9 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://inference-docs.cerebras.ai/models
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1-8b",
|
||||
|
@ -18,4 +21,8 @@ MODEL_ENTRIES = [
|
|||
"llama-3.3-70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
]
|
||||
build_hf_repo_model_entry(
|
||||
"llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
@ -47,7 +47,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
model_entries = [
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
|
@ -56,7 +59,7 @@ model_entries = [
|
|||
"databricks-meta-llama-3-1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(
|
||||
|
@ -66,7 +69,7 @@ class DatabricksInferenceAdapter(
|
|||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=model_entries)
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -11,6 +11,17 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
|
@ -40,14 +51,6 @@ MODEL_ENTRIES = [
|
|||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama4-scout-instruct-basic",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
|
@ -64,4 +67,4 @@ MODEL_ENTRIES = [
|
|||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
@ -17,11 +17,16 @@ LLM_MODEL_IDS = [
|
|||
"gemini/gemini-2.5-pro",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
]
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
||||
|
|
|
@ -10,6 +10,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-8b-8192",
|
||||
|
@ -51,4 +53,4 @@ MODEL_ENTRIES = [
|
|||
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
@ -11,6 +11,9 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama3-8b-instruct",
|
||||
|
@ -99,4 +102,4 @@ MODEL_ENTRIES = [
|
|||
),
|
||||
# TODO(mf): how do we handle Nemotron models?
|
||||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
@ -48,16 +48,20 @@ EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
|
|||
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
|
||||
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
|
||||
}
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||
ProviderModelEntry(
|
||||
provider_model_id=model_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": model_info.embedding_dimension,
|
||||
"context_length": model_info.context_length,
|
||||
},
|
||||
)
|
||||
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
|
||||
]
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id=model_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": model_info.embedding_dimension,
|
||||
"context_length": model_info.context_length,
|
||||
},
|
||||
)
|
||||
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||
|
||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
|
@ -25,6 +25,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import RunpodImplConfig
|
||||
|
||||
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
|
||||
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
|
||||
RUNPOD_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
||||
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
||||
|
@ -40,6 +42,14 @@ RUNPOD_SUPPORTED_MODELS = {
|
|||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||
}
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(provider_model_id, model_descriptor)
|
||||
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
||||
|
||||
class RunpodInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
|
|
|
@ -9,6 +9,14 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.1-8B-Instruct",
|
||||
|
@ -46,8 +54,4 @@ MODEL_ENTRIES = [
|
|||
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
@ -11,6 +11,16 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
|
@ -40,14 +50,6 @@ MODEL_ENTRIES = [
|
|||
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
|
||||
model_type=ModelType.embedding,
|
||||
|
@ -78,4 +80,4 @@ MODEL_ENTRIES = [
|
|||
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
],
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
@ -44,6 +44,7 @@ def build_hf_repo_model_entry(
|
|||
]
|
||||
if additional_aliases:
|
||||
aliases.extend(additional_aliases)
|
||||
aliases = [alias for alias in aliases if alias is not None]
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=aliases,
|
||||
|
@ -90,7 +91,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
if provider_resource_id:
|
||||
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences
|
||||
if provider_resource_id != supported_model_id: # be idempotent, only reject differences
|
||||
raise ValueError(
|
||||
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue