This commit is contained in:
Sébastien Han 2025-07-09 15:47:41 +02:00 committed by GitHub
commit f2d3a6c8f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 353 additions and 327 deletions

View file

@ -146,11 +146,10 @@ in the runtime configuration to help route to the correct provider.""",
class Provider(BaseModel):
# provider_id of None means that the provider is not enabled - this happens
# when the provider is enabled via a conditional environment variable
provider_id: str | None
provider_id: str
provider_type: str
config: dict[str, Any]
enabled: bool = Field(default=True, description="Whether the provider is enabled")
class LoggingConfig(BaseModel):

View file

@ -199,7 +199,7 @@ def validate_and_prepare_providers(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
if not provider.provider_id or not provider.enabled:
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue

View file

@ -99,19 +99,10 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method)
for obj in objects:
# Do not register models on disabled providers
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue
# In complex templates, like our starter template, we may have dynamic model ids
# given by environment variables. This allows those environment variables to have
# a default value of __disabled__ to skip registration of the model if not set.
if (
hasattr(obj, "provider_model_id")
and obj.provider_model_id is not None
and "__disabled__" in obj.provider_model_id
):
if hasattr(obj, "provider_model_id") and obj.provider_model_id is not None and not obj.enabled:
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
continue
# we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
@ -155,17 +146,20 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
for i, v in enumerate(config):
try:
# Special handling for providers: first resolve the provider_id to check if provider
# is disabled so that we can skip config env variable expansion and avoid validation errors
# is disabled so that we can skip config env variable expansion and avoid validation
# errors
if isinstance(v, dict) and "provider_id" in v:
try:
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
if resolved_provider_id == "__disabled__":
# We have to set a default to True because we use Pydantic
# exclude_defaults=True from the serializer so the loaded config only has
# 'enabled' field when it is set to False explicitly.
if not v.get("enabled", True):
logger.debug(
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
f"Skipping config env variable expansion for disabled provider: {v.get('provider_type', '') + '/' if v.get('provider_type', '') else ''}{v.get('provider_id', '')}"
)
# Create a copy with resolved provider_id but original config
disabled_provider = v.copy()
disabled_provider["provider_id"] = resolved_provider_id
disabled_provider["provider_id"] = v["provider_id"]
result.append(disabled_provider)
continue
except EnvVarError:

View file

@ -68,5 +68,5 @@ class HuggingFacePostTrainingConfig(BaseModel):
dataloader_pin_memory: bool = True
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}

View file

@ -25,6 +25,7 @@ class ProviderModelEntry(BaseModel):
llama_model: str | None = None
model_type: ModelType = ModelType.llm
metadata: dict[str, Any] = Field(default_factory=dict)
enabled: bool = Field(default=True, description="Whether the model is enabled")
def get_huggingface_repo(model_descriptor: str) -> str | None:

File diff suppressed because it is too large Load diff

View file

@ -78,22 +78,25 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
if provider_type == "ollama":
return [
ProviderModelEntry(
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=}",
model_type=ModelType.llm,
enabled=False,
),
ProviderModelEntry(
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=}",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": "${env.OLLAMA_EMBEDDING_DIMENSION:=384}",
},
enabled=False,
),
]
elif provider_type == "vllm":
return [
ProviderModelEntry(
provider_model_id="${env.VLLM_INFERENCE_MODEL:=__disabled__}",
provider_model_id="${env.VLLM_INFERENCE_MODEL:=}",
model_type=ModelType.llm,
enabled=False,
),
]
@ -129,29 +132,29 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
provider_type = provider_spec.adapter.adapter_type
# Build the environment variable name for enabling this provider
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
# only get the provider type after the ::
model_entries = _get_model_entries_for_provider(provider_type)
config = _get_config_for_provider(provider_spec)
providers.append(
(
f"${{env.{env_var}:=__disabled__}}",
provider_type,
model_entries,
config,
)
)
available_models[f"${{env.{env_var}:=__disabled__}}"] = model_entries
available_models[provider_type] = model_entries
inference_providers = []
for provider_id, provider_type, model_entries, config in providers:
for provider_type, model_entries, config in providers:
inference_providers.append(
Provider(
provider_id=provider_id,
provider_id=provider_type,
provider_type=f"remote::{provider_type}",
config=config,
enabled=False,
)
)
available_models[provider_id] = model_entries
available_models[provider_type] = model_entries
return inference_providers, available_models
@ -162,33 +165,33 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_providers = [
Provider(
provider_id="${env.ENABLE_FAISS:=faiss}",
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_SQLITE_VEC:=__disabled__}",
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
enabled=False,
),
Provider(
provider_id="${env.ENABLE_MILVUS:=__disabled__}",
provider_id="milvus",
provider_type="inline::milvus",
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
enabled=False,
),
Provider(
provider_id="${env.ENABLE_CHROMADB:=__disabled__}",
provider_id="chromadb",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:=}"),
config=ChromaVectorIOConfig.sample_run_config(),
enabled=False,
),
Provider(
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}",
),
config=PGVectorVectorIOConfig.sample_run_config(),
enabled=False,
),
]
@ -216,14 +219,14 @@ def get_distribution_template() -> DistributionTemplate:
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
embedding_provider = Provider(
provider_id="${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}",
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
post_training_provider = Provider(
provider_id="huggingface",
provider_type="inline::huggingface",
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
config=HuggingFacePostTrainingConfig.sample_run_config(),
)
default_tool_groups = [
ToolGroupInput(

View file

@ -105,7 +105,8 @@ class RunConfigSettings(BaseModel):
if api_providers := self.provider_overrides.get(api_str):
# Convert Provider objects to dicts for YAML serialization
provider_configs[api_str] = [
p.model_dump(exclude_none=True) if isinstance(p, Provider) else p for p in api_providers
p.model_dump(exclude_defaults=True, exclude_none=True) if isinstance(p, Provider) else p
for p in api_providers
]
continue
@ -133,7 +134,7 @@ class RunConfigSettings(BaseModel):
provider_id=provider_id,
provider_type=provider_type,
config=config,
).model_dump(exclude_none=True)
).model_dump(exclude_defaults=True, exclude_none=True)
)
# Get unique set of APIs from providers