Merge branch 'main' into vectordb_name

This commit is contained in:
Francisco Arceo 2025-07-09 20:53:46 -04:00 committed by GitHub
commit 36ca9543a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 2282 additions and 1644 deletions

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in # These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence, # the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist * @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist @mattf

View file

@ -7,3 +7,5 @@ runs:
shell: bash shell: bash
run: | run: |
docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models
# TODO: rebuild an ollama image with llama-guard3:1b
docker exec ollama ollama pull llama-guard3:1b

View file

@ -24,7 +24,7 @@ jobs:
matrix: matrix:
# Listing tests manually since some of them currently fail # Listing tests manually since some of them currently fail
# TODO: generate matrix list from tests/integration when fixed # TODO: generate matrix list from tests/integration when fixed
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime, vector_io] test-type: [agents, inference, datasets, inspect, safety, scoring, post_training, providers, tool_runtime, vector_io]
client-type: [library, server] client-type: [library, server]
python-version: ["3.12", "3.13"] python-version: ["3.12", "3.13"]
fail-fast: false # we want to run all tests regardless of failure fail-fast: false # we want to run all tests regardless of failure
@ -51,11 +51,23 @@ jobs:
free -h free -h
df -h df -h
- name: Verify Ollama status is OK
if: matrix.client-type == 'http'
run: |
echo "Verifying Ollama status..."
ollama_status=$(curl -s -L http://127.0.0.1:8321/v1/providers/ollama|jq --raw-output .health.status)
echo "Ollama status: $ollama_status"
if [ "$ollama_status" != "OK" ]; then
echo "Ollama health check failed"
exit 1
fi
- name: Run Integration Tests - name: Run Integration Tests
env: env:
OLLAMA_INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" # for server tests OLLAMA_INFERENCE_MODEL: "llama3.2:3b-instruct-fp16" # for server tests
ENABLE_OLLAMA: "ollama" # for server tests ENABLE_OLLAMA: "ollama" # for server tests
OLLAMA_URL: "http://0.0.0.0:11434" OLLAMA_URL: "http://0.0.0.0:11434"
SAFETY_MODEL: "llama-guard3:1b"
# Use 'shell' to get pipefail behavior # Use 'shell' to get pipefail behavior
# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference
# TODO: write a precommit hook to detect if a test contains a pipe but does not use 'shell: bash' # TODO: write a precommit hook to detect if a test contains a pipe but does not use 'shell: bash'
@ -68,8 +80,9 @@ jobs:
fi fi
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="ollama/meta-llama/Llama-3.2-3B-Instruct" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \
--embedding-model=all-MiniLM-L6-v2 \ --embedding-model=all-MiniLM-L6-v2 \
--safety-shield=ollama \
--color=yes \ --color=yes \
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log

View file

@ -11132,8 +11132,38 @@
"title": "Trace" "title": "Trace"
}, },
"Checkpoint": { "Checkpoint": {
"description": "Checkpoint created during training runs", "type": "object",
"title": "Checkpoint" "properties": {
"identifier": {
"type": "string"
},
"created_at": {
"type": "string",
"format": "date-time"
},
"epoch": {
"type": "integer"
},
"post_training_job_id": {
"type": "string"
},
"path": {
"type": "string"
},
"training_metrics": {
"$ref": "#/components/schemas/PostTrainingMetric"
}
},
"additionalProperties": false,
"required": [
"identifier",
"created_at",
"epoch",
"post_training_job_id",
"path"
],
"title": "Checkpoint",
"description": "Checkpoint created during training runs"
}, },
"PostTrainingJobArtifactsResponse": { "PostTrainingJobArtifactsResponse": {
"type": "object", "type": "object",
@ -11156,6 +11186,31 @@
"title": "PostTrainingJobArtifactsResponse", "title": "PostTrainingJobArtifactsResponse",
"description": "Artifacts of a finetuning job." "description": "Artifacts of a finetuning job."
}, },
"PostTrainingMetric": {
"type": "object",
"properties": {
"epoch": {
"type": "integer"
},
"train_loss": {
"type": "number"
},
"validation_loss": {
"type": "number"
},
"perplexity": {
"type": "number"
}
},
"additionalProperties": false,
"required": [
"epoch",
"train_loss",
"validation_loss",
"perplexity"
],
"title": "PostTrainingMetric"
},
"PostTrainingJobStatusResponse": { "PostTrainingJobStatusResponse": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -7838,8 +7838,30 @@ components:
- start_time - start_time
title: Trace title: Trace
Checkpoint: Checkpoint:
description: Checkpoint created during training runs type: object
properties:
identifier:
type: string
created_at:
type: string
format: date-time
epoch:
type: integer
post_training_job_id:
type: string
path:
type: string
training_metrics:
$ref: '#/components/schemas/PostTrainingMetric'
additionalProperties: false
required:
- identifier
- created_at
- epoch
- post_training_job_id
- path
title: Checkpoint title: Checkpoint
description: Checkpoint created during training runs
PostTrainingJobArtifactsResponse: PostTrainingJobArtifactsResponse:
type: object type: object
properties: properties:
@ -7855,6 +7877,24 @@ components:
- checkpoints - checkpoints
title: PostTrainingJobArtifactsResponse title: PostTrainingJobArtifactsResponse
description: Artifacts of a finetuning job. description: Artifacts of a finetuning job.
PostTrainingMetric:
type: object
properties:
epoch:
type: integer
train_loss:
type: number
validation_loss:
type: number
perplexity:
type: number
additionalProperties: false
required:
- epoch
- train_loss
- validation_loss
- perplexity
title: PostTrainingMetric
PostTrainingJobStatusResponse: PostTrainingJobStatusResponse:
type: object type: object
properties: properties:

View file

@ -13,7 +13,7 @@ Latest Release Notes: [link](https://github.com/meta-llama/llama-stack-client-ko
*Tagged releases are stable versions of the project. While we strive to maintain a stable main branch, it's not guaranteed to be free of bugs or issues.* *Tagged releases are stable versions of the project. While we strive to maintain a stable main branch, it's not guaranteed to be free of bugs or issues.*
## Android Demo App ## Android Demo App
Check out our demo app to see how to integrate Llama Stack into your Android app: [Android Demo App](https://github.com/meta-llama/llama-stack-client-kotlin/tree/examples/android_app) Check out our demo app to see how to integrate Llama Stack into your Android app: [Android Demo App](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app)
The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlamaStackRemoteInference.kts`, and `MainActivity.java`. With encompassed business logic, the app shows how to use Llama Stack for both the environments. The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlamaStackRemoteInference.kts`, and `MainActivity.java`. With encompassed business logic, the app shows how to use Llama Stack for both the environments.
@ -68,7 +68,7 @@ Ensure the Llama Stack server version is the same as the Kotlin SDK Library for
Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations) Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations)
How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#settings) How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#settings)
### Initialize the Client ### Initialize the Client
A client serves as the primary interface for interacting with a specific inference type and its associated parameters. Only after client is initialized then you can configure and start inferences. A client serves as the primary interface for interacting with a specific inference type and its associated parameters. Only after client is initialized then you can configure and start inferences.
@ -135,7 +135,7 @@ val result = client!!.inference().chatCompletionStreaming(
### Setup Custom Tool Calling ### Setup Custom Tool Calling
Android demo app for more details: [Custom Tool Calling](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#tool-calling) Android demo app for more details: [Custom Tool Calling](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#tool-calling)
## Advanced Users ## Advanced Users

View file

@ -114,6 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server | | `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server | | `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server | | `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. | | `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. > **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.

View file

@ -19,8 +19,10 @@ class PostTrainingMetric(BaseModel):
perplexity: float perplexity: float
@json_schema_type(schema={"description": "Checkpoint created during training runs"}) @json_schema_type
class Checkpoint(BaseModel): class Checkpoint(BaseModel):
"""Checkpoint created during training runs"""
identifier: str identifier: str
created_at: datetime created_at: datetime
epoch: int epoch: int

View file

@ -98,6 +98,7 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method) method = getattr(impls[api], register_method)
for obj in objects: for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# Do not register models on disabled providers # Do not register models on disabled providers
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__": if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.") logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
@ -112,6 +113,11 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
): ):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.") logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
continue continue
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.")
continue
# we want to maintain the type information in arguments to method. # we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict, # instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
# we use model_dump() to find all the attrs and then getattr to get the still typed value. # we use model_dump() to find all the attrs and then getattr to get the still typed value.

View file

@ -178,6 +178,7 @@ def usecases() -> list[UseCase | str]:
), ),
RawMessage(role="user", content="What is the 100th decimal of pi?"), RawMessage(role="user", content="What is the 100th decimal of pi?"),
RawMessage( RawMessage(
role="assistant",
content="", content="",
stop_reason=StopReason.end_of_message, stop_reason=StopReason.end_of_message,
tool_calls=[ tool_calls=[

View file

@ -24,8 +24,8 @@ class ShieldRunnerMixin:
def __init__( def __init__(
self, self,
safety_api: Safety, safety_api: Safety,
input_shields: list[str] = None, input_shields: list[str] | None = None,
output_shields: list[str] = None, output_shields: list[str] | None = None,
): ):
self.safety_api = safety_api self.safety_api = safety_api
self.input_shields = input_shields self.input_shields = input_shields
@ -37,6 +37,7 @@ class ShieldRunnerMixin:
return await self.safety_api.run_shield( return await self.safety_api.run_shield(
shield_id=identifier, shield_id=identifier,
messages=messages, messages=messages,
params={},
) )
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers]) responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])

View file

@ -39,7 +39,7 @@ class MetaReferenceInferenceConfig(BaseModel):
def validate_model(cls, model: str) -> str: def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models() permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models] descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models] repos = [m.huggingface_repo for m in permitted_models if m.huggingface_repo is not None]
if model not in (descriptors + repos): if model not in (descriptors + repos):
model_list = "\n\t".join(repos) model_list = "\n\t".join(repos)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]") raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")

View file

@ -123,7 +123,8 @@ class TorchtunePostTrainingImpl:
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any], hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any], logger_config: dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob:
raise NotImplementedError()
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse( return ListPostTrainingJobsResponse(

View file

@ -146,10 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass pass
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS: # Allow any model to be registered as a shield
raise ValueError( # The model will be validated during runtime when making inference calls
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}" pass
)
async def run_shield( async def run_shield(
self, self,
@ -167,11 +166,25 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if len(messages) > 0 and messages[0].role != Role.user.value: if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content) messages[0] = UserMessage(content=messages[0].content)
model = LLAMA_GUARD_MODEL_IDS[shield.provider_resource_id] # Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
model_id = shield.provider_resource_id
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
if model_id in LLAMA_GUARD_MODEL_IDS:
# Use the mapped model for categories but the original model_id for inference
mapped_model = LLAMA_GUARD_MODEL_IDS[model_id]
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
else:
# For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
impl = LlamaGuardShield( impl = LlamaGuardShield(
model=model, model=model_id,
inference_api=self.inference_api, inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories, excluded_categories=self.config.excluded_categories,
safety_categories=safety_categories,
) )
return await impl.run(messages) return await impl.run(messages)
@ -183,20 +196,21 @@ class LlamaGuardShield:
model: str, model: str,
inference_api: Inference, inference_api: Inference,
excluded_categories: list[str] | None = None, excluded_categories: list[str] | None = None,
safety_categories: list[str] | None = None,
): ):
if excluded_categories is None: if excluded_categories is None:
excluded_categories = [] excluded_categories = []
if safety_categories is None:
safety_categories = []
assert len(excluded_categories) == 0 or all( assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
if model not in MODEL_TO_SAFETY_CATEGORIES_MAP:
raise ValueError(f"Unsupported model: {model}")
self.model = model self.model = model
self.inference_api = inference_api self.inference_api = inference_api
self.excluded_categories = excluded_categories self.excluded_categories = excluded_categories
self.safety_categories = safety_categories
def check_unsafe_response(self, response: str) -> str | None: def check_unsafe_response(self, response: str) -> str | None:
match = re.match(r"^unsafe\n(.*)$", response) match = re.match(r"^unsafe\n(.*)$", response)
@ -214,7 +228,7 @@ class LlamaGuardShield:
final_categories = [] final_categories = []
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.model] all_categories = self.safety_categories
for cat in all_categories: for cat in all_categories:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat] cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories: if cat_code in excluded_categories:

View file

@ -15,21 +15,26 @@ LLM_MODEL_IDS = [
"anthropic/claude-3-5-haiku-latest", "anthropic/claude-3-5-haiku-latest",
] ]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [ MODEL_ENTRIES = (
ProviderModelEntry( [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
provider_model_id="anthropic/voyage-3", + [
model_type=ModelType.embedding, ProviderModelEntry(
metadata={"embedding_dimension": 1024, "context_length": 32000}, provider_model_id="anthropic/voyage-3",
), model_type=ModelType.embedding,
ProviderModelEntry( metadata={"embedding_dimension": 1024, "context_length": 32000},
provider_model_id="anthropic/voyage-3-lite", ),
model_type=ModelType.embedding, ProviderModelEntry(
metadata={"embedding_dimension": 512, "context_length": 32000}, provider_model_id="anthropic/voyage-3-lite",
), model_type=ModelType.embedding,
ProviderModelEntry( metadata={"embedding_dimension": 512, "context_length": 32000},
provider_model_id="anthropic/voyage-code-3", ),
model_type=ModelType.embedding, ProviderModelEntry(
metadata={"embedding_dimension": 1024, "context_length": 32000}, provider_model_id="anthropic/voyage-code-3",
), model_type=ModelType.embedding,
] metadata={"embedding_dimension": 1024, "context_length": 32000},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -9,6 +9,10 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = []
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"meta.llama3-1-8b-instruct-v1:0", "meta.llama3-1-8b-instruct-v1:0",
@ -22,4 +26,4 @@ MODEL_ENTRIES = [
"meta.llama3-1-405b-instruct-v1:0", "meta.llama3-1-405b-instruct-v1:0",
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -9,6 +9,9 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = []
# https://inference-docs.cerebras.ai/models
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"llama3.1-8b", "llama3.1-8b",
@ -18,4 +21,8 @@ MODEL_ENTRIES = [
"llama-3.3-70b", "llama-3.3-70b",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
] build_hf_repo_model_entry(
"llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -47,7 +47,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import DatabricksImplConfig from .config import DatabricksImplConfig
model_entries = [ SAFETY_MODELS_ENTRIES = []
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"databricks-meta-llama-3-1-70b-instruct", "databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
@ -56,7 +59,7 @@ model_entries = [
"databricks-meta-llama-3-1-405b-instruct", "databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
), ),
] ] + SAFETY_MODELS_ENTRIES
class DatabricksInferenceAdapter( class DatabricksInferenceAdapter(
@ -66,7 +69,7 @@ class DatabricksInferenceAdapter(
OpenAICompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin,
): ):
def __init__(self, config: DatabricksImplConfig) -> None: def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=model_entries) ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -11,6 +11,17 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-8b-instruct", "accounts/fireworks/models/llama-v3p1-8b-instruct",
@ -40,14 +51,6 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-v3p3-70b-instruct", "accounts/fireworks/models/llama-v3p3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
build_hf_repo_model_entry( build_hf_repo_model_entry(
"accounts/fireworks/models/llama4-scout-instruct-basic", "accounts/fireworks/models/llama4-scout-instruct-basic",
CoreModelId.llama4_scout_17b_16e_instruct.value, CoreModelId.llama4_scout_17b_16e_instruct.value,
@ -64,4 +67,4 @@ MODEL_ENTRIES = [
"context_length": 8192, "context_length": 8192,
}, },
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -17,11 +17,16 @@ LLM_MODEL_IDS = [
"gemini/gemini-2.5-pro", "gemini/gemini-2.5-pro",
] ]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [ MODEL_ENTRIES = (
ProviderModelEntry( [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
provider_model_id="gemini/text-embedding-004", + [
model_type=ModelType.embedding, ProviderModelEntry(
metadata={"embedding_dimension": 768, "context_length": 2048}, provider_model_id="gemini/text-embedding-004",
), model_type=ModelType.embedding,
] metadata={"embedding_dimension": 768, "context_length": 2048},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -10,6 +10,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_entry, build_model_entry,
) )
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"groq/llama3-8b-8192", "groq/llama3-8b-8192",
@ -51,4 +53,4 @@ MODEL_ENTRIES = [
"groq/meta-llama/llama-4-maverick-17b-128e-instruct", "groq/meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value, CoreModelId.llama4_maverick_17b_128e_instruct.value,
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -11,6 +11,9 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = []
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"meta/llama3-8b-instruct", "meta/llama3-8b-instruct",
@ -99,4 +102,4 @@ MODEL_ENTRIES = [
), ),
# TODO(mf): how do we handle Nemotron models? # TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct", # "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
] ] + SAFETY_MODELS_ENTRIES

View file

@ -48,16 +48,20 @@ EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192), "text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192), "text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
} }
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = (
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [ [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
ProviderModelEntry( + [
provider_model_id=model_id, ProviderModelEntry(
model_type=ModelType.embedding, provider_model_id=model_id,
metadata={ model_type=ModelType.embedding,
"embedding_dimension": model_info.embedding_dimension, metadata={
"context_length": model_info.context_length, "embedding_dimension": model_info.embedding_dimension,
}, "context_length": model_info.context_length,
) },
for model_id, model_info in EMBEDDING_MODEL_IDS.items() )
] for model_id, model_info in EMBEDDING_MODEL_IDS.items()
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -11,7 +11,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate # from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin,
@ -25,6 +25,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import RunpodImplConfig from .config import RunpodImplConfig
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
RUNPOD_SUPPORTED_MODELS = { RUNPOD_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B", "Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B", "Llama3.1-70B": "meta-llama/Llama-3.1-70B",
@ -40,6 +42,14 @@ RUNPOD_SUPPORTED_MODELS = {
"Llama3.2-3B": "meta-llama/Llama-3.2-3B", "Llama3.2-3B": "meta-llama/Llama-3.2-3B",
} }
SAFETY_MODELS_ENTRIES = []
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES
class RunpodInferenceAdapter( class RunpodInferenceAdapter(
ModelRegistryHelper, ModelRegistryHelper,
@ -61,25 +71,25 @@ class RunpodInferenceAdapter(
self, self,
model: str, model: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None, sampling_params: SamplingParams | None = None,
response_format: Optional[ResponseFormat] = None, response_format: ResponseFormat | None = None,
stream: Optional[bool] = False, stream: bool | None = False,
logprobs: Optional[LogProbConfig] = None, logprobs: LogProbConfig | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
async def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: list[Message],
sampling_params: Optional[SamplingParams] = None, sampling_params: SamplingParams | None = None,
response_format: Optional[ResponseFormat] = None, response_format: ResponseFormat | None = None,
tools: Optional[List[ToolDefinition]] = None, tools: list[ToolDefinition] | None = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: ToolPromptFormat | None = None,
stream: Optional[bool] = False, stream: bool | None = False,
logprobs: Optional[LogProbConfig] = None, logprobs: LogProbConfig | None = None,
tool_config: Optional[ToolConfig] = None, tool_config: ToolConfig | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
@ -129,10 +139,10 @@ class RunpodInferenceAdapter(
async def embeddings( async def embeddings(
self, self,
model: str, model: str,
contents: List[str] | List[InterleavedContentItem], contents: list[str] | list[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none, text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: Optional[int] = None, output_dimension: int | None = None,
task_type: Optional[EmbeddingTaskType] = None, task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -9,6 +9,14 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.1-8B-Instruct", "sambanova/Meta-Llama-3.1-8B-Instruct",
@ -46,8 +54,4 @@ MODEL_ENTRIES = [
"sambanova/Llama-4-Maverick-17B-128E-Instruct", "sambanova/Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value, CoreModelId.llama4_maverick_17b_128e_instruct.value,
), ),
build_hf_repo_model_entry( ] + SAFETY_MODELS_ENTRIES
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]

View file

@ -11,6 +11,16 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
@ -40,14 +50,6 @@ MODEL_ENTRIES = [
"meta-llama/Llama-3.3-70B-Instruct-Turbo", "meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval", provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
model_type=ModelType.embedding, model_type=ModelType.embedding,
@ -78,4 +80,4 @@ MODEL_ENTRIES = [
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
], ],
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -14,6 +14,6 @@ async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, Provide
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(config, deps[Api.inference]) impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -16,6 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
uri: str = Field(description="The URI of the Milvus server") uri: str = Field(description="The URI of the Milvus server")
token: str | None = Field(description="The token of the Milvus server") token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong") consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
# This configuration allows additional fields to be passed through to the underlying Milvus client. # This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. # See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.

View file

@ -44,6 +44,7 @@ def build_hf_repo_model_entry(
] ]
if additional_aliases: if additional_aliases:
aliases.extend(additional_aliases) aliases.extend(additional_aliases)
aliases = [alias for alias in aliases if alias is not None]
return ProviderModelEntry( return ProviderModelEntry(
provider_model_id=provider_model_id, provider_model_id=provider_model_id,
aliases=aliases, aliases=aliases,
@ -90,7 +91,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model # embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id provider_resource_id = model.provider_resource_id
if provider_resource_id: if provider_resource_id:
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences if provider_resource_id != supported_model_id: # be idempotent, only reject differences
raise ValueError( raise ValueError(
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first." f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
) )

View file

@ -256,11 +256,46 @@ inference_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db
models: models:
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama3.1-8b
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama3.1-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama3.1-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama-3.3-70b
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-3.3-70b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-3.3-70b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama-4-scout-17b-16e-instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_INFERENCE_MODEL:=__disabled__} model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_INFERENCE_MODEL:=__disabled__}
provider_id: ${env.ENABLE_OLLAMA:=__disabled__} provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
provider_model_id: ${env.OLLAMA_INFERENCE_MODEL:=__disabled__} provider_model_id: ${env.OLLAMA_INFERENCE_MODEL:=__disabled__}
model_type: llm model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
provider_model_id: ${env.SAFETY_MODEL:=__disabled__}
model_type: llm
- metadata: - metadata:
embedding_dimension: ${env.OLLAMA_EMBEDDING_DIMENSION:=384} embedding_dimension: ${env.OLLAMA_EMBEDDING_DIMENSION:=384}
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_EMBEDDING_MODEL:=__disabled__} model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}
@ -342,26 +377,6 @@ models:
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-8b
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-11b-vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama4-scout-instruct-basic model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama4-scout-instruct-basic
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
@ -389,6 +404,26 @@ models:
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: nomic-ai/nomic-embed-text-v1.5 provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding model_type: embedding
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-8b
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-11b-vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo
provider_id: ${env.ENABLE_TOGETHER:=__disabled__} provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
@ -459,26 +494,6 @@ models:
provider_id: ${env.ENABLE_TOGETHER:=__disabled__} provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo
model_type: llm model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Meta-Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Meta-Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision-Turbo
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: - metadata:
embedding_dimension: 768 embedding_dimension: 768
context_length: 8192 context_length: 8192
@ -523,6 +538,264 @@ models:
provider_id: ${env.ENABLE_TOGETHER:=__disabled__} provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
model_type: llm model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision-Turbo
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-8b-instruct-v1:0
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-8b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-8b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-70b-instruct-v1:0
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-70b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-70b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-405b-instruct-v1:0
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-405b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-405b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/databricks-meta-llama-3-1-70b-instruct
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/databricks-meta-llama-3-1-405b-instruct
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama3-8b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3-8B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama3-70b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3-70B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-8b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-70b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-405b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-1b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-1B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-3b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-3B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-11b-vision-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-90b-vision-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.3-70b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata:
embedding_dimension: 2048
context_length: 8192
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/llama-3.2-nv-embedqa-1b-v2
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/nv-embedqa-e5-v5
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: nvidia/nv-embedqa-e5-v5
model_type: embedding
- metadata:
embedding_dimension: 4096
context_length: 512
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/nv-embedqa-mistral-7b-v2
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/snowflake/arctic-embed-l
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: snowflake/arctic-embed-l
model_type: embedding
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-8B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-70B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-70B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B:bf16-mp8
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B:bf16-mp8
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B:bf16-mp16
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B:bf16-mp16
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-8B-Instruct
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-8B-Instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-70B-Instruct
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-70B-Instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct:bf16-mp8
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B-Instruct:bf16-mp8
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B-Instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct:bf16-mp16
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B-Instruct:bf16-mp16
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.2-1B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.2-1B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.2-3B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.2-3B
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o
provider_id: ${env.ENABLE_OPENAI:=__disabled__} provider_id: ${env.ENABLE_OPENAI:=__disabled__}
@ -894,7 +1167,25 @@ models:
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
model_type: embedding model_type: embedding
shields: [] shields:
- shield_id: ${env.ENABLE_OLLAMA:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b}
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b}
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision}
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo}
- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -12,6 +12,7 @@ from llama_stack.distribution.datatypes import (
ModelInput, ModelInput,
Provider, Provider,
ProviderSpec, ProviderSpec,
ShieldInput,
ToolGroupInput, ToolGroupInput,
) )
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -31,24 +32,75 @@ from llama_stack.providers.registry.inference import available_providers
from llama_stack.providers.remote.inference.anthropic.models import ( from llama_stack.providers.remote.inference.anthropic.models import (
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES, MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.anthropic.models import (
SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.bedrock.models import (
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.bedrock.models import (
SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.cerebras.models import (
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.cerebras.models import (
SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.databricks.databricks import (
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.databricks.databricks import (
SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.fireworks.models import ( from llama_stack.providers.remote.inference.fireworks.models import (
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.fireworks.models import (
SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.gemini.models import ( from llama_stack.providers.remote.inference.gemini.models import (
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES, MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.gemini.models import (
SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.groq.models import ( from llama_stack.providers.remote.inference.groq.models import (
MODEL_ENTRIES as GROQ_MODEL_ENTRIES, MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.groq.models import (
SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.nvidia.models import (
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.nvidia.models import (
SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.openai.models import ( from llama_stack.providers.remote.inference.openai.models import (
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES, MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.openai.models import (
SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.runpod.runpod import (
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.runpod.runpod import (
SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.sambanova.models import ( from llama_stack.providers.remote.inference.sambanova.models import (
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES, MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.sambanova.models import (
SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.together.models import ( from llama_stack.providers.remote.inference.together.models import (
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES, MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.together.models import (
SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import ( from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig, PGVectorVectorIOConfig,
@ -72,6 +124,11 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
"gemini": GEMINI_MODEL_ENTRIES, "gemini": GEMINI_MODEL_ENTRIES,
"groq": GROQ_MODEL_ENTRIES, "groq": GROQ_MODEL_ENTRIES,
"sambanova": SAMBANOVA_MODEL_ENTRIES, "sambanova": SAMBANOVA_MODEL_ENTRIES,
"cerebras": CEREBRAS_MODEL_ENTRIES,
"bedrock": BEDROCK_MODEL_ENTRIES,
"databricks": DATABRICKS_MODEL_ENTRIES,
"nvidia": NVIDIA_MODEL_ENTRIES,
"runpod": RUNPOD_MODEL_ENTRIES,
} }
# Special handling for providers with dynamic model entries # Special handling for providers with dynamic model entries
@ -81,6 +138,10 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}", provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
model_type=ModelType.llm, model_type=ModelType.llm,
), ),
ProviderModelEntry(
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}", provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
model_type=ModelType.embedding, model_type=ModelType.embedding,
@ -100,6 +161,35 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
return model_entries_map.get(provider_type, []) return model_entries_map.get(provider_type, [])
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
"""Get model entries for a specific provider type."""
safety_model_entries_map = {
"openai": OPENAI_SAFETY_MODELS_ENTRIES,
"fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES,
"together": TOGETHER_SAFETY_MODELS_ENTRIES,
"anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES,
"gemini": GEMINI_SAFETY_MODELS_ENTRIES,
"groq": GROQ_SAFETY_MODELS_ENTRIES,
"sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES,
"cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES,
"bedrock": BEDROCK_SAFETY_MODELS_ENTRIES,
"databricks": DATABRICKS_SAFETY_MODELS_ENTRIES,
"nvidia": NVIDIA_SAFETY_MODELS_ENTRIES,
"runpod": RUNPOD_SAFETY_MODELS_ENTRIES,
}
# Special handling for providers with dynamic model entries
if provider_type == "ollama":
return [
ProviderModelEntry(
provider_model_id="llama-guard3:1b",
model_type=ModelType.llm,
),
]
return safety_model_entries_map.get(provider_type, [])
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]: def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
"""Get configuration for a provider using its adapter's config class.""" """Get configuration for a provider using its adapter's config class."""
config_class = instantiate_class_type(provider_spec.config_class) config_class = instantiate_class_type(provider_spec.config_class)
@ -155,6 +245,31 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
return inference_providers, available_models return inference_providers, available_models
# build a list of shields for all possible providers
def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]:
shields = []
for provider in providers:
provider_type = provider.provider_type.split("::")[1]
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
if len(safety_model_entries) == 0:
continue
if provider.provider_id:
shield_id = provider.provider_id
else:
raise ValueError(f"Provider {provider.provider_type} has no provider_id")
for safety_model_entry in safety_model_entries:
print(f"provider.provider_id: {provider.provider_id}")
print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}")
shields.append(
ShieldInput(
provider_id="llama-guard",
shield_id=shield_id,
provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}",
)
)
return shields
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
remote_inference_providers, available_models = get_remote_inference_providers() remote_inference_providers, available_models = get_remote_inference_providers()
@ -192,6 +307,8 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
shields = get_shields_for_providers(remote_inference_providers)
providers = { providers = {
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]), "inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
"vector_io": ([p.provider_type for p in vector_io_providers]), "vector_io": ([p.provider_type for p in vector_io_providers]),
@ -266,9 +383,7 @@ def get_distribution_template() -> DistributionTemplate:
default_models=default_models + [embedding_model], default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
# TODO: add a way to enable/disable shields on the fly # TODO: add a way to enable/disable shields on the fly
# default_shields=[ default_shields=shields,
# ShieldInput(provider_id="llama-guard", shield_id="${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}")
# ],
), ),
}, },
run_config_env_vars={ run_config_env_vars={

View file

@ -32,7 +32,7 @@ dependencies = [
"openai>=1.66", "openai>=1.66",
"prompt-toolkit", "prompt-toolkit",
"python-dotenv", "python-dotenv",
"python-jose", "python-jose[cryptography]",
"pydantic>=2", "pydantic>=2",
"rich", "rich",
"starlette", "starlette",
@ -225,7 +225,6 @@ follow_imports = "silent"
# to exclude the entire directory. # to exclude the entire directory.
exclude = [ exclude = [
# As we fix more and more of these, we should remove them from the list # As we fix more and more of these, we should remove them from the list
"^llama_stack/apis/common/training_types\\.py$",
"^llama_stack/cli/download\\.py$", "^llama_stack/cli/download\\.py$",
"^llama_stack/cli/stack/_build\\.py$", "^llama_stack/cli/stack/_build\\.py$",
"^llama_stack/distribution/build\\.py$", "^llama_stack/distribution/build\\.py$",
@ -243,25 +242,20 @@ exclude = [
"^llama_stack/models/llama/llama3/interface\\.py$", "^llama_stack/models/llama/llama3/interface\\.py$",
"^llama_stack/models/llama/llama3/tokenizer\\.py$", "^llama_stack/models/llama/llama3/tokenizer\\.py$",
"^llama_stack/models/llama/llama3/tool_utils\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$",
"^llama_stack/models/llama/llama3_3/prompts\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/", "^llama_stack/providers/inline/agents/meta_reference/",
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/safety\\.py$",
"^llama_stack/providers/inline/datasetio/localfs/", "^llama_stack/providers/inline/datasetio/localfs/",
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
"^llama_stack/models/llama/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/", "^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$", "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$",
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/inference/vllm/",
"^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/post_training/common/validator\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$",
"^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/code_scanner/",
"^llama_stack/providers/inline/safety/llama_guard/", "^llama_stack/providers/inline/safety/llama_guard/",
"^llama_stack/providers/inline/safety/prompt_guard/", "^llama_stack/providers/inline/safety/prompt_guard/",

View file

@ -28,6 +28,8 @@ certifi==2025.1.31
# httpcore # httpcore
# httpx # httpx
# requests # requests
cffi==1.17.1 ; platform_python_implementation != 'PyPy'
# via cryptography
charset-normalizer==3.4.1 charset-normalizer==3.4.1
# via requests # via requests
click==8.1.8 click==8.1.8
@ -38,6 +40,8 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via # via
# click # click
# tqdm # tqdm
cryptography==45.0.5
# via python-jose
deprecated==1.2.18 deprecated==1.2.18
# via # via
# opentelemetry-api # opentelemetry-api
@ -156,6 +160,8 @@ pyasn1==0.4.8
# via # via
# python-jose # python-jose
# rsa # rsa
pycparser==2.22 ; platform_python_implementation != 'PyPy'
# via cffi
pydantic==2.10.6 pydantic==2.10.6
# via # via
# fastapi # fastapi

View file

@ -7,7 +7,8 @@ FROM --platform=linux/amd64 ollama/ollama:latest
RUN ollama serve & \ RUN ollama serve & \
sleep 5 && \ sleep 5 && \
ollama pull llama3.2:3b-instruct-fp16 && \ ollama pull llama3.2:3b-instruct-fp16 && \
ollama pull all-minilm:l6-v2 ollama pull all-minilm:l6-v2 && \
ollama pull llama-guard3:1b
# Set the entrypoint to start ollama serve # Set the entrypoint to start ollama serve
ENTRYPOINT ["ollama", "serve"] ENTRYPOINT ["ollama", "serve"]

3002
uv.lock generated

File diff suppressed because it is too large Load diff