mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
refactor(env)!: enhanced environment variable substitution (#2490)
# What does this PR do? This commit significantly improves the environment variable substitution functionality in Llama Stack configuration files: * The version field in configuration files has been changed from string to integer type for better type consistency across build and run configurations. * The environment variable substitution system for ${env.FOO:} was fixed and properly returns an error * The environment variable substitution system for ${env.FOO+} returns None instead of an empty strings, it better matches type annotations in config fields * The system includes automatic type conversion for boolean, integer, and float values. * The error messages have been enhanced to provide clearer guidance when environment variables are missing, including suggestions for using default values or conditional syntax. * Comprehensive documentation has been added to the configuration guide explaining all supported syntax patterns, best practices, and runtime override capabilities. * Multiple provider configurations have been updated to use the new conditional syntax for optional API keys, making the system more flexible for different deployment scenarios. The telemetry configuration has been improved to properly handle optional endpoints with appropriate validation, ensuring that required endpoints are specified when their corresponding sinks are enabled. * There were many instances of ${env.NVIDIA_API_KEY:} that should have caused the code to fail. However, due to a bug, the distro server was still being started, and early validation wasn’t triggered. As a result, failures were likely being handled downstream by the providers. I’ve maintained similar behavior by using ${env.NVIDIA_API_KEY:+}, though I believe this is incorrect for many configurations. I’ll leave it to each provider to correct it as needed. * Environment variable substitution now uses the same syntax as Bash parameter expansion. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
36d70637b9
commit
43c1f39bd6
91 changed files with 1053 additions and 892 deletions
|
@ -23,7 +23,7 @@ class LocalfsFilesImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"storage_dir": "${env.FILES_STORAGE_DIR:" + __distro_dir__ + "/files}",
|
||||
"storage_dir": "${env.FILES_STORAGE_DIR:=" + __distro_dir__ + "/files}",
|
||||
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="files_metadata.db",
|
||||
|
|
|
@ -49,11 +49,11 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
def sample_run_config(
|
||||
cls,
|
||||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
||||
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
|
||||
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
|
||||
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:=null}",
|
||||
quantization_type: str = "${env.QUANTIZATION_TYPE:=bf16}",
|
||||
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:=0}",
|
||||
max_batch_size: str = "${env.MAX_BATCH_SIZE:=1}",
|
||||
max_seq_len: str = "${env.MAX_SEQ_LEN:=4096}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -44,10 +44,10 @@ class VLLMConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||
"max_model_len": "${env.MAX_MODEL_LEN:4096}",
|
||||
"max_num_seqs": "${env.MAX_NUM_SEQS:4}",
|
||||
"enforce_eager": "${env.ENFORCE_EAGER:False}",
|
||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}",
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:=1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:=4096}",
|
||||
"max_model_len": "${env.MAX_MODEL_LEN:=4096}",
|
||||
"max_num_seqs": "${env.MAX_NUM_SEQS:=4}",
|
||||
"enforce_eager": "${env.ENFORCE_EAGER:=False}",
|
||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:=0.3}",
|
||||
}
|
||||
|
|
|
@ -17,5 +17,5 @@ class BraintrustScoringConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_api_key": "${env.OPENAI_API_KEY:}",
|
||||
"openai_api_key": "${env.OPENAI_API_KEY:+}",
|
||||
}
|
||||
|
|
|
@ -20,12 +20,12 @@ class TelemetrySink(StrEnum):
|
|||
|
||||
|
||||
class TelemetryConfig(BaseModel):
|
||||
otel_trace_endpoint: str = Field(
|
||||
default="http://localhost:4318/v1/traces",
|
||||
otel_trace_endpoint: str | None = Field(
|
||||
default=None,
|
||||
description="The OpenTelemetry collector endpoint URL for traces",
|
||||
)
|
||||
otel_metric_endpoint: str = Field(
|
||||
default="http://localhost:4318/v1/metrics",
|
||||
otel_metric_endpoint: str | None = Field(
|
||||
default=None,
|
||||
description="The OpenTelemetry collector endpoint URL for metrics",
|
||||
)
|
||||
service_name: str = Field(
|
||||
|
@ -52,7 +52,7 @@ class TelemetryConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:\u200b}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:=\u200b}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:=console,sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
|
|
@ -87,12 +87,16 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
trace.set_tracer_provider(provider)
|
||||
_TRACER_PROVIDER = provider
|
||||
if TelemetrySink.OTEL_TRACE in self.config.sinks:
|
||||
if self.config.otel_trace_endpoint is None:
|
||||
raise ValueError("otel_trace_endpoint is required when OTEL_TRACE is enabled")
|
||||
span_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_trace_endpoint,
|
||||
)
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
if self.config.otel_metric_endpoint is None:
|
||||
raise ValueError("otel_metric_endpoint is required when OTEL_METRIC is enabled")
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=self.config.otel_metric_endpoint,
|
||||
|
|
|
@ -19,5 +19,5 @@ class QdrantVectorIOConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"path": "${env.QDRANT_PATH:~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
|
||||
"path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
|
||||
}
|
||||
|
|
|
@ -15,5 +15,5 @@ class SQLiteVectorIOConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
}
|
||||
|
|
|
@ -54,8 +54,8 @@ class NvidiaDatasetIOConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.NVIDIA_API_KEY:}",
|
||||
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
|
||||
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
|
||||
"datasets_url": "${env.NVIDIA_DATASETS_URL:http://nemo.test}",
|
||||
"api_key": "${env.NVIDIA_API_KEY:+}",
|
||||
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:=default}",
|
||||
"project_id": "${env.NVIDIA_PROJECT_ID:=test-project}",
|
||||
"datasets_url": "${env.NVIDIA_DATASETS_URL:=http://nemo.test}",
|
||||
}
|
||||
|
|
|
@ -25,5 +25,5 @@ class NVIDIAEvalConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
|
||||
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}",
|
||||
}
|
||||
|
|
|
@ -55,7 +55,7 @@ class NVIDIAConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
|
||||
"api_key": "${env.NVIDIA_API_KEY:}",
|
||||
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
|
||||
"url": "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}",
|
||||
"api_key": "${env.NVIDIA_API_KEY:+}",
|
||||
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:=True}",
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ class OllamaImplConfig(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", raise_on_connect_error: bool = True, **kwargs
|
||||
cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", raise_on_connect_error: bool = True, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
|
|
|
@ -25,6 +25,6 @@ class RunpodImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.RUNPOD_URL:}",
|
||||
"api_token": "${env.RUNPOD_API_TOKEN:}",
|
||||
"url": "${env.RUNPOD_URL:+}",
|
||||
"api_token": "${env.RUNPOD_API_TOKEN:+}",
|
||||
}
|
||||
|
|
|
@ -26,5 +26,5 @@ class TogetherImplConfig(BaseModel):
|
|||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY:}",
|
||||
"api_key": "${env.TOGETHER_API_KEY:+}",
|
||||
}
|
||||
|
|
|
@ -34,9 +34,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
@classmethod
|
||||
def validate_tls_verify(cls, v):
|
||||
if isinstance(v, str):
|
||||
# Check if it's a boolean string
|
||||
if v.lower() in ("true", "false"):
|
||||
return v.lower() == "true"
|
||||
# Otherwise, treat it as a cert path
|
||||
cert_path = Path(v).expanduser().resolve()
|
||||
if not cert_path.exists():
|
||||
|
@ -54,7 +51,7 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
):
|
||||
return {
|
||||
"url": url,
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
|
||||
"api_token": "${env.VLLM_API_TOKEN:fake}",
|
||||
"tls_verify": "${env.VLLM_TLS_VERIFY:true}",
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
|
||||
"api_token": "${env.VLLM_API_TOKEN:=fake}",
|
||||
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ class WatsonXConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
|
||||
"api_key": "${env.WATSONX_API_KEY:}",
|
||||
"project_id": "${env.WATSONX_PROJECT_ID:}",
|
||||
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
|
||||
"api_key": "${env.WATSONX_API_KEY:+}",
|
||||
"project_id": "${env.WATSONX_PROJECT_ID:+}",
|
||||
}
|
||||
|
|
|
@ -55,10 +55,10 @@ class NvidiaPostTrainingConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.NVIDIA_API_KEY:}",
|
||||
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
|
||||
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
|
||||
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}",
|
||||
"api_key": "${env.NVIDIA_API_KEY:+}",
|
||||
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:=default}",
|
||||
"project_id": "${env.NVIDIA_PROJECT_ID:=test-project}",
|
||||
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -35,6 +35,6 @@ class NVIDIASafetyConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
|
||||
"config_id": "${env.NVIDIA_GUARDRAILS_CONFIG_ID:self-check}",
|
||||
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}",
|
||||
"config_id": "${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}",
|
||||
}
|
||||
|
|
|
@ -22,6 +22,6 @@ class BraveSearchToolConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.BRAVE_SEARCH_API_KEY:}",
|
||||
"api_key": "${env.BRAVE_SEARCH_API_KEY:+}",
|
||||
"max_results": 3,
|
||||
}
|
||||
|
|
|
@ -22,6 +22,6 @@ class TavilySearchToolConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.TAVILY_SEARCH_API_KEY:}",
|
||||
"api_key": "${env.TAVILY_SEARCH_API_KEY:+}",
|
||||
"max_results": 3,
|
||||
}
|
||||
|
|
|
@ -17,5 +17,5 @@ class WolframAlphaToolConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.WOLFRAM_ALPHA_API_KEY:}",
|
||||
"api_key": "${env.WOLFRAM_ALPHA_API_KEY:+}",
|
||||
}
|
||||
|
|
|
@ -22,8 +22,8 @@ class PGVectorVectorIOConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
host: str = "${env.PGVECTOR_HOST:localhost}",
|
||||
port: int = "${env.PGVECTOR_PORT:5432}",
|
||||
host: str = "${env.PGVECTOR_HOST:=localhost}",
|
||||
port: int = "${env.PGVECTOR_PORT:=5432}",
|
||||
db: str = "${env.PGVECTOR_DB}",
|
||||
user: str = "${env.PGVECTOR_USER}",
|
||||
password: str = "${env.PGVECTOR_PASSWORD}",
|
||||
|
|
|
@ -45,8 +45,8 @@ class RedisKVStoreConfig(CommonConfig):
|
|||
return {
|
||||
"type": "redis",
|
||||
"namespace": None,
|
||||
"host": "${env.REDIS_HOST:localhost}",
|
||||
"port": "${env.REDIS_PORT:6379}",
|
||||
"host": "${env.REDIS_HOST:=localhost}",
|
||||
"port": "${env.REDIS_PORT:=6379}",
|
||||
}
|
||||
|
||||
|
||||
|
@ -66,7 +66,7 @@ class SqliteKVStoreConfig(CommonConfig):
|
|||
return {
|
||||
"type": "sqlite",
|
||||
"namespace": None,
|
||||
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
||||
|
||||
|
@ -84,12 +84,12 @@ class PostgresKVStoreConfig(CommonConfig):
|
|||
return {
|
||||
"type": "postgres",
|
||||
"namespace": None,
|
||||
"host": "${env.POSTGRES_HOST:localhost}",
|
||||
"port": "${env.POSTGRES_PORT:5432}",
|
||||
"db": "${env.POSTGRES_DB:llamastack}",
|
||||
"user": "${env.POSTGRES_USER:llamastack}",
|
||||
"password": "${env.POSTGRES_PASSWORD:llamastack}",
|
||||
"table_name": "${env.POSTGRES_TABLE_NAME:" + table_name + "}",
|
||||
"host": "${env.POSTGRES_HOST:=localhost}",
|
||||
"port": "${env.POSTGRES_PORT:=5432}",
|
||||
"db": "${env.POSTGRES_DB:=llamastack}",
|
||||
"user": "${env.POSTGRES_USER:=llamastack}",
|
||||
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
|
||||
"table_name": "${env.POSTGRES_TABLE_NAME:=" + table_name + "}",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
@ -131,12 +131,12 @@ class MongoDBKVStoreConfig(CommonConfig):
|
|||
return {
|
||||
"type": "mongodb",
|
||||
"namespace": None,
|
||||
"host": "${env.MONGODB_HOST:localhost}",
|
||||
"port": "${env.MONGODB_PORT:5432}",
|
||||
"host": "${env.MONGODB_HOST:=localhost}",
|
||||
"port": "${env.MONGODB_PORT:=5432}",
|
||||
"db": "${env.MONGODB_DB}",
|
||||
"user": "${env.MONGODB_USER}",
|
||||
"password": "${env.MONGODB_PASSWORD}",
|
||||
"collection_name": "${env.MONGODB_COLLECTION_NAME:" + collection_name + "}",
|
||||
"collection_name": "${env.MONGODB_COLLECTION_NAME:=" + collection_name + "}",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
|||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"):
|
||||
return cls(
|
||||
type="sqlite",
|
||||
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||
db_path="${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -78,11 +78,11 @@ class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
|||
def sample_run_config(cls, **kwargs):
|
||||
return cls(
|
||||
type="postgres",
|
||||
host="${env.POSTGRES_HOST:localhost}",
|
||||
port="${env.POSTGRES_PORT:5432}",
|
||||
db="${env.POSTGRES_DB:llamastack}",
|
||||
user="${env.POSTGRES_USER:llamastack}",
|
||||
password="${env.POSTGRES_PASSWORD:llamastack}",
|
||||
host="${env.POSTGRES_HOST:=localhost}",
|
||||
port="${env.POSTGRES_PORT:=5432}",
|
||||
db="${env.POSTGRES_DB:=llamastack}",
|
||||
user="${env.POSTGRES_USER:=llamastack}",
|
||||
password="${env.POSTGRES_PASSWORD:=llamastack}",
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue