mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-24 10:00:02 +00:00
feat: introduce "enabled" field for providers
Closes: #2622 Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
5561f1c36d
commit
514b0aa4c5
8 changed files with 353 additions and 327 deletions
File diff suppressed because it is too large
Load diff
|
|
@ -78,22 +78,25 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
|
|||
if provider_type == "ollama":
|
||||
return [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
|
||||
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=}",
|
||||
model_type=ModelType.llm,
|
||||
enabled=False,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
|
||||
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=}",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": "${env.OLLAMA_EMBEDDING_DIMENSION:=384}",
|
||||
},
|
||||
enabled=False,
|
||||
),
|
||||
]
|
||||
elif provider_type == "vllm":
|
||||
return [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.VLLM_INFERENCE_MODEL:=__disabled__}",
|
||||
provider_model_id="${env.VLLM_INFERENCE_MODEL:=}",
|
||||
model_type=ModelType.llm,
|
||||
enabled=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -129,29 +132,29 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
|
|||
provider_type = provider_spec.adapter.adapter_type
|
||||
|
||||
# Build the environment variable name for enabling this provider
|
||||
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
||||
# only get the provider type after the ::
|
||||
model_entries = _get_model_entries_for_provider(provider_type)
|
||||
config = _get_config_for_provider(provider_spec)
|
||||
providers.append(
|
||||
(
|
||||
f"${{env.{env_var}:=__disabled__}}",
|
||||
provider_type,
|
||||
model_entries,
|
||||
config,
|
||||
)
|
||||
)
|
||||
available_models[f"${{env.{env_var}:=__disabled__}}"] = model_entries
|
||||
available_models[provider_type] = model_entries
|
||||
|
||||
inference_providers = []
|
||||
for provider_id, provider_type, model_entries, config in providers:
|
||||
for provider_type, model_entries, config in providers:
|
||||
inference_providers.append(
|
||||
Provider(
|
||||
provider_id=provider_id,
|
||||
provider_id=provider_type,
|
||||
provider_type=f"remote::{provider_type}",
|
||||
config=config,
|
||||
enabled=False,
|
||||
)
|
||||
)
|
||||
available_models[provider_id] = model_entries
|
||||
available_models[provider_type] = model_entries
|
||||
return inference_providers, available_models
|
||||
|
||||
|
||||
|
|
@ -162,33 +165,33 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
|
||||
vector_io_providers = [
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_FAISS:=faiss}",
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_SQLITE_VEC:=__disabled__}",
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
enabled=False,
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_MILVUS:=__disabled__}",
|
||||
provider_id="milvus",
|
||||
provider_type="inline::milvus",
|
||||
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
enabled=False,
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_CHROMADB:=__disabled__}",
|
||||
provider_id="chromadb",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:=}"),
|
||||
config=ChromaVectorIOConfig.sample_run_config(),
|
||||
enabled=False,
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
|
||||
provider_id="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
db="${env.PGVECTOR_DB:=}",
|
||||
user="${env.PGVECTOR_USER:=}",
|
||||
password="${env.PGVECTOR_PASSWORD:=}",
|
||||
),
|
||||
config=PGVectorVectorIOConfig.sample_run_config(),
|
||||
enabled=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -216,14 +219,14 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
embedding_provider = Provider(
|
||||
provider_id="${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}",
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
post_training_provider = Provider(
|
||||
provider_id="huggingface",
|
||||
provider_type="inline::huggingface",
|
||||
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
config=HuggingFacePostTrainingConfig.sample_run_config(),
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
|
|
|
|||
|
|
@ -105,7 +105,8 @@ class RunConfigSettings(BaseModel):
|
|||
if api_providers := self.provider_overrides.get(api_str):
|
||||
# Convert Provider objects to dicts for YAML serialization
|
||||
provider_configs[api_str] = [
|
||||
p.model_dump(exclude_none=True) if isinstance(p, Provider) else p for p in api_providers
|
||||
p.model_dump(exclude_defaults=True, exclude_none=True) if isinstance(p, Provider) else p
|
||||
for p in api_providers
|
||||
]
|
||||
continue
|
||||
|
||||
|
|
@ -133,7 +134,7 @@ class RunConfigSettings(BaseModel):
|
|||
provider_id=provider_id,
|
||||
provider_type=provider_type,
|
||||
config=config,
|
||||
).model_dump(exclude_none=True)
|
||||
).model_dump(exclude_defaults=True, exclude_none=True)
|
||||
)
|
||||
|
||||
# Get unique set of APIs from providers
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue