llama-stack-mirror/llama_stack/distributions/template.py
IAN MILLER 98a5047f9d
feat(prompts): attach prompts to storage stores in run configs (#3893)
# What does this PR do?
<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
This PR is responsible for attaching prompts to storage stores in run
configs. It allows to specify prompts as stores in different
distributions. The need of this functionality was initiated in #3514

> Note, #3514 is divided on three separate PRs. Current PR is the first
of three.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
Manual testing and updated CI unit tests

Prerequisites:

1. `uv run --with llama-stack llama stack list-deps starter | xargs -L1
uv pip install`

2. `llama stack run starter `

```
INFO     2025-10-23 15:36:17,387 llama_stack.cli.stack.run:100 cli: Using run configuration:                            
         /Users/ianmiller/llama-stack/llama_stack/distributions/starter/run.yaml                                        
INFO     2025-10-23 15:36:17,423 llama_stack.cli.stack.run:157 cli: HTTPS enabled with certificates:                    
           Key: None                                                                                                    
           Cert: None                                                                                                   
INFO     2025-10-23 15:36:17,424 llama_stack.cli.stack.run:159 cli: Listening on ['::', '0.0.0.0']:8321                 
INFO     2025-10-23 15:36:17,749 llama_stack.core.server.server:521 core::server: Run configuration:                    
INFO     2025-10-23 15:36:17,756 llama_stack.core.server.server:524 core::server: apis:                                 
         - agents                                                                                                       
         - batches                                                                                                      
         - datasetio                                                                                                    
         - eval                                                                                                         
         - files                                                                                                        
         - inference                                                                                                    
         - post_training                                                                                                
         - safety                                                                                                       
         - scoring                                                                                                      
         - tool_runtime                                                                                                 
         - vector_io                                                                                                    
         image_name: starter                                                                                            
         providers:                                                                                                     
           agents:                                                                                                      
           - config:                                                                                                    
               persistence:                                                                                             
                 agent_state:                                                                                           
                   backend: kv_default                                                                                  
                   namespace: agents                                                                                    
                 responses:                                                                                             
                   backend: sql_default                                                                                 
                   max_write_queue_size: 10000                                                                          
                   num_writers: 4                                                                                       
                   table_name: responses                                                                                
             provider_id: meta-reference                                                                                
             provider_type: inline::meta-reference                                                                      
           batches:                                                                                                     
           - config:                                                                                                    
               kvstore:                                                                                                 
                 backend: kv_default                                                                                    
                 namespace: batches                                                                                     
             provider_id: reference                                                                                     
             provider_type: inline::reference                                                                           
           datasetio:                                                                                                   
           - config:                                                                                                    
               kvstore:                                                                                                 
                 backend: kv_default                                                                                    
                 namespace: datasetio::huggingface                                                                      
             provider_id: huggingface                                                                                   
             provider_type: remote::huggingface                                                                         
           - config:                                                                                                    
               kvstore:                                                                                                 
                 backend: kv_default                                                                                    
                 namespace: datasetio::localfs                                                                          
             provider_id: localfs                                                                                       
             provider_type: inline::localfs                                                                             
           eval:                                                                                                        
           - config:                                                                                                    
               kvstore:                                                                                                 
                 backend: kv_default                                                                                    
                 namespace: eval                                                                                        
             provider_id: meta-reference                                                                                
             provider_type: inline::meta-reference                                                                      
           files:                                                                                                       
           - config:                                                                                                    
               metadata_store:                                                                                          
                 backend: sql_default                                                                                   
                 table_name: files_metadata                                                                             
               storage_dir: /Users/ianmiller/.llama/distributions/starter/files                                         
             provider_id: meta-reference-files                                                                          
             provider_type: inline::localfs                                                                             
           inference:                                                                                                   
           - config:                                                                                                    
               api_key: '********'                                                                                      
               url: https://api.fireworks.ai/inference/v1                                                               
             provider_id: fireworks                                                                                     
             provider_type: remote::fireworks                                                                           
           - config:                                                                                                    
               api_key: '********'                                                                                      
               url: https://api.together.xyz/v1                                                                         
             provider_id: together                                                                                      
             provider_type: remote::together                                                                            
           - config: {}                                                                                                 
             provider_id: bedrock                                                                                       
             provider_type: remote::bedrock                                                                             
           - config:                                                                                                    
               api_key: '********'                                                                                      
               base_url: https://api.openai.com/v1                                                                      
             provider_id: openai                                                                                        
             provider_type: remote::openai                                                                              
           - config:                                                                                                    
               api_key: '********'                                                                                      
             provider_id: anthropic                                                                                     
             provider_type: remote::anthropic                                                                           
           - config:                                                                                                    
               api_key: '********'                                                                                      
             provider_id: gemini                                                                                        
             provider_type: remote::gemini                                                                              
           - config:                                                                                                    
               api_key: '********'                                                                                      
               url: https://api.groq.com                                                                                
             provider_id: groq                                                                                          
             provider_type: remote::groq                                                                                
           - config:                                                                                                    
               api_key: '********'                                                                                      
               url: https://api.sambanova.ai/v1                                                                         
             provider_id: sambanova                                                                                     
             provider_type: remote::sambanova                                                                           
           - config: {}                                                                                                 
             provider_id: sentence-transformers                                                                         
             provider_type: inline::sentence-transformers                                                               
           post_training:                                                                                               
           - config:                                                                                                    
               checkpoint_format: meta                                                                                  
             provider_id: torchtune-cpu                                                                                 
             provider_type: inline::torchtune-cpu                                                                       
           safety:                                                                                                      
           - config:                                                                                                    
               excluded_categories: []                                                                                  
             provider_id: llama-guard                                                                                   
             provider_type: inline::llama-guard                                                                         
           - config: {}                                                                                                 
             provider_id: code-scanner                                                                                  
             provider_type: inline::code-scanner                                                                        
           scoring:                                                                                                     
           - config: {}                                                                                                 
             provider_id: basic                                                                                         
             provider_type: inline::basic                                                                               
           - config: {}                                                                                                 
             provider_id: llm-as-judge                                                                                  
             provider_type: inline::llm-as-judge                                                                        
           - config:                                                                                                    
               openai_api_key: '********'                                                                               
             provider_id: braintrust                                                                                    
             provider_type: inline::braintrust                                                                          
           tool_runtime:                                                                                                
           - config:                                                                                                    
               api_key: '********'                                                                                      
               max_results: 3                                                                                           
             provider_id: brave-search                                                                                  
             provider_type: remote::brave-search                                                                        
           - config:                                                                                                    
               api_key: '********'                                                                                      
               max_results: 3                                                                                           
             provider_id: tavily-search                                                                                 
             provider_type: remote::tavily-search                                                                       
           - config: {}                                                                                                 
             provider_id: rag-runtime                                                                                   
             provider_type: inline::rag-runtime                                                                         
           - config: {}                                                                                                 
             provider_id: model-context-protocol                                                                        
             provider_type: remote::model-context-protocol                                                              
           vector_io:                                                                                                   
           - config:                                                                                                    
               persistence:                                                                                             
                 backend: kv_default                                                                                    
                 namespace: vector_io::faiss                                                                            
             provider_id: faiss                                                                                         
             provider_type: inline::faiss                                                                               
           - config:                                                                                                    
               db_path: /Users/ianmiller/.llama/distributions/starter/sqlite_vec.db                                     
               persistence:                                                                                             
                 backend: kv_default                                                                                    
                 namespace: vector_io::sqlite_vec                                                                       
             provider_id: sqlite-vec                                                                                    
             provider_type: inline::sqlite-vec                                                                          
         registered_resources:                                                                                          
           benchmarks: []                                                                                               
           datasets: []                                                                                                 
           models: []                                                                                                   
           scoring_fns: []                                                                                              
           shields: []                                                                                                  
           tool_groups:                                                                                                 
           - provider_id: tavily-search                                                                                 
             toolgroup_id: builtin::websearch                                                                           
           - provider_id: rag-runtime                                                                                   
             toolgroup_id: builtin::rag                                                                                 
           vector_stores: []                                                                                            
         server:                                                                                                        
           port: 8321                                                                                                   
         storage:                                                                                                       
           backends:                                                                                                    
             kv_default:                                                                                                
               db_path: /Users/ianmiller/.llama/distributions/starter/kvstore.db                                        
               type: kv_sqlite                                                                                          
             sql_default:                                                                                               
               db_path: /Users/ianmiller/.llama/distributions/starter/sql_store.db                                      
               type: sql_sqlite                                                                                         
           stores:                                                                                                      
             conversations:                                                                                             
               backend: sql_default                                                                                     
               table_name: openai_conversations                                                                         
             inference:                                                                                                 
               backend: sql_default                                                                                     
               max_write_queue_size: 10000                                                                              
               num_writers: 4                                                                                           
               table_name: inference_store                                                                              
             metadata:                                                                                                  
               backend: kv_default                                                                                      
               namespace: registry                                                                                      
             prompts:                                                                                                   
               backend: kv_default                                                                                      
               namespace: prompts                                                                                       
         telemetry:                                                                                                     
           enabled: true                                                                                                
         vector_stores:                                                                                                 
           default_embedding_model:                                                                                     
             model_id: nomic-ai/nomic-embed-text-v1.5                                                                   
             provider_id: sentence-transformers                                                                         
           default_provider_id: faiss                                                                                   
         version: 2                                                                                                     
                                                                                                                        
INFO     2025-10-23 15:36:20,032 llama_stack.providers.utils.inference.inference_store:74 inference: Write queue        
         disabled for SQLite to avoid concurrency issues                                                                
WARNING  2025-10-23 15:36:20,422 llama_stack.providers.inline.telemetry.meta_reference.telemetry:84 telemetry:          
         OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry                                                     
INFO     2025-10-23 15:36:22,379 llama_stack.providers.utils.inference.openai_mixin:436 providers::utils:               
         OpenAIInferenceAdapter.list_provider_model_ids() returned 105 models                                           
INFO     2025-10-23 15:36:22,703 uvicorn.error:84 uncategorized: Started server process [17328]                         
INFO     2025-10-23 15:36:22,704 uvicorn.error:48 uncategorized: Waiting for application startup.                       
INFO     2025-10-23 15:36:22,706 llama_stack.core.server.server:179 core::server: Starting up Llama Stack server        
         (version: 0.3.0)                                                                                               
INFO     2025-10-23 15:36:22,707 llama_stack.core.stack:470 core: starting registry refresh task                        
INFO     2025-10-23 15:36:22,708 uvicorn.error:62 uncategorized: Application startup complete.                          
INFO     2025-10-23 15:36:22,708 uvicorn.error:216 uncategorized: Uvicorn running on http://['::', '0.0.0.0']:8321      
         (Press CTRL+C to quit)   
```
As you can see, prompts are attached to stores in config

Testing:

1. Create prompt:

```
curl -X POST http://localhost:8321/v1/prompts \                 
  -H "Content-Type: application/json" \
  -d '{
    "prompt": "Hello {{name}}! You are working at {{company}}. Your role is {{role}} at {{company}}. Remember, {{name}}, to be {{tone}}.",
    "variables": ["name", "company", "role", "tone"]
  }'
```

`{"prompt":"Hello {{name}}! You are working at {{company}}. Your role is
{{role}} at {{company}}. Remember, {{name}}, to be
{{tone}}.","version":1,"prompt_id":"pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e163f","variables":["name","company","role","tone"],"is_default":false}%
`

2. Get prompt:

`curl -X GET
http://localhost:8321/v1/prompts/pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e163f`

`{"prompt":"Hello {{name}}! You are working at {{company}}. Your role is
{{role}} at {{company}}. Remember, {{name}}, to be
{{tone}}.","version":1,"prompt_id":"pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e163f","variables":["name","company","role","tone"],"is_default":false}%
`

3. Query sqlite KV storage to check created prompt:

```
sqlite> .mode column
sqlite> .headers on
sqlite> SELECT * FROM kvstore WHERE key LIKE 'prompts:v1:%';
key                                                           value                                                         expiration
------------------------------------------------------------  ------------------------------------------------------------  ----------
prompts:v1:pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e  {"prompt_id": "pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab            
163f:1                                                        5f6e163f", "prompt": "Hello {{name}}! You are working at {{c            
                                                              ompany}}. Your role is {{role}} at {{company}}. Remember, {{            
                                                              name}}, to be {{tone}}.", "version": 1, "variables": ["name"            
                                                              , "company", "role", "tone"], "is_default": false}                      

prompts:v1:pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e  1                                                                       
163f:default                                                                                                                          
sqlite> 
```
2025-10-27 11:12:12 -07:00

465 lines
18 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from typing import Any, Literal
import jinja2
import rich
import yaml
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetPurpose
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import (
LLAMA_STACK_RUN_CONFIG_VERSION,
Api,
BenchmarkInput,
BuildConfig,
BuildProvider,
DatasetInput,
DistributionSpec,
ModelInput,
Provider,
SafetyConfig,
ShieldInput,
TelemetryConfig,
ToolGroupInput,
VectorStoresConfig,
)
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
SqlStoreReference,
StorageBackendType,
)
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
def filter_empty_values(obj: Any) -> Any:
"""Recursively filter out specific empty values from a dictionary or list.
This function removes:
- Empty strings ('') only when they are the 'module' field
- Empty dictionaries ({}) only when they are the 'config' field
- None values (always excluded)
"""
if obj is None:
return None
if isinstance(obj, dict):
filtered = {}
for key, value in obj.items():
# Special handling for specific fields
if key == "module" and isinstance(value, str) and value == "":
# Skip empty module strings
continue
elif key == "config" and isinstance(value, dict) and not value:
# Skip empty config dictionaries
continue
elif key == "container_image" and not value:
# Skip empty container_image names
continue
else:
# For all other fields, recursively filter but preserve empty values
filtered_value = filter_empty_values(value)
# if filtered_value is not None:
filtered[key] = filtered_value
return filtered
elif isinstance(obj, list):
filtered = []
for item in obj:
filtered_item = filter_empty_values(item)
if filtered_item is not None:
filtered.append(filtered_item)
return filtered
else:
# For all other types (including empty strings and dicts that aren't module/config),
# preserve them as-is
return obj
def get_model_registry(
available_models: dict[str, list[ProviderModelEntry]],
) -> tuple[list[ModelInput], bool]:
models = []
# check for conflicts in model ids
all_ids = set()
ids_conflict = False
for _, entries in available_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
if model_id in all_ids:
ids_conflict = True
rich.print(
f"[yellow]Model id {model_id} conflicts; all model ids will be prefixed with provider id[/yellow]"
)
break
all_ids.update(ids)
if ids_conflict:
break
if ids_conflict:
break
for provider_id, entries in available_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
models.append(
ModelInput(
model_id=identifier,
provider_model_id=entry.provider_model_id,
provider_id=provider_id,
model_type=entry.model_type,
metadata=entry.metadata,
)
)
return models, ids_conflict
def get_shield_registry(
available_safety_models: dict[str, list[ProviderModelEntry]],
ids_conflict_in_models: bool,
) -> list[ShieldInput]:
shields = []
# check for conflicts in shield ids
all_ids = set()
ids_conflict = False
for _, entries in available_safety_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
if model_id in all_ids:
ids_conflict = True
rich.print(
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
)
break
all_ids.update(ids)
if ids_conflict:
break
if ids_conflict:
break
for provider_id, entries in available_safety_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
shields.append(
ShieldInput(
shield_id=identifier,
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
if ids_conflict_in_models
else entry.provider_model_id,
)
)
return shields
class DefaultModel(BaseModel):
model_id: str
doc_string: str
class RunConfigSettings(BaseModel):
provider_overrides: dict[str, list[Provider]] = Field(default_factory=dict)
default_models: list[ModelInput] | None = None
default_shields: list[ShieldInput] | None = None
default_tool_groups: list[ToolGroupInput] | None = None
default_datasets: list[DatasetInput] | None = None
default_benchmarks: list[BenchmarkInput] | None = None
vector_stores_config: VectorStoresConfig | None = None
safety_config: SafetyConfig | None = None
telemetry: TelemetryConfig = Field(default_factory=lambda: TelemetryConfig(enabled=True))
storage_backends: dict[str, Any] | None = None
storage_stores: dict[str, Any] | None = None
def run_config(
self,
name: str,
providers: dict[str, list[BuildProvider]],
container_image: str | None = None,
) -> dict:
provider_registry = get_provider_registry()
provider_configs = {}
for api_str, provider_objs in providers.items():
if api_providers := self.provider_overrides.get(api_str):
# Convert Provider objects to dicts for YAML serialization
provider_configs[api_str] = [p.model_dump(exclude_none=True) for p in api_providers]
continue
provider_configs[api_str] = []
for provider in provider_objs:
api = Api(api_str)
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Unknown provider type: {provider.provider_type} for API: {api_str}")
provider_id = provider.provider_type.split("::")[-1]
config_class = provider_registry[api][provider.provider_type].config_class
assert config_class is not None, (
f"No config class for provider type: {provider.provider_type} for API: {api_str}"
)
config_class = instantiate_class_type(config_class)
if hasattr(config_class, "sample_run_config"):
config = config_class.sample_run_config(__distro_dir__=f"~/.llama/distributions/{name}")
else:
config = {}
# BuildProvider does not have a config attribute; skip assignment
provider_configs[api_str].append(
Provider(
provider_id=provider_id,
provider_type=provider.provider_type,
config=config,
).model_dump(exclude_none=True)
)
# Get unique set of APIs from providers
apis = sorted(providers.keys())
storage_backends = self.storage_backends or {
"kv_default": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=f"~/.llama/distributions/{name}",
db_name="kvstore.db",
),
"sql_default": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=f"~/.llama/distributions/{name}",
db_name="sql_store.db",
),
}
storage_stores = self.storage_stores or {
"metadata": KVStoreReference(
backend="kv_default",
namespace="registry",
).model_dump(exclude_none=True),
"inference": InferenceStoreReference(
backend="sql_default",
table_name="inference_store",
).model_dump(exclude_none=True),
"conversations": SqlStoreReference(
backend="sql_default",
table_name="openai_conversations",
).model_dump(exclude_none=True),
"prompts": KVStoreReference(
backend="kv_default",
namespace="prompts",
).model_dump(exclude_none=True),
}
storage_config = dict(
backends=storage_backends,
stores=storage_stores,
)
# Return a dict that matches StackRunConfig structure
config = {
"version": LLAMA_STACK_RUN_CONFIG_VERSION,
"image_name": name,
"container_image": container_image,
"apis": apis,
"providers": provider_configs,
"storage": storage_config,
"registered_resources": {
"models": [m.model_dump(exclude_none=True) for m in (self.default_models or [])],
"shields": [s.model_dump(exclude_none=True) for s in (self.default_shields or [])],
"vector_dbs": [],
"datasets": [d.model_dump(exclude_none=True) for d in (self.default_datasets or [])],
"scoring_fns": [],
"benchmarks": [b.model_dump(exclude_none=True) for b in (self.default_benchmarks or [])],
"tool_groups": [t.model_dump(exclude_none=True) for t in (self.default_tool_groups or [])],
},
"server": {
"port": 8321,
},
"telemetry": self.telemetry.model_dump(exclude_none=True) if self.telemetry else None,
}
if self.vector_stores_config:
config["vector_stores"] = self.vector_stores_config.model_dump(exclude_none=True)
if self.safety_config:
config["safety"] = self.safety_config.model_dump(exclude_none=True)
return config
class DistributionTemplate(BaseModel):
"""
Represents a Llama Stack distribution instance that can generate configuration
and documentation files.
"""
name: str
description: str
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
# Now uses BuildProvider for build config, not Provider
providers: dict[str, list[BuildProvider]]
run_configs: dict[str, RunConfigSettings]
template_path: Path | None = None
# Optional configuration
run_config_env_vars: dict[str, tuple[str, str]] | None = None
container_image: str | None = None
available_models_by_provider: dict[str, list[ProviderModelEntry]] | None = None
# we may want to specify additional pip packages without necessarily indicating a
# specific "default" inference store (which is what typically used to dictate additional
# pip packages)
additional_pip_packages: list[str] | None = None
def build_config(self) -> BuildConfig:
additional_pip_packages: list[str] = []
for run_config in self.run_configs.values():
run_config_ = run_config.run_config(self.name, self.providers, self.container_image)
# TODO: This is a hack to get the dependencies for internal APIs into build
# We should have a better way to do this by formalizing the concept of "internal" APIs
# and providers, with a way to specify dependencies for them.
storage_cfg = run_config_.get("storage", {})
for backend_cfg in storage_cfg.get("backends", {}).values():
store_type = backend_cfg.get("type")
if not store_type:
continue
if str(store_type).startswith("kv_"):
additional_pip_packages.extend(get_kv_pip_packages(backend_cfg))
elif str(store_type).startswith("sql_"):
additional_pip_packages.extend(get_sql_pip_packages(backend_cfg))
if self.additional_pip_packages:
additional_pip_packages.extend(self.additional_pip_packages)
# Create minimal providers for build config (without runtime configs)
build_providers = {}
for api, providers in self.providers.items():
build_providers[api] = []
for provider in providers:
# Create a minimal build provider object with only essential build information
build_provider = BuildProvider(
provider_type=provider.provider_type,
module=provider.module,
)
build_providers[api].append(build_provider)
return BuildConfig(
distribution_spec=DistributionSpec(
description=self.description,
container_image=self.container_image,
providers=build_providers,
),
image_type=LlamaStackImageType.VENV.value, # default to venv
additional_pip_packages=sorted(set(additional_pip_packages)),
)
def generate_markdown_docs(self) -> str:
providers_table = "| API | Provider(s) |\n"
providers_table += "|-----|-------------|\n"
for api, providers in sorted(self.providers.items()):
providers_str = ", ".join(f"`{p.provider_type}`" for p in providers)
providers_table += f"| {api} | {providers_str} |\n"
if self.template_path is not None:
template = self.template_path.read_text()
comment = "<!-- This file was auto-generated by distro_codegen.py, please edit source -->\n"
orphantext = "---\norphan: true\n---\n"
if template.startswith(orphantext):
template = template.replace(orphantext, orphantext + comment)
else:
template = comment + template
# Render template with rich-generated table
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
# NOTE: autoescape is required to prevent XSS attacks
autoescape=True,
)
template = env.from_string(template)
default_models = []
if self.available_models_by_provider:
has_multiple_providers = len(self.available_models_by_provider.keys()) > 1
for provider_id, model_entries in self.available_models_by_provider.items():
for model_entry in model_entries:
doc_parts = []
if model_entry.aliases:
doc_parts.append(f"aliases: {', '.join(model_entry.aliases)}")
if has_multiple_providers:
doc_parts.append(f"provider: {provider_id}")
default_models.append(
DefaultModel(
model_id=model_entry.provider_model_id,
doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""),
)
)
return template.render(
name=self.name,
description=self.description,
providers=self.providers,
providers_table=providers_table,
run_config_env_vars=self.run_config_env_vars,
default_models=default_models,
)
return ""
def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None:
def enum_representer(dumper, data):
return dumper.represent_scalar("tag:yaml.org,2002:str", data.value)
# Register YAML representer for enums
yaml.add_representer(ModelType, enum_representer)
yaml.add_representer(DatasetPurpose, enum_representer)
yaml.add_representer(StorageBackendType, enum_representer)
yaml.SafeDumper.add_representer(ModelType, enum_representer)
yaml.SafeDumper.add_representer(DatasetPurpose, enum_representer)
yaml.SafeDumper.add_representer(StorageBackendType, enum_representer)
for output_dir in [yaml_output_dir, doc_output_dir]:
output_dir.mkdir(parents=True, exist_ok=True)
build_config = self.build_config()
with open(yaml_output_dir / "build.yaml", "w") as f:
yaml.safe_dump(
filter_empty_values(build_config.model_dump(exclude_none=True)),
f,
sort_keys=False,
)
for yaml_pth, settings in self.run_configs.items():
run_config = settings.run_config(self.name, self.providers, self.container_image)
with open(yaml_output_dir / yaml_pth, "w") as f:
yaml.safe_dump(
filter_empty_values(run_config),
f,
sort_keys=False,
)
if self.template_path:
docs = self.generate_markdown_docs()
with open(doc_output_dir / f"{self.name}.md", "w") as f:
f.write(docs if docs.endswith("\n") else docs + "\n")