mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 22:48:51 +00:00
feat(starter)!: simplify starter distro; litellm model registry changes (#2916)
This commit is contained in:
parent
3344d8a9e5
commit
9583f468f8
64 changed files with 2027 additions and 4092 deletions
|
@ -7,21 +7,15 @@ distribution_spec:
|
|||
- provider_type: remote::ollama
|
||||
- provider_type: remote::vllm
|
||||
- provider_type: remote::tgi
|
||||
- provider_type: remote::hf::serverless
|
||||
- provider_type: remote::hf::endpoint
|
||||
- provider_type: remote::fireworks
|
||||
- provider_type: remote::together
|
||||
- provider_type: remote::bedrock
|
||||
- provider_type: remote::databricks
|
||||
- provider_type: remote::nvidia
|
||||
- provider_type: remote::runpod
|
||||
- provider_type: remote::openai
|
||||
- provider_type: remote::anthropic
|
||||
- provider_type: remote::gemini
|
||||
- provider_type: remote::groq
|
||||
- provider_type: remote::llama-openai-compat
|
||||
- provider_type: remote::sambanova
|
||||
- provider_type: remote::passthrough
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_type: inline::faiss
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -7,20 +7,19 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildProvider,
|
||||
ModelInput,
|
||||
Provider,
|
||||
ProviderSpec,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.datatypes import RemoteProviderSpec
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import (
|
||||
MilvusVectorIOConfig,
|
||||
|
@ -29,117 +28,17 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
|||
SQLiteVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.registry.inference import available_providers
|
||||
from llama_stack.providers.remote.inference.anthropic.models import (
|
||||
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.bedrock.models import (
|
||||
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.cerebras.models import (
|
||||
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.databricks.databricks import (
|
||||
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.fireworks.models import (
|
||||
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.gemini.models import (
|
||||
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.groq.models import (
|
||||
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.nvidia.models import (
|
||||
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.openai.models import (
|
||||
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.runpod.runpod import (
|
||||
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.sambanova.models import (
|
||||
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.together.models import (
|
||||
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
from llama_stack.templates.template import (
|
||||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
get_model_registry,
|
||||
get_shield_registry,
|
||||
)
|
||||
|
||||
|
||||
def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
||||
"""Get model entries for a specific provider type."""
|
||||
model_entries_map = {
|
||||
"openai": OPENAI_MODEL_ENTRIES,
|
||||
"fireworks": FIREWORKS_MODEL_ENTRIES,
|
||||
"together": TOGETHER_MODEL_ENTRIES,
|
||||
"anthropic": ANTHROPIC_MODEL_ENTRIES,
|
||||
"gemini": GEMINI_MODEL_ENTRIES,
|
||||
"groq": GROQ_MODEL_ENTRIES,
|
||||
"sambanova": SAMBANOVA_MODEL_ENTRIES,
|
||||
"cerebras": CEREBRAS_MODEL_ENTRIES,
|
||||
"bedrock": BEDROCK_MODEL_ENTRIES,
|
||||
"databricks": DATABRICKS_MODEL_ENTRIES,
|
||||
"nvidia": NVIDIA_MODEL_ENTRIES,
|
||||
"runpod": RUNPOD_MODEL_ENTRIES,
|
||||
}
|
||||
|
||||
# Special handling for providers with dynamic model entries
|
||||
if provider_type == "ollama":
|
||||
return [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": "${env.OLLAMA_EMBEDDING_DIMENSION:=384}",
|
||||
},
|
||||
),
|
||||
]
|
||||
elif provider_type == "vllm":
|
||||
return [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.VLLM_INFERENCE_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
|
||||
return model_entries_map.get(provider_type, [])
|
||||
|
||||
|
||||
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
||||
"""Get model entries for a specific provider type."""
|
||||
safety_model_entries_map = {
|
||||
"ollama": [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
return safety_model_entries_map.get(provider_type, [])
|
||||
|
||||
|
||||
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
|
||||
"""Get configuration for a provider using its adapter's config class."""
|
||||
config_class = instantiate_class_type(provider_spec.config_class)
|
||||
|
@ -150,40 +49,48 @@ def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
|
|||
return {}
|
||||
|
||||
|
||||
def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
|
||||
all_providers = available_providers()
|
||||
ENABLED_INFERENCE_PROVIDERS = [
|
||||
"ollama",
|
||||
"vllm",
|
||||
"tgi",
|
||||
"fireworks",
|
||||
"together",
|
||||
"gemini",
|
||||
"groq",
|
||||
"sambanova",
|
||||
"anthropic",
|
||||
"openai",
|
||||
"cerebras",
|
||||
"nvidia",
|
||||
"bedrock",
|
||||
]
|
||||
|
||||
# Filter out inline providers and watsonx - the starter distro only exposes remote providers
|
||||
INFERENCE_PROVIDER_IDS = {
|
||||
"vllm": "${env.VLLM_URL:+vllm}",
|
||||
"tgi": "${env.TGI_URL:+tgi}",
|
||||
"cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
|
||||
"nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
|
||||
}
|
||||
|
||||
|
||||
def get_remote_inference_providers() -> list[Provider]:
|
||||
# Filter out inline providers and some others - the starter distro only exposes remote providers
|
||||
remote_providers = [
|
||||
provider
|
||||
for provider in all_providers
|
||||
# TODO: re-add once the Python 3.13 issue is fixed
|
||||
# discussion: https://github.com/meta-llama/llama-stack/pull/2327#discussion_r2156883828
|
||||
if hasattr(provider, "adapter") and provider.adapter.adapter_type != "watsonx"
|
||||
for provider in available_providers()
|
||||
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
|
||||
]
|
||||
|
||||
providers = []
|
||||
available_models = {}
|
||||
|
||||
inference_providers = []
|
||||
for provider_spec in remote_providers:
|
||||
provider_type = provider_spec.adapter.adapter_type
|
||||
|
||||
# Build the environment variable name for enabling this provider
|
||||
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
||||
model_entries = _get_model_entries_for_provider(provider_type)
|
||||
if provider_type in INFERENCE_PROVIDER_IDS:
|
||||
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
|
||||
else:
|
||||
provider_id = provider_type.replace("-", "_").replace("::", "_")
|
||||
config = _get_config_for_provider(provider_spec)
|
||||
providers.append(
|
||||
(
|
||||
f"${{env.{env_var}:=__disabled__}}",
|
||||
provider_type,
|
||||
model_entries,
|
||||
config,
|
||||
)
|
||||
)
|
||||
available_models[f"${{env.{env_var}:=__disabled__}}"] = model_entries
|
||||
|
||||
inference_providers = []
|
||||
for provider_id, provider_type, model_entries, config in providers:
|
||||
inference_providers.append(
|
||||
Provider(
|
||||
provider_id=provider_id,
|
||||
|
@ -191,31 +98,13 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
|
|||
config=config,
|
||||
)
|
||||
)
|
||||
available_models[provider_id] = model_entries
|
||||
return inference_providers, available_models
|
||||
|
||||
|
||||
# build a list of shields for all possible providers
|
||||
def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
|
||||
available_models = {}
|
||||
for provider in providers:
|
||||
provider_type = provider.provider_type.split("::")[1]
|
||||
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
|
||||
if len(safety_model_entries) == 0:
|
||||
continue
|
||||
|
||||
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
||||
provider_id = f"${{env.{env_var}:=__disabled__}}"
|
||||
|
||||
available_models[provider_id] = safety_model_entries
|
||||
|
||||
return available_models
|
||||
return inference_providers
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
remote_inference_providers, available_models = get_remote_inference_providers()
|
||||
remote_inference_providers = get_remote_inference_providers()
|
||||
name = "starter"
|
||||
# For build config, use BuildProvider with only provider_type and module
|
||||
|
||||
providers = {
|
||||
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
|
||||
+ [BuildProvider(provider_type="inline::sentence-transformers")],
|
||||
|
@ -254,15 +143,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
embedding_provider = Provider(
|
||||
provider_id="${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}",
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
post_training_provider = Provider(
|
||||
provider_id="huggingface",
|
||||
provider_type="inline::huggingface",
|
||||
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
|
@ -273,19 +157,14 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id=embedding_provider.provider_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
|
||||
default_models, ids_conflict_in_models = get_model_registry(available_models)
|
||||
|
||||
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
|
||||
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
|
||||
default_shields = [
|
||||
# if the
|
||||
ShieldInput(
|
||||
shield_id="llama-guard",
|
||||
provider_id="${env.SAFETY_MODEL:+llama-guard}",
|
||||
provider_shield_id="${env.SAFETY_MODEL:=}",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
@ -294,7 +173,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
|
@ -302,22 +180,22 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"inference": remote_inference_providers + [embedding_provider],
|
||||
"vector_io": [
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_FAISS:=faiss}",
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_SQLITE_VEC:=__disabled__}",
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_MILVUS:=__disabled__}",
|
||||
provider_id="${env.MILVUS_URL:+milvus}",
|
||||
provider_type="inline::milvus",
|
||||
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_CHROMADB:=__disabled__}",
|
||||
provider_id="${env.CHROMADB_URL:+chromadb}",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}/",
|
||||
|
@ -325,7 +203,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
|
||||
provider_id="${env.PGVECTOR_DB:+pgvector}",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
|
@ -336,12 +214,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
),
|
||||
],
|
||||
"files": [files_provider],
|
||||
"post_training": [post_training_provider],
|
||||
},
|
||||
default_models=[embedding_model] + default_models,
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
# TODO: add a way to enable/disable shields on the fly
|
||||
default_shields=shields,
|
||||
default_shields=default_shields,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
@ -385,17 +261,5 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"http://localhost:11434",
|
||||
"Ollama URL",
|
||||
),
|
||||
"OLLAMA_INFERENCE_MODEL": (
|
||||
"",
|
||||
"Optional Ollama Inference Model to register on startup",
|
||||
),
|
||||
"OLLAMA_EMBEDDING_MODEL": (
|
||||
"",
|
||||
"Optional Ollama Embedding Model to register on startup",
|
||||
),
|
||||
"OLLAMA_EMBEDDING_DIMENSION": (
|
||||
"384",
|
||||
"Ollama Embedding Dimension",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue