feat: Add open benchmark template codegen (#1579)

## What does this PR do?

As title, add codegen for open-benchmark template

## test 

checked the new generated run.yaml file and it's identical before and
after the change

Also add small improvement to together template so that missing
TOGETHER_API_KEY won't crash the server which is the consistent user
experience as other remote providers
This commit is contained in:
Botao Chen 2025-03-12 11:12:08 -07:00 committed by GitHub
parent 4eee349acd
commit 0b0be70605
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 430 additions and 84 deletions

View file

@ -453,6 +453,40 @@
"transformers", "transformers",
"uvicorn" "uvicorn"
], ],
"open-benchmark": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"fastapi",
"fire",
"httpx",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"together",
"tqdm",
"transformers",
"uvicorn"
],
"remote-vllm": [ "remote-vllm": [
"aiosqlite", "aiosqlite",
"autoevals", "autoevals",

View file

@ -26,5 +26,5 @@ class TogetherImplConfig(BaseModel):
def sample_run_config(cls, **kwargs) -> Dict[str, Any]: def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return { return {
"url": "https://api.together.xyz/v1", "url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY}", "api_key": "${env.TOGETHER_API_KEY:}",
} }

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .open_benchmark import get_distribution_template # noqa: F401

View file

@ -0,0 +1,293 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Tuple
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
BenchmarkInput,
DatasetInput,
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
# in this template, we allow each API key to be optional
providers = [
(
"openai",
[
ProviderModelEntry(
provider_model_id="openai/gpt-4o",
model_type=ModelType.llm,
)
],
OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"),
),
(
"anthropic",
[
ProviderModelEntry(
provider_model_id="anthropic/claude-3-5-sonnet-latest",
model_type=ModelType.llm,
)
],
AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"),
),
(
"gemini",
[
ProviderModelEntry(
provider_model_id="gemini/gemini-1.5-flash",
model_type=ModelType.llm,
)
],
GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"),
),
(
"groq",
[],
GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"),
),
(
"together",
[],
TogetherImplConfig.sample_run_config(api_key="${env.TOGETHER_API_KEY:}"),
),
]
inference_providers = []
available_models = {}
for provider_id, model_entries, config in providers:
inference_providers.append(
Provider(
provider_id=provider_id,
provider_type=f"remote::{provider_id}",
config=config,
)
)
available_models[provider_id] = model_entries
return inference_providers, available_models
def get_distribution_template() -> DistributionTemplate:
inference_providers, available_models = get_inference_providers()
providers = {
"inference": [p.provider_type for p in inference_providers],
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
name = "open-benchmark"
vector_io_providers = [
Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
),
Provider(
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
db="${env.PGVECTOR_DB:}",
user="${env.PGVECTOR_USER:}",
password="${env.PGVECTOR_PASSWORD:}",
),
),
]
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
default_models = get_model_registry(available_models) + [
ModelInput(
model_id="meta-llama/Llama-3.3-70B-Instruct",
provider_id="groq",
provider_model_id="groq/llama-3.3-70b-versatile",
model_type=ModelType.llm,
),
ModelInput(
model_id="meta-llama/Llama-3.1-405B-Instruct",
provider_id="together",
provider_model_id="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
model_type=ModelType.llm,
),
]
default_datasets = [
DatasetInput(
dataset_id="simpleqa",
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"},
metadata={
"path": "llamastack/simpleqa",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
),
DatasetInput(
dataset_id="mmlu_cot",
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/mmlu_cot"},
metadata={
"path": "llamastack/mmlu_cot",
"name": "all",
"split": "test",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
),
DatasetInput(
dataset_id="gpqa_cot",
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"},
metadata={
"path": "llamastack/gpqa_0shot_cot",
"name": "gpqa_main",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
),
DatasetInput(
dataset_id="math_500",
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/math_500"},
metadata={
"path": "llamastack/math_500",
"split": "test",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
),
]
default_benchmarks = [
BenchmarkInput(
benchmark_id="meta-reference-simpleqa",
dataset_id="simpleqa",
scoring_functions=["llm-as-judge::405b-simpleqa"],
),
BenchmarkInput(
benchmark_id="meta-reference-mmlu-cot",
dataset_id="mmlu_cot",
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
),
BenchmarkInput(
benchmark_id="meta-reference-gpqa-cot",
dataset_id="gpqa_cot",
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
),
BenchmarkInput(
benchmark_id="meta-reference-math-500",
dataset_id="math_500",
scoring_functions=["basic::regex_parser_math_response"],
),
]
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Distribution for running open benchmarks",
container_image=None,
template_path=None,
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": inference_providers,
"vector_io": vector_io_providers,
},
default_models=default_models,
default_tool_groups=default_tool_groups,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
default_datasets=default_datasets,
default_benchmarks=default_benchmarks,
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"TOGETHER_API_KEY": (
"",
"Together API Key",
),
"OPENAI_API_KEY": (
"",
"OpenAI API Key",
),
"GEMINI_API_KEY": (
"",
"Gemini API Key",
),
"ANTHROPIC_API_KEY": (
"",
"Anthropic API Key",
),
"GROQ_API_KEY": (
"",
"Groq API Key",
),
},
)

View file

@ -38,7 +38,7 @@ providers:
- provider_id: sqlite-vec - provider_id: sqlite-vec
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec
config: config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/sqlite_vec.db
- provider_id: ${env.ENABLE_CHROMADB+chromadb} - provider_id: ${env.ENABLE_CHROMADB+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb
config: config:
@ -62,14 +62,14 @@ providers:
persistence_store: persistence_store:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db
telemetry: telemetry:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack} service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite} sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/open-benchmark/trace_store.db}
eval: eval:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -114,18 +114,13 @@ providers:
config: {} config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db
models: models:
- metadata: {} - metadata: {}
model_id: openai/gpt-4o model_id: openai/gpt-4o
provider_id: openai provider_id: openai
provider_model_id: openai/gpt-4o provider_model_id: openai/gpt-4o
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: together
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
model_type: llm
- metadata: {} - metadata: {}
model_id: anthropic/claude-3-5-sonnet-latest model_id: anthropic/claude-3-5-sonnet-latest
provider_id: anthropic provider_id: anthropic
@ -141,84 +136,95 @@ models:
provider_id: groq provider_id: groq
provider_model_id: groq/llama-3.3-70b-versatile provider_model_id: groq/llama-3.3-70b-versatile
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: together
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
model_type: llm
shields: shields:
- shield_id: meta-llama/Llama-Guard-3-8B - shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: [] vector_dbs: []
datasets: datasets:
- dataset_id: simpleqa - dataset_schema:
provider_id: huggingface
url:
uri: https://huggingface.co/datasets/llamastack/simpleqa
metadata:
path: llamastack/simpleqa
name:
split: train
dataset_schema:
input_query: input_query:
type: string type: string
expected_answer: expected_answer:
type: string type: string
chat_completion_input: chat_completion_input:
type: string type: string
- dataset_id: mmlu_cot url:
uri: https://huggingface.co/datasets/llamastack/simpleqa
metadata:
path: llamastack/simpleqa
split: train
dataset_id: simpleqa
provider_id: huggingface provider_id: huggingface
- dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
url: url:
uri: https://huggingface.co/datasets/llamastack/mmlu_cot uri: https://huggingface.co/datasets/llamastack/mmlu_cot
metadata: metadata:
path: llamastack/mmlu_cot path: llamastack/mmlu_cot
name: all name: all
split: test split: test
dataset_schema: dataset_id: mmlu_cot
provider_id: huggingface
- dataset_schema:
input_query: input_query:
type: string type: string
expected_answer: expected_answer:
type: string type: string
chat_completion_input: chat_completion_input:
type: string type: string
- dataset_id: gpqa_cot
provider_id: huggingface
url: url:
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
metadata: metadata:
path: llamastack/gpqa_0shot_cot path: llamastack/gpqa_0shot_cot
name: gpqa_main name: gpqa_main
split: train split: train
dataset_schema: dataset_id: gpqa_cot
provider_id: huggingface
- dataset_schema:
input_query: input_query:
type: string type: string
expected_answer: expected_answer:
type: string type: string
chat_completion_input: chat_completion_input:
type: string type: string
- dataset_id: math_500
provider_id: huggingface
url: url:
uri: https://huggingface.co/datasets/llamastack/math_500 uri: https://huggingface.co/datasets/llamastack/math_500
metadata: metadata:
path: llamastack/math_500 path: llamastack/math_500
name:
split: test split: test
dataset_schema: dataset_id: math_500
input_query: provider_id: huggingface
type: string
expected_answer:
type: string
chat_completion_input:
type: string
scoring_fns: [] scoring_fns: []
benchmarks: benchmarks:
- benchmark_id: meta-reference-simpleqa - dataset_id: simpleqa
dataset_id: simpleqa scoring_functions:
scoring_functions: ["llm-as-judge::405b-simpleqa"] - llm-as-judge::405b-simpleqa
- benchmark_id: meta-reference-mmlu-cot metadata: {}
dataset_id: mmlu_cot benchmark_id: meta-reference-simpleqa
scoring_functions: ["basic::regex_parser_multiple_choice_answer"] - dataset_id: mmlu_cot
- benchmark_id: meta-reference-gpqa-cot scoring_functions:
dataset_id: gpqa_cot - basic::regex_parser_multiple_choice_answer
scoring_functions: ["basic::regex_parser_multiple_choice_answer"] metadata: {}
- benchmark_id: meta-reference-math-500 benchmark_id: meta-reference-mmlu-cot
dataset_id: math_500 - dataset_id: gpqa_cot
scoring_functions: ["basic::regex_parser_math_response"] scoring_functions:
- basic::regex_parser_multiple_choice_answer
metadata: {}
benchmark_id: meta-reference-gpqa-cot
- dataset_id: math_500
scoring_functions:
- basic::regex_parser_math_response
metadata: {}
benchmark_id: meta-reference-math-500
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search

View file

@ -14,7 +14,9 @@ from pydantic import BaseModel, Field
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
Api, Api,
BenchmarkInput,
BuildConfig, BuildConfig,
DatasetInput,
DistributionSpec, DistributionSpec,
ModelInput, ModelInput,
Provider, Provider,
@ -56,6 +58,8 @@ class RunConfigSettings(BaseModel):
default_models: Optional[List[ModelInput]] = None default_models: Optional[List[ModelInput]] = None
default_shields: Optional[List[ShieldInput]] = None default_shields: Optional[List[ShieldInput]] = None
default_tool_groups: Optional[List[ToolGroupInput]] = None default_tool_groups: Optional[List[ToolGroupInput]] = None
default_datasets: Optional[List[DatasetInput]] = None
default_benchmarks: Optional[List[BenchmarkInput]] = None
def run_config( def run_config(
self, self,
@ -113,6 +117,8 @@ class RunConfigSettings(BaseModel):
models=self.default_models or [], models=self.default_models or [],
shields=self.default_shields or [], shields=self.default_shields or [],
tool_groups=self.default_tool_groups or [], tool_groups=self.default_tool_groups or [],
datasets=self.default_datasets or [],
benchmarks=self.default_benchmarks or [],
) )

View file

@ -16,7 +16,7 @@ providers:
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY} api_key: ${env.TOGETHER_API_KEY:}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}

View file

@ -16,7 +16,7 @@ providers:
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY} api_key: ${env.TOGETHER_API_KEY:}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}