mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
feat: introduce "enabled" field for providers
Closes: #2622 Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
5561f1c36d
commit
514b0aa4c5
8 changed files with 353 additions and 327 deletions
|
@ -146,11 +146,10 @@ in the runtime configuration to help route to the correct provider.""",
|
||||||
|
|
||||||
|
|
||||||
class Provider(BaseModel):
|
class Provider(BaseModel):
|
||||||
# provider_id of None means that the provider is not enabled - this happens
|
provider_id: str
|
||||||
# when the provider is enabled via a conditional environment variable
|
|
||||||
provider_id: str | None
|
|
||||||
provider_type: str
|
provider_type: str
|
||||||
config: dict[str, Any]
|
config: dict[str, Any]
|
||||||
|
enabled: bool = Field(default=True, description="Whether the provider is enabled")
|
||||||
|
|
||||||
|
|
||||||
class LoggingConfig(BaseModel):
|
class LoggingConfig(BaseModel):
|
||||||
|
|
|
@ -199,7 +199,7 @@ def validate_and_prepare_providers(
|
||||||
|
|
||||||
specs = {}
|
specs = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
if not provider.provider_id or not provider.enabled:
|
||||||
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -99,19 +99,10 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
method = getattr(impls[api], register_method)
|
method = getattr(impls[api], register_method)
|
||||||
for obj in objects:
|
for obj in objects:
|
||||||
# Do not register models on disabled providers
|
# Do not register models on disabled providers
|
||||||
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
|
if hasattr(obj, "provider_model_id") and obj.provider_model_id is not None and not obj.enabled:
|
||||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
|
||||||
continue
|
|
||||||
# In complex templates, like our starter template, we may have dynamic model ids
|
|
||||||
# given by environment variables. This allows those environment variables to have
|
|
||||||
# a default value of __disabled__ to skip registration of the model if not set.
|
|
||||||
if (
|
|
||||||
hasattr(obj, "provider_model_id")
|
|
||||||
and obj.provider_model_id is not None
|
|
||||||
and "__disabled__" in obj.provider_model_id
|
|
||||||
):
|
|
||||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
|
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# we want to maintain the type information in arguments to method.
|
# we want to maintain the type information in arguments to method.
|
||||||
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
||||||
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
||||||
|
@ -155,17 +146,20 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||||
for i, v in enumerate(config):
|
for i, v in enumerate(config):
|
||||||
try:
|
try:
|
||||||
# Special handling for providers: first resolve the provider_id to check if provider
|
# Special handling for providers: first resolve the provider_id to check if provider
|
||||||
# is disabled so that we can skip config env variable expansion and avoid validation errors
|
# is disabled so that we can skip config env variable expansion and avoid validation
|
||||||
|
# errors
|
||||||
if isinstance(v, dict) and "provider_id" in v:
|
if isinstance(v, dict) and "provider_id" in v:
|
||||||
try:
|
try:
|
||||||
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
|
# We have to set a default to True because we use Pydantic
|
||||||
if resolved_provider_id == "__disabled__":
|
# exclude_defaults=True from the serializer so the loaded config only has
|
||||||
|
# 'enabled' field when it is set to False explicitly.
|
||||||
|
if not v.get("enabled", True):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
|
f"Skipping config env variable expansion for disabled provider: {v.get('provider_type', '') + '/' if v.get('provider_type', '') else ''}{v.get('provider_id', '')}"
|
||||||
)
|
)
|
||||||
# Create a copy with resolved provider_id but original config
|
# Create a copy with resolved provider_id but original config
|
||||||
disabled_provider = v.copy()
|
disabled_provider = v.copy()
|
||||||
disabled_provider["provider_id"] = resolved_provider_id
|
disabled_provider["provider_id"] = v["provider_id"]
|
||||||
result.append(disabled_provider)
|
result.append(disabled_provider)
|
||||||
continue
|
continue
|
||||||
except EnvVarError:
|
except EnvVarError:
|
||||||
|
|
|
@ -68,5 +68,5 @@ class HuggingFacePostTrainingConfig(BaseModel):
|
||||||
dataloader_pin_memory: bool = True
|
dataloader_pin_memory: bool = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||||
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}
|
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}
|
||||||
|
|
|
@ -25,6 +25,7 @@ class ProviderModelEntry(BaseModel):
|
||||||
llama_model: str | None = None
|
llama_model: str | None = None
|
||||||
model_type: ModelType = ModelType.llm
|
model_type: ModelType = ModelType.llm
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
enabled: bool = Field(default=True, description="Whether the model is enabled")
|
||||||
|
|
||||||
|
|
||||||
def get_huggingface_repo(model_descriptor: str) -> str | None:
|
def get_huggingface_repo(model_descriptor: str) -> str | None:
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -78,22 +78,25 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
|
||||||
if provider_type == "ollama":
|
if provider_type == "ollama":
|
||||||
return [
|
return [
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
|
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=}",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
|
enabled=False,
|
||||||
),
|
),
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
|
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=}",
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata={
|
metadata={
|
||||||
"embedding_dimension": "${env.OLLAMA_EMBEDDING_DIMENSION:=384}",
|
"embedding_dimension": "${env.OLLAMA_EMBEDDING_DIMENSION:=384}",
|
||||||
},
|
},
|
||||||
|
enabled=False,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
elif provider_type == "vllm":
|
elif provider_type == "vllm":
|
||||||
return [
|
return [
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="${env.VLLM_INFERENCE_MODEL:=__disabled__}",
|
provider_model_id="${env.VLLM_INFERENCE_MODEL:=}",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
|
enabled=False,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -129,29 +132,29 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
|
||||||
provider_type = provider_spec.adapter.adapter_type
|
provider_type = provider_spec.adapter.adapter_type
|
||||||
|
|
||||||
# Build the environment variable name for enabling this provider
|
# Build the environment variable name for enabling this provider
|
||||||
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
# only get the provider type after the ::
|
||||||
model_entries = _get_model_entries_for_provider(provider_type)
|
model_entries = _get_model_entries_for_provider(provider_type)
|
||||||
config = _get_config_for_provider(provider_spec)
|
config = _get_config_for_provider(provider_spec)
|
||||||
providers.append(
|
providers.append(
|
||||||
(
|
(
|
||||||
f"${{env.{env_var}:=__disabled__}}",
|
|
||||||
provider_type,
|
provider_type,
|
||||||
model_entries,
|
model_entries,
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
available_models[f"${{env.{env_var}:=__disabled__}}"] = model_entries
|
available_models[provider_type] = model_entries
|
||||||
|
|
||||||
inference_providers = []
|
inference_providers = []
|
||||||
for provider_id, provider_type, model_entries, config in providers:
|
for provider_type, model_entries, config in providers:
|
||||||
inference_providers.append(
|
inference_providers.append(
|
||||||
Provider(
|
Provider(
|
||||||
provider_id=provider_id,
|
provider_id=provider_type,
|
||||||
provider_type=f"remote::{provider_type}",
|
provider_type=f"remote::{provider_type}",
|
||||||
config=config,
|
config=config,
|
||||||
|
enabled=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
available_models[provider_id] = model_entries
|
available_models[provider_type] = model_entries
|
||||||
return inference_providers, available_models
|
return inference_providers, available_models
|
||||||
|
|
||||||
|
|
||||||
|
@ -162,33 +165,33 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
|
|
||||||
vector_io_providers = [
|
vector_io_providers = [
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="${env.ENABLE_FAISS:=faiss}",
|
provider_id="faiss",
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
),
|
),
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="${env.ENABLE_SQLITE_VEC:=__disabled__}",
|
provider_id="sqlite-vec",
|
||||||
provider_type="inline::sqlite-vec",
|
provider_type="inline::sqlite-vec",
|
||||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
|
enabled=False,
|
||||||
),
|
),
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="${env.ENABLE_MILVUS:=__disabled__}",
|
provider_id="milvus",
|
||||||
provider_type="inline::milvus",
|
provider_type="inline::milvus",
|
||||||
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
|
enabled=False,
|
||||||
),
|
),
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="${env.ENABLE_CHROMADB:=__disabled__}",
|
provider_id="chromadb",
|
||||||
provider_type="remote::chromadb",
|
provider_type="remote::chromadb",
|
||||||
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:=}"),
|
config=ChromaVectorIOConfig.sample_run_config(),
|
||||||
|
enabled=False,
|
||||||
),
|
),
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
|
provider_id="pgvector",
|
||||||
provider_type="remote::pgvector",
|
provider_type="remote::pgvector",
|
||||||
config=PGVectorVectorIOConfig.sample_run_config(
|
config=PGVectorVectorIOConfig.sample_run_config(),
|
||||||
db="${env.PGVECTOR_DB:=}",
|
enabled=False,
|
||||||
user="${env.PGVECTOR_USER:=}",
|
|
||||||
password="${env.PGVECTOR_PASSWORD:=}",
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -216,14 +219,14 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
)
|
)
|
||||||
embedding_provider = Provider(
|
embedding_provider = Provider(
|
||||||
provider_id="${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}",
|
provider_id="sentence-transformers",
|
||||||
provider_type="inline::sentence-transformers",
|
provider_type="inline::sentence-transformers",
|
||||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
post_training_provider = Provider(
|
post_training_provider = Provider(
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
provider_type="inline::huggingface",
|
provider_type="inline::huggingface",
|
||||||
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=HuggingFacePostTrainingConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
default_tool_groups = [
|
default_tool_groups = [
|
||||||
ToolGroupInput(
|
ToolGroupInput(
|
||||||
|
|
|
@ -105,7 +105,8 @@ class RunConfigSettings(BaseModel):
|
||||||
if api_providers := self.provider_overrides.get(api_str):
|
if api_providers := self.provider_overrides.get(api_str):
|
||||||
# Convert Provider objects to dicts for YAML serialization
|
# Convert Provider objects to dicts for YAML serialization
|
||||||
provider_configs[api_str] = [
|
provider_configs[api_str] = [
|
||||||
p.model_dump(exclude_none=True) if isinstance(p, Provider) else p for p in api_providers
|
p.model_dump(exclude_defaults=True, exclude_none=True) if isinstance(p, Provider) else p
|
||||||
|
for p in api_providers
|
||||||
]
|
]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -133,7 +134,7 @@ class RunConfigSettings(BaseModel):
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_type=provider_type,
|
provider_type=provider_type,
|
||||||
config=config,
|
config=config,
|
||||||
).model_dump(exclude_none=True)
|
).model_dump(exclude_defaults=True, exclude_none=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get unique set of APIs from providers
|
# Get unique set of APIs from providers
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue