feat: introduce "enabled" field for providers

Closes: #2622
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-07-04 17:13:01 +02:00
parent 5561f1c36d
commit 514b0aa4c5
No known key found for this signature in database
8 changed files with 353 additions and 327 deletions

File diff suppressed because it is too large Load diff

View file

@ -78,22 +78,25 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
if provider_type == "ollama":
return [
ProviderModelEntry(
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=}",
model_type=ModelType.llm,
enabled=False,
),
ProviderModelEntry(
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=}",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": "${env.OLLAMA_EMBEDDING_DIMENSION:=384}",
},
enabled=False,
),
]
elif provider_type == "vllm":
return [
ProviderModelEntry(
provider_model_id="${env.VLLM_INFERENCE_MODEL:=__disabled__}",
provider_model_id="${env.VLLM_INFERENCE_MODEL:=}",
model_type=ModelType.llm,
enabled=False,
),
]
@ -129,29 +132,29 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
provider_type = provider_spec.adapter.adapter_type
# Build the environment variable name for enabling this provider
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
# only get the provider type after the ::
model_entries = _get_model_entries_for_provider(provider_type)
config = _get_config_for_provider(provider_spec)
providers.append(
(
f"${{env.{env_var}:=__disabled__}}",
provider_type,
model_entries,
config,
)
)
available_models[f"${{env.{env_var}:=__disabled__}}"] = model_entries
available_models[provider_type] = model_entries
inference_providers = []
for provider_id, provider_type, model_entries, config in providers:
for provider_type, model_entries, config in providers:
inference_providers.append(
Provider(
provider_id=provider_id,
provider_id=provider_type,
provider_type=f"remote::{provider_type}",
config=config,
enabled=False,
)
)
available_models[provider_id] = model_entries
available_models[provider_type] = model_entries
return inference_providers, available_models
@ -162,33 +165,33 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_providers = [
Provider(
provider_id="${env.ENABLE_FAISS:=faiss}",
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_SQLITE_VEC:=__disabled__}",
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
enabled=False,
),
Provider(
provider_id="${env.ENABLE_MILVUS:=__disabled__}",
provider_id="milvus",
provider_type="inline::milvus",
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
enabled=False,
),
Provider(
provider_id="${env.ENABLE_CHROMADB:=__disabled__}",
provider_id="chromadb",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:=}"),
config=ChromaVectorIOConfig.sample_run_config(),
enabled=False,
),
Provider(
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}",
),
config=PGVectorVectorIOConfig.sample_run_config(),
enabled=False,
),
]
@ -216,14 +219,14 @@ def get_distribution_template() -> DistributionTemplate:
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
embedding_provider = Provider(
provider_id="${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}",
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
post_training_provider = Provider(
provider_id="huggingface",
provider_type="inline::huggingface",
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
config=HuggingFacePostTrainingConfig.sample_run_config(),
)
default_tool_groups = [
ToolGroupInput(