mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> This PR is responsible for attaching prompts to storage stores in run configs. It allows to specify prompts as stores in different distributions. The need of this functionality was initiated in #3514 > Note, #3514 is divided on three separate PRs. Current PR is the first of three. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Manual testing and updated CI unit tests Prerequisites: 1. `uv run --with llama-stack llama stack list-deps starter | xargs -L1 uv pip install` 2. `llama stack run starter ` ``` INFO 2025-10-23 15:36:17,387 llama_stack.cli.stack.run:100 cli: Using run configuration: /Users/ianmiller/llama-stack/llama_stack/distributions/starter/run.yaml INFO 2025-10-23 15:36:17,423 llama_stack.cli.stack.run:157 cli: HTTPS enabled with certificates: Key: None Cert: None INFO 2025-10-23 15:36:17,424 llama_stack.cli.stack.run:159 cli: Listening on ['::', '0.0.0.0']:8321 INFO 2025-10-23 15:36:17,749 llama_stack.core.server.server:521 core::server: Run configuration: INFO 2025-10-23 15:36:17,756 llama_stack.core.server.server:524 core::server: apis: - agents - batches - datasetio - eval - files - inference - post_training - safety - scoring - tool_runtime - vector_io image_name: starter providers: agents: - config: persistence: agent_state: backend: kv_default namespace: agents responses: backend: sql_default max_write_queue_size: 10000 num_writers: 4 table_name: responses provider_id: meta-reference provider_type: inline::meta-reference batches: - config: kvstore: backend: kv_default namespace: batches provider_id: reference provider_type: inline::reference datasetio: - config: kvstore: backend: kv_default namespace: datasetio::huggingface provider_id: huggingface provider_type: remote::huggingface - config: kvstore: backend: kv_default namespace: datasetio::localfs provider_id: localfs provider_type: inline::localfs eval: - config: kvstore: backend: kv_default namespace: eval provider_id: meta-reference provider_type: inline::meta-reference files: - config: metadata_store: backend: sql_default table_name: files_metadata storage_dir: /Users/ianmiller/.llama/distributions/starter/files provider_id: meta-reference-files provider_type: inline::localfs inference: - config: api_key: '********' url: https://api.fireworks.ai/inference/v1 provider_id: fireworks provider_type: remote::fireworks - config: api_key: '********' url: https://api.together.xyz/v1 provider_id: together provider_type: remote::together - config: {} provider_id: bedrock provider_type: remote::bedrock - config: api_key: '********' base_url: https://api.openai.com/v1 provider_id: openai provider_type: remote::openai - config: api_key: '********' provider_id: anthropic provider_type: remote::anthropic - config: api_key: '********' provider_id: gemini provider_type: remote::gemini - config: api_key: '********' url: https://api.groq.com provider_id: groq provider_type: remote::groq - config: api_key: '********' url: https://api.sambanova.ai/v1 provider_id: sambanova provider_type: remote::sambanova - config: {} provider_id: sentence-transformers provider_type: inline::sentence-transformers post_training: - config: checkpoint_format: meta provider_id: torchtune-cpu provider_type: inline::torchtune-cpu safety: - config: excluded_categories: [] provider_id: llama-guard provider_type: inline::llama-guard - config: {} provider_id: code-scanner provider_type: inline::code-scanner scoring: - config: {} provider_id: basic provider_type: inline::basic - config: {} provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: openai_api_key: '********' provider_id: braintrust provider_type: inline::braintrust tool_runtime: - config: api_key: '********' max_results: 3 provider_id: brave-search provider_type: remote::brave-search - config: api_key: '********' max_results: 3 provider_id: tavily-search provider_type: remote::tavily-search - config: {} provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} provider_id: model-context-protocol provider_type: remote::model-context-protocol vector_io: - config: persistence: backend: kv_default namespace: vector_io::faiss provider_id: faiss provider_type: inline::faiss - config: db_path: /Users/ianmiller/.llama/distributions/starter/sqlite_vec.db persistence: backend: kv_default namespace: vector_io::sqlite_vec provider_id: sqlite-vec provider_type: inline::sqlite-vec registered_resources: benchmarks: [] datasets: [] models: [] scoring_fns: [] shields: [] tool_groups: - provider_id: tavily-search toolgroup_id: builtin::websearch - provider_id: rag-runtime toolgroup_id: builtin::rag vector_stores: [] server: port: 8321 storage: backends: kv_default: db_path: /Users/ianmiller/.llama/distributions/starter/kvstore.db type: kv_sqlite sql_default: db_path: /Users/ianmiller/.llama/distributions/starter/sql_store.db type: sql_sqlite stores: conversations: backend: sql_default table_name: openai_conversations inference: backend: sql_default max_write_queue_size: 10000 num_writers: 4 table_name: inference_store metadata: backend: kv_default namespace: registry prompts: backend: kv_default namespace: prompts telemetry: enabled: true vector_stores: default_embedding_model: model_id: nomic-ai/nomic-embed-text-v1.5 provider_id: sentence-transformers default_provider_id: faiss version: 2 INFO 2025-10-23 15:36:20,032 llama_stack.providers.utils.inference.inference_store:74 inference: Write queue disabled for SQLite to avoid concurrency issues WARNING 2025-10-23 15:36:20,422 llama_stack.providers.inline.telemetry.meta_reference.telemetry:84 telemetry: OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry INFO 2025-10-23 15:36:22,379 llama_stack.providers.utils.inference.openai_mixin:436 providers::utils: OpenAIInferenceAdapter.list_provider_model_ids() returned 105 models INFO 2025-10-23 15:36:22,703 uvicorn.error:84 uncategorized: Started server process [17328] INFO 2025-10-23 15:36:22,704 uvicorn.error:48 uncategorized: Waiting for application startup. INFO 2025-10-23 15:36:22,706 llama_stack.core.server.server:179 core::server: Starting up Llama Stack server (version: 0.3.0) INFO 2025-10-23 15:36:22,707 llama_stack.core.stack:470 core: starting registry refresh task INFO 2025-10-23 15:36:22,708 uvicorn.error:62 uncategorized: Application startup complete. INFO 2025-10-23 15:36:22,708 uvicorn.error:216 uncategorized: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit) ``` As you can see, prompts are attached to stores in config Testing: 1. Create prompt: ``` curl -X POST http://localhost:8321/v1/prompts \ -H "Content-Type: application/json" \ -d '{ "prompt": "Hello {{name}}! You are working at {{company}}. Your role is {{role}} at {{company}}. Remember, {{name}}, to be {{tone}}.", "variables": ["name", "company", "role", "tone"] }' ``` `{"prompt":"Hello {{name}}! You are working at {{company}}. Your role is {{role}} at {{company}}. Remember, {{name}}, to be {{tone}}.","version":1,"prompt_id":"pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e163f","variables":["name","company","role","tone"],"is_default":false}% ` 2. Get prompt: `curl -X GET http://localhost:8321/v1/prompts/pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e163f` `{"prompt":"Hello {{name}}! You are working at {{company}}. Your role is {{role}} at {{company}}. Remember, {{name}}, to be {{tone}}.","version":1,"prompt_id":"pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e163f","variables":["name","company","role","tone"],"is_default":false}% ` 3. Query sqlite KV storage to check created prompt: ``` sqlite> .mode column sqlite> .headers on sqlite> SELECT * FROM kvstore WHERE key LIKE 'prompts:v1:%'; key value expiration ------------------------------------------------------------ ------------------------------------------------------------ ---------- prompts:v1:pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e {"prompt_id": "pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab 163f:1 5f6e163f", "prompt": "Hello {{name}}! You are working at {{c ompany}}. Your role is {{role}} at {{company}}. Remember, {{ name}}, to be {{tone}}.", "version": 1, "variables": ["name" , "company", "role", "tone"], "is_default": false} prompts:v1:pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e 1 163f:default sqlite> ```
465 lines
18 KiB
Python
465 lines
18 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from pathlib import Path
|
|
from typing import Any, Literal
|
|
|
|
import jinja2
|
|
import rich
|
|
import yaml
|
|
from pydantic import BaseModel, Field
|
|
|
|
from llama_stack.apis.datasets import DatasetPurpose
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.core.datatypes import (
|
|
LLAMA_STACK_RUN_CONFIG_VERSION,
|
|
Api,
|
|
BenchmarkInput,
|
|
BuildConfig,
|
|
BuildProvider,
|
|
DatasetInput,
|
|
DistributionSpec,
|
|
ModelInput,
|
|
Provider,
|
|
SafetyConfig,
|
|
ShieldInput,
|
|
TelemetryConfig,
|
|
ToolGroupInput,
|
|
VectorStoresConfig,
|
|
)
|
|
from llama_stack.core.distribution import get_provider_registry
|
|
from llama_stack.core.storage.datatypes import (
|
|
InferenceStoreReference,
|
|
KVStoreReference,
|
|
SqlStoreReference,
|
|
StorageBackendType,
|
|
)
|
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
|
from llama_stack.core.utils.image_types import LlamaStackImageType
|
|
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages
|
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
|
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
|
|
|
|
|
|
def filter_empty_values(obj: Any) -> Any:
|
|
"""Recursively filter out specific empty values from a dictionary or list.
|
|
|
|
This function removes:
|
|
- Empty strings ('') only when they are the 'module' field
|
|
- Empty dictionaries ({}) only when they are the 'config' field
|
|
- None values (always excluded)
|
|
"""
|
|
if obj is None:
|
|
return None
|
|
|
|
if isinstance(obj, dict):
|
|
filtered = {}
|
|
for key, value in obj.items():
|
|
# Special handling for specific fields
|
|
if key == "module" and isinstance(value, str) and value == "":
|
|
# Skip empty module strings
|
|
continue
|
|
elif key == "config" and isinstance(value, dict) and not value:
|
|
# Skip empty config dictionaries
|
|
continue
|
|
elif key == "container_image" and not value:
|
|
# Skip empty container_image names
|
|
continue
|
|
else:
|
|
# For all other fields, recursively filter but preserve empty values
|
|
filtered_value = filter_empty_values(value)
|
|
# if filtered_value is not None:
|
|
filtered[key] = filtered_value
|
|
return filtered
|
|
|
|
elif isinstance(obj, list):
|
|
filtered = []
|
|
for item in obj:
|
|
filtered_item = filter_empty_values(item)
|
|
if filtered_item is not None:
|
|
filtered.append(filtered_item)
|
|
return filtered
|
|
|
|
else:
|
|
# For all other types (including empty strings and dicts that aren't module/config),
|
|
# preserve them as-is
|
|
return obj
|
|
|
|
|
|
def get_model_registry(
|
|
available_models: dict[str, list[ProviderModelEntry]],
|
|
) -> tuple[list[ModelInput], bool]:
|
|
models = []
|
|
|
|
# check for conflicts in model ids
|
|
all_ids = set()
|
|
ids_conflict = False
|
|
|
|
for _, entries in available_models.items():
|
|
for entry in entries:
|
|
ids = [entry.provider_model_id] + entry.aliases
|
|
for model_id in ids:
|
|
if model_id in all_ids:
|
|
ids_conflict = True
|
|
rich.print(
|
|
f"[yellow]Model id {model_id} conflicts; all model ids will be prefixed with provider id[/yellow]"
|
|
)
|
|
break
|
|
all_ids.update(ids)
|
|
if ids_conflict:
|
|
break
|
|
if ids_conflict:
|
|
break
|
|
|
|
for provider_id, entries in available_models.items():
|
|
for entry in entries:
|
|
ids = [entry.provider_model_id] + entry.aliases
|
|
for model_id in ids:
|
|
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
|
models.append(
|
|
ModelInput(
|
|
model_id=identifier,
|
|
provider_model_id=entry.provider_model_id,
|
|
provider_id=provider_id,
|
|
model_type=entry.model_type,
|
|
metadata=entry.metadata,
|
|
)
|
|
)
|
|
return models, ids_conflict
|
|
|
|
|
|
def get_shield_registry(
|
|
available_safety_models: dict[str, list[ProviderModelEntry]],
|
|
ids_conflict_in_models: bool,
|
|
) -> list[ShieldInput]:
|
|
shields = []
|
|
|
|
# check for conflicts in shield ids
|
|
all_ids = set()
|
|
ids_conflict = False
|
|
|
|
for _, entries in available_safety_models.items():
|
|
for entry in entries:
|
|
ids = [entry.provider_model_id] + entry.aliases
|
|
for model_id in ids:
|
|
if model_id in all_ids:
|
|
ids_conflict = True
|
|
rich.print(
|
|
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
|
|
)
|
|
break
|
|
all_ids.update(ids)
|
|
if ids_conflict:
|
|
break
|
|
if ids_conflict:
|
|
break
|
|
|
|
for provider_id, entries in available_safety_models.items():
|
|
for entry in entries:
|
|
ids = [entry.provider_model_id] + entry.aliases
|
|
for model_id in ids:
|
|
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
|
shields.append(
|
|
ShieldInput(
|
|
shield_id=identifier,
|
|
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
|
|
if ids_conflict_in_models
|
|
else entry.provider_model_id,
|
|
)
|
|
)
|
|
|
|
return shields
|
|
|
|
|
|
class DefaultModel(BaseModel):
|
|
model_id: str
|
|
doc_string: str
|
|
|
|
|
|
class RunConfigSettings(BaseModel):
|
|
provider_overrides: dict[str, list[Provider]] = Field(default_factory=dict)
|
|
default_models: list[ModelInput] | None = None
|
|
default_shields: list[ShieldInput] | None = None
|
|
default_tool_groups: list[ToolGroupInput] | None = None
|
|
default_datasets: list[DatasetInput] | None = None
|
|
default_benchmarks: list[BenchmarkInput] | None = None
|
|
vector_stores_config: VectorStoresConfig | None = None
|
|
safety_config: SafetyConfig | None = None
|
|
telemetry: TelemetryConfig = Field(default_factory=lambda: TelemetryConfig(enabled=True))
|
|
storage_backends: dict[str, Any] | None = None
|
|
storage_stores: dict[str, Any] | None = None
|
|
|
|
def run_config(
|
|
self,
|
|
name: str,
|
|
providers: dict[str, list[BuildProvider]],
|
|
container_image: str | None = None,
|
|
) -> dict:
|
|
provider_registry = get_provider_registry()
|
|
provider_configs = {}
|
|
for api_str, provider_objs in providers.items():
|
|
if api_providers := self.provider_overrides.get(api_str):
|
|
# Convert Provider objects to dicts for YAML serialization
|
|
provider_configs[api_str] = [p.model_dump(exclude_none=True) for p in api_providers]
|
|
continue
|
|
|
|
provider_configs[api_str] = []
|
|
for provider in provider_objs:
|
|
api = Api(api_str)
|
|
if provider.provider_type not in provider_registry[api]:
|
|
raise ValueError(f"Unknown provider type: {provider.provider_type} for API: {api_str}")
|
|
provider_id = provider.provider_type.split("::")[-1]
|
|
config_class = provider_registry[api][provider.provider_type].config_class
|
|
assert config_class is not None, (
|
|
f"No config class for provider type: {provider.provider_type} for API: {api_str}"
|
|
)
|
|
|
|
config_class = instantiate_class_type(config_class)
|
|
if hasattr(config_class, "sample_run_config"):
|
|
config = config_class.sample_run_config(__distro_dir__=f"~/.llama/distributions/{name}")
|
|
else:
|
|
config = {}
|
|
# BuildProvider does not have a config attribute; skip assignment
|
|
provider_configs[api_str].append(
|
|
Provider(
|
|
provider_id=provider_id,
|
|
provider_type=provider.provider_type,
|
|
config=config,
|
|
).model_dump(exclude_none=True)
|
|
)
|
|
# Get unique set of APIs from providers
|
|
apis = sorted(providers.keys())
|
|
|
|
storage_backends = self.storage_backends or {
|
|
"kv_default": SqliteKVStoreConfig.sample_run_config(
|
|
__distro_dir__=f"~/.llama/distributions/{name}",
|
|
db_name="kvstore.db",
|
|
),
|
|
"sql_default": SqliteSqlStoreConfig.sample_run_config(
|
|
__distro_dir__=f"~/.llama/distributions/{name}",
|
|
db_name="sql_store.db",
|
|
),
|
|
}
|
|
|
|
storage_stores = self.storage_stores or {
|
|
"metadata": KVStoreReference(
|
|
backend="kv_default",
|
|
namespace="registry",
|
|
).model_dump(exclude_none=True),
|
|
"inference": InferenceStoreReference(
|
|
backend="sql_default",
|
|
table_name="inference_store",
|
|
).model_dump(exclude_none=True),
|
|
"conversations": SqlStoreReference(
|
|
backend="sql_default",
|
|
table_name="openai_conversations",
|
|
).model_dump(exclude_none=True),
|
|
"prompts": KVStoreReference(
|
|
backend="kv_default",
|
|
namespace="prompts",
|
|
).model_dump(exclude_none=True),
|
|
}
|
|
|
|
storage_config = dict(
|
|
backends=storage_backends,
|
|
stores=storage_stores,
|
|
)
|
|
|
|
# Return a dict that matches StackRunConfig structure
|
|
config = {
|
|
"version": LLAMA_STACK_RUN_CONFIG_VERSION,
|
|
"image_name": name,
|
|
"container_image": container_image,
|
|
"apis": apis,
|
|
"providers": provider_configs,
|
|
"storage": storage_config,
|
|
"registered_resources": {
|
|
"models": [m.model_dump(exclude_none=True) for m in (self.default_models or [])],
|
|
"shields": [s.model_dump(exclude_none=True) for s in (self.default_shields or [])],
|
|
"vector_dbs": [],
|
|
"datasets": [d.model_dump(exclude_none=True) for d in (self.default_datasets or [])],
|
|
"scoring_fns": [],
|
|
"benchmarks": [b.model_dump(exclude_none=True) for b in (self.default_benchmarks or [])],
|
|
"tool_groups": [t.model_dump(exclude_none=True) for t in (self.default_tool_groups or [])],
|
|
},
|
|
"server": {
|
|
"port": 8321,
|
|
},
|
|
"telemetry": self.telemetry.model_dump(exclude_none=True) if self.telemetry else None,
|
|
}
|
|
|
|
if self.vector_stores_config:
|
|
config["vector_stores"] = self.vector_stores_config.model_dump(exclude_none=True)
|
|
|
|
if self.safety_config:
|
|
config["safety"] = self.safety_config.model_dump(exclude_none=True)
|
|
|
|
return config
|
|
|
|
|
|
class DistributionTemplate(BaseModel):
|
|
"""
|
|
Represents a Llama Stack distribution instance that can generate configuration
|
|
and documentation files.
|
|
"""
|
|
|
|
name: str
|
|
description: str
|
|
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
|
|
|
|
# Now uses BuildProvider for build config, not Provider
|
|
providers: dict[str, list[BuildProvider]]
|
|
run_configs: dict[str, RunConfigSettings]
|
|
template_path: Path | None = None
|
|
|
|
# Optional configuration
|
|
run_config_env_vars: dict[str, tuple[str, str]] | None = None
|
|
container_image: str | None = None
|
|
|
|
available_models_by_provider: dict[str, list[ProviderModelEntry]] | None = None
|
|
|
|
# we may want to specify additional pip packages without necessarily indicating a
|
|
# specific "default" inference store (which is what typically used to dictate additional
|
|
# pip packages)
|
|
additional_pip_packages: list[str] | None = None
|
|
|
|
def build_config(self) -> BuildConfig:
|
|
additional_pip_packages: list[str] = []
|
|
for run_config in self.run_configs.values():
|
|
run_config_ = run_config.run_config(self.name, self.providers, self.container_image)
|
|
|
|
# TODO: This is a hack to get the dependencies for internal APIs into build
|
|
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
|
# and providers, with a way to specify dependencies for them.
|
|
|
|
storage_cfg = run_config_.get("storage", {})
|
|
for backend_cfg in storage_cfg.get("backends", {}).values():
|
|
store_type = backend_cfg.get("type")
|
|
if not store_type:
|
|
continue
|
|
if str(store_type).startswith("kv_"):
|
|
additional_pip_packages.extend(get_kv_pip_packages(backend_cfg))
|
|
elif str(store_type).startswith("sql_"):
|
|
additional_pip_packages.extend(get_sql_pip_packages(backend_cfg))
|
|
|
|
if self.additional_pip_packages:
|
|
additional_pip_packages.extend(self.additional_pip_packages)
|
|
|
|
# Create minimal providers for build config (without runtime configs)
|
|
build_providers = {}
|
|
for api, providers in self.providers.items():
|
|
build_providers[api] = []
|
|
for provider in providers:
|
|
# Create a minimal build provider object with only essential build information
|
|
build_provider = BuildProvider(
|
|
provider_type=provider.provider_type,
|
|
module=provider.module,
|
|
)
|
|
build_providers[api].append(build_provider)
|
|
|
|
return BuildConfig(
|
|
distribution_spec=DistributionSpec(
|
|
description=self.description,
|
|
container_image=self.container_image,
|
|
providers=build_providers,
|
|
),
|
|
image_type=LlamaStackImageType.VENV.value, # default to venv
|
|
additional_pip_packages=sorted(set(additional_pip_packages)),
|
|
)
|
|
|
|
def generate_markdown_docs(self) -> str:
|
|
providers_table = "| API | Provider(s) |\n"
|
|
providers_table += "|-----|-------------|\n"
|
|
|
|
for api, providers in sorted(self.providers.items()):
|
|
providers_str = ", ".join(f"`{p.provider_type}`" for p in providers)
|
|
providers_table += f"| {api} | {providers_str} |\n"
|
|
|
|
if self.template_path is not None:
|
|
template = self.template_path.read_text()
|
|
comment = "<!-- This file was auto-generated by distro_codegen.py, please edit source -->\n"
|
|
orphantext = "---\norphan: true\n---\n"
|
|
|
|
if template.startswith(orphantext):
|
|
template = template.replace(orphantext, orphantext + comment)
|
|
else:
|
|
template = comment + template
|
|
|
|
# Render template with rich-generated table
|
|
env = jinja2.Environment(
|
|
trim_blocks=True,
|
|
lstrip_blocks=True,
|
|
# NOTE: autoescape is required to prevent XSS attacks
|
|
autoescape=True,
|
|
)
|
|
template = env.from_string(template)
|
|
|
|
default_models = []
|
|
if self.available_models_by_provider:
|
|
has_multiple_providers = len(self.available_models_by_provider.keys()) > 1
|
|
for provider_id, model_entries in self.available_models_by_provider.items():
|
|
for model_entry in model_entries:
|
|
doc_parts = []
|
|
if model_entry.aliases:
|
|
doc_parts.append(f"aliases: {', '.join(model_entry.aliases)}")
|
|
if has_multiple_providers:
|
|
doc_parts.append(f"provider: {provider_id}")
|
|
|
|
default_models.append(
|
|
DefaultModel(
|
|
model_id=model_entry.provider_model_id,
|
|
doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""),
|
|
)
|
|
)
|
|
|
|
return template.render(
|
|
name=self.name,
|
|
description=self.description,
|
|
providers=self.providers,
|
|
providers_table=providers_table,
|
|
run_config_env_vars=self.run_config_env_vars,
|
|
default_models=default_models,
|
|
)
|
|
return ""
|
|
|
|
def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None:
|
|
def enum_representer(dumper, data):
|
|
return dumper.represent_scalar("tag:yaml.org,2002:str", data.value)
|
|
|
|
# Register YAML representer for enums
|
|
yaml.add_representer(ModelType, enum_representer)
|
|
yaml.add_representer(DatasetPurpose, enum_representer)
|
|
yaml.add_representer(StorageBackendType, enum_representer)
|
|
yaml.SafeDumper.add_representer(ModelType, enum_representer)
|
|
yaml.SafeDumper.add_representer(DatasetPurpose, enum_representer)
|
|
yaml.SafeDumper.add_representer(StorageBackendType, enum_representer)
|
|
|
|
for output_dir in [yaml_output_dir, doc_output_dir]:
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
build_config = self.build_config()
|
|
with open(yaml_output_dir / "build.yaml", "w") as f:
|
|
yaml.safe_dump(
|
|
filter_empty_values(build_config.model_dump(exclude_none=True)),
|
|
f,
|
|
sort_keys=False,
|
|
)
|
|
|
|
for yaml_pth, settings in self.run_configs.items():
|
|
run_config = settings.run_config(self.name, self.providers, self.container_image)
|
|
with open(yaml_output_dir / yaml_pth, "w") as f:
|
|
yaml.safe_dump(
|
|
filter_empty_values(run_config),
|
|
f,
|
|
sort_keys=False,
|
|
)
|
|
|
|
if self.template_path:
|
|
docs = self.generate_markdown_docs()
|
|
with open(doc_output_dir / f"{self.name}.md", "w") as f:
|
|
f.write(docs if docs.endswith("\n") else docs + "\n")
|