mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
# What does this PR do? - add /eval, /scoring, /datasetio API providers to distribution templates - regenerate build.yaml / run.yaml files - fix `template.py` to take in list of providers instead of only first one - override memory provider as faiss default for all distro (as only 1 memory provider is needed to start basic flow, chromadb/pgvector need additional setup step). ``` python llama_stack/scripts/distro_codegen.py ``` - updated README to start UI via conda builds. ## Test Plan ``` python llama_stack/scripts/distro_codegen.py ``` - Use newly generated `run.yaml` to start server ``` llama stack run ./llama_stack/templates/together/run.yaml ``` <img width="1191" alt="image" src="https://github.com/user-attachments/assets/62f7d179-0cd0-427c-b6e8-e087d4648f09"> #### Registration ``` ❯ llama-stack-client datasets register \ --dataset-id "mmlu" \ --provider-id "huggingface" \ --url "https://huggingface.co/datasets/llamastack/evals" \ --metadata '{"path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train"}' \ --schema '{"input_query": {"type": "string"}, "expected_answer": {"type": "string", "chat_completion_input": {"type": "string"}}}' ❯ llama-stack-client datasets list ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ metadata ┃ type ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━┩ │ mmlu │ huggingface │ {'path': 'llamastack/evals', 'name': │ dataset │ │ │ │ 'evals__mmlu__details', 'split': │ │ │ │ │ 'train'} │ │ └────────────┴─────────────┴─────────────────────────────────────────┴─────────┘ ``` ``` ❯ llama-stack-client datasets register \ --dataset-id "simpleqa" \ --provider-id "huggingface" \ --url "https://huggingface.co/datasets/llamastack/evals" \ --metadata '{"path": "llamastack/evals", "name": "evals__simpleqa", "split": "train"}' \ --schema '{"input_query": {"type": "string"}, "expected_answer": {"type": "string", "chat_completion_input": {"type": "string"}}}' ❯ llama-stack-client datasets list ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ metadata ┃ type ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━┩ │ mmlu │ huggingface │ {'path': 'llamastack/evals', 'name': 'evals__mmlu__details', │ dataset │ │ │ │ 'split': 'train'} │ │ │ simpleqa │ huggingface │ {'path': 'llamastack/evals', 'name': 'evals__simpleqa', │ dataset │ │ │ │ 'split': 'train'} │ │ └────────────┴─────────────┴───────────────────────────────────────────────────────────────┴─────────┘ ``` ``` ❯ llama-stack-client eval_tasks register \ > --eval-task-id meta-reference-mmlu \ > --provider-id meta-reference \ > --dataset-id mmlu \ > --scoring-functions basic::regex_parser_multiple_choice_answer ❯ llama-stack-client eval_tasks register \ --eval-task-id meta-reference-simpleqa \ --provider-id meta-reference \ --dataset-id simpleqa \ --scoring-functions llm-as-judge::405b-simpleqa ❯ llama-stack-client eval_tasks list ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓ ┃ dataset_id ┃ identifier ┃ metadata ┃ provider_id ┃ provider_resour… ┃ scoring_functio… ┃ type ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩ │ mmlu │ meta-reference-… │ {} │ meta-reference │ meta-reference-… │ ['basic::regex_… │ eval_task │ │ simpleqa │ meta-reference-… │ {} │ meta-reference │ meta-reference-… │ ['llm-as-judge:… │ eval_task │ └────────────┴──────────────────┴──────────┴────────────────┴──────────────────┴──────────────────┴───────────┘ ``` #### Test with UI ``` streamlit run app.py ``` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
166 lines
5.8 KiB
Python
166 lines
5.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from pathlib import Path
|
|
from typing import Dict, List, Literal, Optional, Tuple
|
|
|
|
import jinja2
|
|
import yaml
|
|
from pydantic import BaseModel, Field
|
|
|
|
from llama_stack.distribution.datatypes import (
|
|
Api,
|
|
BuildConfig,
|
|
DistributionSpec,
|
|
ModelInput,
|
|
Provider,
|
|
ShieldInput,
|
|
StackRunConfig,
|
|
)
|
|
from llama_stack.distribution.distribution import get_provider_registry
|
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
|
|
|
|
class RunConfigSettings(BaseModel):
|
|
provider_overrides: Dict[str, List[Provider]] = Field(default_factory=dict)
|
|
default_models: Optional[List[ModelInput]] = None
|
|
default_shields: Optional[List[ShieldInput]] = None
|
|
|
|
def run_config(
|
|
self,
|
|
name: str,
|
|
providers: Dict[str, List[str]],
|
|
docker_image: Optional[str] = None,
|
|
) -> StackRunConfig:
|
|
provider_registry = get_provider_registry()
|
|
|
|
provider_configs = {}
|
|
for api_str, provider_types in providers.items():
|
|
if api_providers := self.provider_overrides.get(api_str):
|
|
provider_configs[api_str] = api_providers
|
|
continue
|
|
|
|
provider_configs[api_str] = []
|
|
for provider_type in provider_types:
|
|
provider_id = provider_type.split("::")[-1]
|
|
|
|
api = Api(api_str)
|
|
if provider_type not in provider_registry[api]:
|
|
raise ValueError(
|
|
f"Unknown provider type: {provider_type} for API: {api_str}"
|
|
)
|
|
|
|
config_class = provider_registry[api][provider_type].config_class
|
|
assert (
|
|
config_class is not None
|
|
), f"No config class for provider type: {provider_type} for API: {api_str}"
|
|
|
|
config_class = instantiate_class_type(config_class)
|
|
if hasattr(config_class, "sample_run_config"):
|
|
config = config_class.sample_run_config(
|
|
__distro_dir__=f"distributions/{name}"
|
|
)
|
|
else:
|
|
config = {}
|
|
|
|
provider_configs[api_str].append(
|
|
Provider(
|
|
provider_id=provider_id,
|
|
provider_type=provider_type,
|
|
config=config,
|
|
)
|
|
)
|
|
|
|
# Get unique set of APIs from providers
|
|
apis = list(sorted(providers.keys()))
|
|
|
|
return StackRunConfig(
|
|
image_name=name,
|
|
docker_image=docker_image,
|
|
conda_env=name,
|
|
apis=apis,
|
|
providers=provider_configs,
|
|
metadata_store=SqliteKVStoreConfig.sample_run_config(
|
|
__distro_dir__=f"distributions/{name}",
|
|
db_name="registry.db",
|
|
),
|
|
models=self.default_models or [],
|
|
shields=self.default_shields or [],
|
|
)
|
|
|
|
|
|
class DistributionTemplate(BaseModel):
|
|
"""
|
|
Represents a Llama Stack distribution instance that can generate configuration
|
|
and documentation files.
|
|
"""
|
|
|
|
name: str
|
|
description: str
|
|
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
|
|
|
|
providers: Dict[str, List[str]]
|
|
run_configs: Dict[str, RunConfigSettings]
|
|
template_path: Optional[Path] = None
|
|
|
|
# Optional configuration
|
|
run_config_env_vars: Optional[Dict[str, Tuple[str, str]]] = None
|
|
docker_image: Optional[str] = None
|
|
|
|
default_models: Optional[List[ModelInput]] = None
|
|
|
|
def build_config(self) -> BuildConfig:
|
|
return BuildConfig(
|
|
name=self.name,
|
|
distribution_spec=DistributionSpec(
|
|
description=self.description,
|
|
docker_image=self.docker_image,
|
|
providers=self.providers,
|
|
),
|
|
image_type="conda", # default to conda, can be overridden
|
|
)
|
|
|
|
def generate_markdown_docs(self) -> str:
|
|
providers_table = "| API | Provider(s) |\n"
|
|
providers_table += "|-----|-------------|\n"
|
|
|
|
for api, providers in sorted(self.providers.items()):
|
|
providers_str = ", ".join(f"`{p}`" for p in providers)
|
|
providers_table += f"| {api} | {providers_str} |\n"
|
|
|
|
template = self.template_path.read_text()
|
|
# Render template with rich-generated table
|
|
env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
|
|
template = env.from_string(template)
|
|
return template.render(
|
|
name=self.name,
|
|
description=self.description,
|
|
providers=self.providers,
|
|
providers_table=providers_table,
|
|
run_config_env_vars=self.run_config_env_vars,
|
|
default_models=self.default_models,
|
|
)
|
|
|
|
def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None:
|
|
for output_dir in [yaml_output_dir, doc_output_dir]:
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
build_config = self.build_config()
|
|
with open(yaml_output_dir / "build.yaml", "w") as f:
|
|
yaml.safe_dump(build_config.model_dump(), f, sort_keys=False)
|
|
|
|
for yaml_pth, settings in self.run_configs.items():
|
|
run_config = settings.run_config(
|
|
self.name, self.providers, self.docker_image
|
|
)
|
|
with open(yaml_output_dir / yaml_pth, "w") as f:
|
|
yaml.safe_dump(run_config.model_dump(), f, sort_keys=False)
|
|
|
|
if self.template_path:
|
|
docs = self.generate_markdown_docs()
|
|
with open(doc_output_dir / f"{self.name}.md", "w") as f:
|
|
f.write(docs if docs.endswith("\n") else docs + "\n")
|