llama-stack-mirror/llama_stack/distributions/starter/starter.py
Ashwin Bharambe 7519b73fcc
feat(distro): fork off a starter-gpu distribution (#3240)
The starter distribution added post-training which added torch
dependencies which pulls in all the nvidia CUDA libraries. This made our
starter container very big. We have worked hard to keep the starter
container small so it serves its purpose as a starter. This PR tries to
get it back to its size by forking off duplicate "-gpu" providers for
post-training. These forked providers are then used for a new
`starter-gpu` distribution which can pull in all dependencies.
2025-08-22 15:47:15 -07:00

282 lines
10 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import (
BuildProvider,
Provider,
ProviderSpec,
ShieldInput,
ToolGroupInput,
)
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.datatypes import RemoteProviderSpec
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
SQLiteVectorIOConfig,
)
from llama_stack.providers.registry.inference import available_providers
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig,
)
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
"""Get configuration for a provider using its adapter's config class."""
config_class = instantiate_class_type(provider_spec.config_class)
if hasattr(config_class, "sample_run_config"):
config: dict[str, Any] = config_class.sample_run_config()
return config
return {}
ENABLED_INFERENCE_PROVIDERS = [
"ollama",
"vllm",
"tgi",
"fireworks",
"together",
"gemini",
"vertexai",
"groq",
"sambanova",
"anthropic",
"openai",
"cerebras",
"nvidia",
"bedrock",
]
INFERENCE_PROVIDER_IDS = {
"ollama": "${env.OLLAMA_URL:+ollama}",
"vllm": "${env.VLLM_URL:+vllm}",
"tgi": "${env.TGI_URL:+tgi}",
"cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
"nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
"vertexai": "${env.VERTEX_AI_PROJECT:+vertexai}",
}
def get_remote_inference_providers() -> list[Provider]:
# Filter out inline providers and some others - the starter distro only exposes remote providers
remote_providers = [
provider
for provider in available_providers()
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
]
inference_providers = []
for provider_spec in remote_providers:
provider_type = provider_spec.adapter.adapter_type
if provider_type in INFERENCE_PROVIDER_IDS:
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
else:
provider_id = provider_type.replace("-", "_").replace("::", "_")
config = _get_config_for_provider(provider_spec)
inference_providers.append(
Provider(
provider_id=provider_id,
provider_type=f"remote::{provider_type}",
config=config,
)
)
return inference_providers
def get_distribution_template() -> DistributionTemplate:
remote_inference_providers = get_remote_inference_providers()
name = "starter"
providers = {
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
+ [BuildProvider(provider_type="inline::sentence-transformers")],
"vector_io": [
BuildProvider(provider_type="inline::faiss"),
BuildProvider(provider_type="inline::sqlite-vec"),
BuildProvider(provider_type="inline::milvus"),
BuildProvider(provider_type="remote::chromadb"),
BuildProvider(provider_type="remote::pgvector"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
"safety": [
BuildProvider(provider_type="inline::llama-guard"),
BuildProvider(provider_type="inline::code-scanner"),
],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"post_training": [BuildProvider(provider_type="inline::huggingface-cpu")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [
BuildProvider(provider_type="remote::huggingface"),
BuildProvider(provider_type="inline::localfs"),
],
"scoring": [
BuildProvider(provider_type="inline::basic"),
BuildProvider(provider_type="inline::llm-as-judge"),
BuildProvider(provider_type="inline::braintrust"),
],
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"batches": [
BuildProvider(provider_type="inline::reference"),
],
}
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
default_shields = [
# if the
ShieldInput(
shield_id="llama-guard",
provider_id="${env.SAFETY_MODEL:+llama-guard}",
provider_shield_id="${env.SAFETY_MODEL:=}",
),
ShieldInput(
shield_id="code-scanner",
provider_id="${env.CODE_SCANNER_MODEL:+code-scanner}",
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
),
]
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.",
container_image=None,
template_path=None,
providers=providers,
additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": remote_inference_providers + [embedding_provider],
"vector_io": [
Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.MILVUS_URL:+milvus}",
provider_type="inline::milvus",
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.CHROMADB_URL:+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}/",
url="${env.CHROMADB_URL:=}",
),
),
Provider(
provider_id="${env.PGVECTOR_DB:+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}",
),
),
],
"files": [files_provider],
},
default_models=[],
default_tool_groups=default_tool_groups,
default_shields=default_shields,
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"8321",
"Port for the Llama Stack distribution server",
),
"FIREWORKS_API_KEY": (
"",
"Fireworks API Key",
),
"OPENAI_API_KEY": (
"",
"OpenAI API Key",
),
"GROQ_API_KEY": (
"",
"Groq API Key",
),
"ANTHROPIC_API_KEY": (
"",
"Anthropic API Key",
),
"GEMINI_API_KEY": (
"",
"Gemini API Key",
),
"VERTEX_AI_PROJECT": (
"",
"Google Cloud Project ID for Vertex AI",
),
"VERTEX_AI_LOCATION": (
"us-central1",
"Google Cloud Location for Vertex AI",
),
"SAMBANOVA_API_KEY": (
"",
"SambaNova API Key",
),
"VLLM_URL": (
"http://localhost:8000/v1",
"vLLM URL",
),
"VLLM_INFERENCE_MODEL": (
"",
"Optional vLLM Inference Model to register on startup",
),
"OLLAMA_URL": (
"http://localhost:11434",
"Ollama URL",
),
},
)