feat(starter)!: simplify starter distro; litellm model registry changes (#2916)

This commit is contained in:
Ashwin Bharambe 2025-07-25 15:02:04 -07:00 committed by GitHub
parent 3344d8a9e5
commit 9583f468f8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
64 changed files with 2027 additions and 4092 deletions

View file

@ -7,21 +7,15 @@ distribution_spec:
- provider_type: remote::ollama
- provider_type: remote::vllm
- provider_type: remote::tgi
- provider_type: remote::hf::serverless
- provider_type: remote::hf::endpoint
- provider_type: remote::fireworks
- provider_type: remote::together
- provider_type: remote::bedrock
- provider_type: remote::databricks
- provider_type: remote::nvidia
- provider_type: remote::runpod
- provider_type: remote::openai
- provider_type: remote::anthropic
- provider_type: remote::gemini
- provider_type: remote::groq
- provider_type: remote::llama-openai-compat
- provider_type: remote::sambanova
- provider_type: remote::passthrough
- provider_type: inline::sentence-transformers
vector_io:
- provider_type: inline::faiss

File diff suppressed because it is too large Load diff

View file

@ -7,20 +7,19 @@
from typing import Any
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import (
BuildProvider,
ModelInput,
Provider,
ProviderSpec,
ShieldInput,
ToolGroupInput,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import RemoteProviderSpec
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.milvus.config import (
MilvusVectorIOConfig,
@ -29,117 +28,17 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
SQLiteVectorIOConfig,
)
from llama_stack.providers.registry.inference import available_providers
from llama_stack.providers.remote.inference.anthropic.models import (
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.bedrock.models import (
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.cerebras.models import (
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.databricks.databricks import (
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.fireworks.models import (
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.gemini.models import (
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.groq.models import (
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.nvidia.models import (
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.openai.models import (
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.runpod.runpod import (
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.sambanova.models import (
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.together.models import (
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
)
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig,
)
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
get_model_registry,
get_shield_registry,
)
def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
"""Get model entries for a specific provider type."""
model_entries_map = {
"openai": OPENAI_MODEL_ENTRIES,
"fireworks": FIREWORKS_MODEL_ENTRIES,
"together": TOGETHER_MODEL_ENTRIES,
"anthropic": ANTHROPIC_MODEL_ENTRIES,
"gemini": GEMINI_MODEL_ENTRIES,
"groq": GROQ_MODEL_ENTRIES,
"sambanova": SAMBANOVA_MODEL_ENTRIES,
"cerebras": CEREBRAS_MODEL_ENTRIES,
"bedrock": BEDROCK_MODEL_ENTRIES,
"databricks": DATABRICKS_MODEL_ENTRIES,
"nvidia": NVIDIA_MODEL_ENTRIES,
"runpod": RUNPOD_MODEL_ENTRIES,
}
# Special handling for providers with dynamic model entries
if provider_type == "ollama":
return [
ProviderModelEntry(
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": "${env.OLLAMA_EMBEDDING_DIMENSION:=384}",
},
),
]
elif provider_type == "vllm":
return [
ProviderModelEntry(
provider_model_id="${env.VLLM_INFERENCE_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
]
return model_entries_map.get(provider_type, [])
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
"""Get model entries for a specific provider type."""
safety_model_entries_map = {
"ollama": [
ProviderModelEntry(
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
],
}
return safety_model_entries_map.get(provider_type, [])
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
"""Get configuration for a provider using its adapter's config class."""
config_class = instantiate_class_type(provider_spec.config_class)
@ -150,40 +49,48 @@ def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
return {}
def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
all_providers = available_providers()
ENABLED_INFERENCE_PROVIDERS = [
"ollama",
"vllm",
"tgi",
"fireworks",
"together",
"gemini",
"groq",
"sambanova",
"anthropic",
"openai",
"cerebras",
"nvidia",
"bedrock",
]
# Filter out inline providers and watsonx - the starter distro only exposes remote providers
INFERENCE_PROVIDER_IDS = {
"vllm": "${env.VLLM_URL:+vllm}",
"tgi": "${env.TGI_URL:+tgi}",
"cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
"nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
}
def get_remote_inference_providers() -> list[Provider]:
# Filter out inline providers and some others - the starter distro only exposes remote providers
remote_providers = [
provider
for provider in all_providers
# TODO: re-add once the Python 3.13 issue is fixed
# discussion: https://github.com/meta-llama/llama-stack/pull/2327#discussion_r2156883828
if hasattr(provider, "adapter") and provider.adapter.adapter_type != "watsonx"
for provider in available_providers()
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
]
providers = []
available_models = {}
inference_providers = []
for provider_spec in remote_providers:
provider_type = provider_spec.adapter.adapter_type
# Build the environment variable name for enabling this provider
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
model_entries = _get_model_entries_for_provider(provider_type)
if provider_type in INFERENCE_PROVIDER_IDS:
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
else:
provider_id = provider_type.replace("-", "_").replace("::", "_")
config = _get_config_for_provider(provider_spec)
providers.append(
(
f"${{env.{env_var}:=__disabled__}}",
provider_type,
model_entries,
config,
)
)
available_models[f"${{env.{env_var}:=__disabled__}}"] = model_entries
inference_providers = []
for provider_id, provider_type, model_entries, config in providers:
inference_providers.append(
Provider(
provider_id=provider_id,
@ -191,31 +98,13 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
config=config,
)
)
available_models[provider_id] = model_entries
return inference_providers, available_models
# build a list of shields for all possible providers
def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
available_models = {}
for provider in providers:
provider_type = provider.provider_type.split("::")[1]
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
if len(safety_model_entries) == 0:
continue
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
provider_id = f"${{env.{env_var}:=__disabled__}}"
available_models[provider_id] = safety_model_entries
return available_models
return inference_providers
def get_distribution_template() -> DistributionTemplate:
remote_inference_providers, available_models = get_remote_inference_providers()
remote_inference_providers = get_remote_inference_providers()
name = "starter"
# For build config, use BuildProvider with only provider_type and module
providers = {
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
+ [BuildProvider(provider_type="inline::sentence-transformers")],
@ -254,15 +143,10 @@ def get_distribution_template() -> DistributionTemplate:
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
embedding_provider = Provider(
provider_id="${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}",
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
post_training_provider = Provider(
provider_id="huggingface",
provider_type="inline::huggingface",
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
@ -273,19 +157,14 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="rag-runtime",
),
]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id=embedding_provider.provider_id,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
default_models, ids_conflict_in_models = get_model_registry(available_models)
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
default_shields = [
# if the
ShieldInput(
shield_id="llama-guard",
provider_id="${env.SAFETY_MODEL:+llama-guard}",
provider_shield_id="${env.SAFETY_MODEL:=}",
),
]
return DistributionTemplate(
name=name,
@ -294,7 +173,6 @@ def get_distribution_template() -> DistributionTemplate:
container_image=None,
template_path=None,
providers=providers,
available_models_by_provider=available_models,
additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
run_configs={
"run.yaml": RunConfigSettings(
@ -302,22 +180,22 @@ def get_distribution_template() -> DistributionTemplate:
"inference": remote_inference_providers + [embedding_provider],
"vector_io": [
Provider(
provider_id="${env.ENABLE_FAISS:=faiss}",
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_SQLITE_VEC:=__disabled__}",
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_MILVUS:=__disabled__}",
provider_id="${env.MILVUS_URL:+milvus}",
provider_type="inline::milvus",
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_CHROMADB:=__disabled__}",
provider_id="${env.CHROMADB_URL:+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}/",
@ -325,7 +203,7 @@ def get_distribution_template() -> DistributionTemplate:
),
),
Provider(
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
provider_id="${env.PGVECTOR_DB:+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
@ -336,12 +214,10 @@ def get_distribution_template() -> DistributionTemplate:
),
],
"files": [files_provider],
"post_training": [post_training_provider],
},
default_models=[embedding_model] + default_models,
default_models=[],
default_tool_groups=default_tool_groups,
# TODO: add a way to enable/disable shields on the fly
default_shields=shields,
default_shields=default_shields,
),
},
run_config_env_vars={
@ -385,17 +261,5 @@ def get_distribution_template() -> DistributionTemplate:
"http://localhost:11434",
"Ollama URL",
),
"OLLAMA_INFERENCE_MODEL": (
"",
"Optional Ollama Inference Model to register on startup",
),
"OLLAMA_EMBEDDING_MODEL": (
"",
"Optional Ollama Embedding Model to register on startup",
),
"OLLAMA_EMBEDDING_DIMENSION": (
"384",
"Ollama Embedding Dimension",
),
},
)