forked from phoenix-oss/llama-stack-mirror
feat: open benchmark template and doc (#1465)
## What does this PR do? - Provide a distro template to let developer easily run the open benchmarks llama stack supports on llama and non-llama models. - Provide doc on how to run open benchmark eval via CLI and open benchmark contributing guide [//]: # (If resolving an issue, uncomment and update the line below) (Closes #1375 ) ## Test Plan open benchmark eval results on llama, gpt, gemini and clause <img width="771" alt="Screenshot 2025-03-06 at 7 33 05 PM" src="https://github.com/user-attachments/assets/1bd85456-b9b9-4b37-af76-4ce1d2bac00e" /> doc preview <img width="944" alt="Screenshot 2025-03-06 at 7 33 58 PM" src="https://github.com/user-attachments/assets/f4e5866d-b395-4c40-aa8b-080edeb5cdb6" /> <img width="955" alt="Screenshot 2025-03-06 at 7 34 04 PM" src="https://github.com/user-attachments/assets/629defb6-d5e4-473c-aa03-308bce386fb4" /> <img width="965" alt="Screenshot 2025-03-06 at 7 35 29 PM" src="https://github.com/user-attachments/assets/c21ff96c-9e8c-4c54-b6b8-25883125f4cf" /> <img width="957" alt="Screenshot 2025-03-06 at 7 35 37 PM" src="https://github.com/user-attachments/assets/47571c90-1381-4e2c-bbed-c4f3a60578d0" />
This commit is contained in:
parent
290cc843fc
commit
4dccf916d1
7 changed files with 585 additions and 10 deletions
|
@ -453,6 +453,42 @@
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn"
|
"uvicorn"
|
||||||
],
|
],
|
||||||
|
"open_benchmark": [
|
||||||
|
"aiosqlite",
|
||||||
|
"autoevals",
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"chromadb-client",
|
||||||
|
"datasets",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"httpx",
|
||||||
|
"litellm",
|
||||||
|
"matplotlib",
|
||||||
|
"mcp",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"openai",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pymongo",
|
||||||
|
"pypdf",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"sqlite-vec",
|
||||||
|
"together",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"uvicorn",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
|
],
|
||||||
"remote-vllm": [
|
"remote-vllm": [
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"autoevals",
|
"autoevals",
|
||||||
|
|
|
@ -24,6 +24,56 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
|
||||||
- Associated with `Benchmark` resource.
|
- Associated with `Benchmark` resource.
|
||||||
|
|
||||||
|
|
||||||
|
## Open-benchmark Eval
|
||||||
|
|
||||||
|
### List of open-benchmarks Llama Stack support
|
||||||
|
|
||||||
|
Llama stack pre-registers several popular open-benchmarks to easily evaluate model perfomance via CLI.
|
||||||
|
|
||||||
|
The list of open-benchmarks we currently support:
|
||||||
|
- [MMLU-COT](https://arxiv.org/abs/2009.03300) (Measuring Massive Multitask Language Understanding): Benchmark designed to comprehensively evaluate the breadth and depth of a model's academic and professional understanding
|
||||||
|
- [GPQA-COT](https://arxiv.org/abs/2311.12022) (A Graduate-Level Google-Proof Q&A Benchmark): A challenging benchmark of 448 multiple-choice questions written by domain experts in biology, physics, and chemistry.
|
||||||
|
- [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions.
|
||||||
|
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)]: Benchmark designed to evaluate multimodal models.
|
||||||
|
|
||||||
|
|
||||||
|
You can follow this contributing guidance to add more open-benchmarks to Llama Stack
|
||||||
|
|
||||||
|
### Run evaluation on open-benchmarks via CLI
|
||||||
|
|
||||||
|
We have built-in functionality to run the supported open-benckmarks using llama-stack-client CLI
|
||||||
|
|
||||||
|
#### Spin up Llama Stack server
|
||||||
|
|
||||||
|
Spin up llama stack server with 'open-benchmark' template
|
||||||
|
```
|
||||||
|
llama stack run llama_stack/templates/open-benchmark/run.yaml
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Run eval CLI
|
||||||
|
There are 3 necessary inputs to run a benchmark eval
|
||||||
|
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on
|
||||||
|
- `model-id`: The model id to evaluate on
|
||||||
|
- `utput_dir`: Path to store the evaluate results
|
||||||
|
```
|
||||||
|
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
|
||||||
|
--model_id <model id to evaluate on> \
|
||||||
|
--output_dir <directory to store the evaluate results> \
|
||||||
|
```
|
||||||
|
|
||||||
|
You can run
|
||||||
|
```
|
||||||
|
llama-stack-client eval run-benchmark help
|
||||||
|
```
|
||||||
|
to see the description of all the flags that eval run-benchmark has
|
||||||
|
|
||||||
|
|
||||||
|
In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggrgate
|
||||||
|
evaluation results over there.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## What's Next?
|
## What's Next?
|
||||||
|
|
||||||
- Check out our Colab notebook on working examples with running benchmark evaluations [here](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb#scrollTo=mxLCsP4MvFqP).
|
- Check out our Colab notebook on working examples with running benchmark evaluations [here](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb#scrollTo=mxLCsP4MvFqP).
|
||||||
|
|
|
@ -275,18 +275,25 @@ response = client.scoring.score(
|
||||||
The following examples give the quick steps to start running evaluations using the llama-stack-client CLI.
|
The following examples give the quick steps to start running evaluations using the llama-stack-client CLI.
|
||||||
|
|
||||||
#### Benchmark Evaluation CLI
|
#### Benchmark Evaluation CLI
|
||||||
Usage: There are 2 inputs necessary for running a benchmark eval
|
There are 3 necessary input for running a benchmark eval
|
||||||
- `eval-task-id`: the identifier associated with the eval task. Each `Benchmark` is parametrized by
|
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on
|
||||||
- `dataset_id`: the identifier associated with the dataset.
|
- `model-id`: The model id to evaluate on
|
||||||
- `List[scoring_function_id]`: list of scoring function identifiers.
|
- `utput_dir`: Path to store the evaluate results
|
||||||
- `eval-task-config`: specifies the configuration of the model / agent to evaluate on.
|
```
|
||||||
|
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
|
||||||
|
--model_id <model id to evaluate on> \
|
||||||
|
--output_dir <directory to store the evaluate results> \
|
||||||
|
```
|
||||||
|
|
||||||
|
You can run
|
||||||
|
```
|
||||||
|
llama-stack-client eval run-benchmark help
|
||||||
|
```
|
||||||
|
to see the description of all the flags to run benckmark eval
|
||||||
|
|
||||||
|
|
||||||
```
|
In the output log, you can find the path to the file that has your evaluation results. Open that file and you can see you aggrgate
|
||||||
llama-stack-client eval run_benchmark <eval-task-id> \
|
evaluation results over there.
|
||||||
--eval-task-config ~/benchmark_config.json \
|
|
||||||
--visualize
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
#### Application Evaluation CLI
|
#### Application Evaluation CLI
|
||||||
|
@ -338,3 +345,52 @@ The `BenchmarkConfig` are user specified config to define:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Open-benchmark Contributing Guide
|
||||||
|
|
||||||
|
### Create the new dataset for your new benchmark
|
||||||
|
An eval open-benchmark essentially contains 2 parts:
|
||||||
|
- `raw data`: The raw dataset associated with the benchmark. You typically need to search the original paper that introduces the benchmark and find the canonical dataset (usually hosted on huggingface)
|
||||||
|
- `prompt template`: How to ask the candidate model to generate the answer (prompt template plays a critical role to the evaluation results). Tyically, you can find the reference prompt template associated with the benchmark in benchmarks author's repo ([exmaple](https://github.com/idavidrein/gpqa/blob/main/prompts/chain_of_thought.txt)) or some other popular open source repos ([example](https://github.com/openai/simple-evals/blob/0a6e8f62e52bc5ae915f752466be3af596caf392/common.py#L14))
|
||||||
|
|
||||||
|
To create new open-benmark in llama stack, you need to combine the prompt template and the raw data into the `chat_completion_input` column in the evaluation dataset.
|
||||||
|
|
||||||
|
Llama stack enforeces the evaluate dataset schema to contain at least 3 columns:
|
||||||
|
- `chat_completion_input`: The actual input to the model to run the generation for eval
|
||||||
|
- `input_query`: The raw input from the raw dataset without the prompt template
|
||||||
|
- `expected_answer`: The ground truth for scoring functions to calcalate the score from.
|
||||||
|
|
||||||
|
|
||||||
|
You need to write a script [example convert script](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840) to convert the benchmark raw dataset to llama stack format eval dataset and update the dataset to huggingface [example benchmark dataset](https://huggingface.co/datasets/llamastack/mmmu)
|
||||||
|
|
||||||
|
|
||||||
|
### Find scoring function for your new benchmark
|
||||||
|
The purpose of scoring function is to calculate the score for each example based on candidate model generation result and expected_answer. It also aggregates the scores from all the examples and generate the final evaluate results.
|
||||||
|
|
||||||
|
|
||||||
|
Firstly, you can see if the existing [llama stack scoring functions](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/inline/scoring) can fulfill your need. If not, you need to write a new scoring function based on what benchmark author / other open source repo describe.
|
||||||
|
|
||||||
|
### Add new benchmark into template
|
||||||
|
Firstly, you need to add the evaluation dataset associated with your benchmark under `datasets` resource in templates/open-benchmark/run.yaml
|
||||||
|
|
||||||
|
Secondly, you need to add the new benchmark you just created under the `benchmarks` resource in the same template. To add the new benchmark, you need to have
|
||||||
|
- `benchmark_id`: identifier of the benchmark
|
||||||
|
- `dataset_id`: identifier of the dataset associated with your benchmark
|
||||||
|
- `scoring_functions`: scoring function to calculate the score based on generation results and expected_answer
|
||||||
|
|
||||||
|
|
||||||
|
### Test the new benchmark
|
||||||
|
|
||||||
|
Spin up llama stack server with 'open-benchmark' templates
|
||||||
|
```
|
||||||
|
llama stack run llama_stack/templates/open-benchmark/run.yaml
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Run eval benchmark CLI with your new benchmark id
|
||||||
|
```
|
||||||
|
llama-stack-client eval run-benchmark <new_benchmark_id> \
|
||||||
|
--model_id <model id to evaluate on> \
|
||||||
|
--output_dir <directory to store the evaluate results> \
|
||||||
|
```
|
||||||
|
|
7
llama_stack/templates/open-benchmark/__init__.py
Normal file
7
llama_stack/templates/open-benchmark/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .open_benchmark import get_distribution_template # noqa: F401
|
36
llama_stack/templates/open-benchmark/build.yaml
Normal file
36
llama_stack/templates/open-benchmark/build.yaml
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
version: '2'
|
||||||
|
distribution_spec:
|
||||||
|
description: Distribution for running open benchmarks
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::openai
|
||||||
|
- remote::anthropic
|
||||||
|
- remote::gemini
|
||||||
|
- remote::groq
|
||||||
|
- remote::together
|
||||||
|
vector_io:
|
||||||
|
- inline::sqlite-vec
|
||||||
|
- remote::chromadb
|
||||||
|
- remote::pgvector
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
eval:
|
||||||
|
- inline::meta-reference
|
||||||
|
datasetio:
|
||||||
|
- remote::huggingface
|
||||||
|
- inline::localfs
|
||||||
|
scoring:
|
||||||
|
- inline::basic
|
||||||
|
- inline::llm-as-judge
|
||||||
|
- inline::braintrust
|
||||||
|
tool_runtime:
|
||||||
|
- remote::brave-search
|
||||||
|
- remote::tavily-search
|
||||||
|
- inline::code-interpreter
|
||||||
|
- inline::rag-runtime
|
||||||
|
- remote::model-context-protocol
|
||||||
|
image_type: conda
|
178
llama_stack/templates/open-benchmark/open_benchmark.py
Normal file
178
llama_stack/templates/open-benchmark/open_benchmark.py
Normal file
|
@ -0,0 +1,178 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
ModelInput,
|
||||||
|
Provider,
|
||||||
|
ShieldInput,
|
||||||
|
ToolGroupInput,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
||||||
|
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
||||||
|
from llama_stack.providers.remote.inference.anthropic.models import MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
||||||
|
from llama_stack.providers.remote.inference.gemini.models import MODEL_ENTRIES as GEMINI_MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
|
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES as GROQ_MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||||
|
from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.together.models import MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
|
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||||
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
||||||
|
# in this template, we allow each API key to be optional
|
||||||
|
providers = [
|
||||||
|
(
|
||||||
|
"openai",
|
||||||
|
OPENAI_MODEL_ENTRIES,
|
||||||
|
OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"anthropic",
|
||||||
|
ANTHROPIC_MODEL_ENTRIES,
|
||||||
|
AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"gemini",
|
||||||
|
GEMINI_MODEL_ENTRIES,
|
||||||
|
GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"groq",
|
||||||
|
GROQ_MODEL_ENTRIES,
|
||||||
|
GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"together",
|
||||||
|
TOGETHER_MODEL_ENTRIES,
|
||||||
|
TogetherImplConfig.sample_run_config(api_key="${env.TOGETHER_API_KEY:}"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
inference_providers = []
|
||||||
|
available_models = {}
|
||||||
|
for provider_id, model_entries, config in providers:
|
||||||
|
inference_providers.append(
|
||||||
|
Provider(
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_type=f"remote::{provider_id}",
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
available_models[provider_id] = model_entries
|
||||||
|
return inference_providers, available_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
inference_providers, available_models = get_inference_providers()
|
||||||
|
providers = {
|
||||||
|
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
|
||||||
|
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
"eval": ["inline::meta-reference"],
|
||||||
|
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||||
|
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||||
|
"tool_runtime": [
|
||||||
|
"remote::brave-search",
|
||||||
|
"remote::tavily-search",
|
||||||
|
"inline::code-interpreter",
|
||||||
|
"inline::rag-runtime",
|
||||||
|
"remote::model-context-protocol",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
name = "open_benchmark"
|
||||||
|
|
||||||
|
vector_io_providers = [
|
||||||
|
Provider(
|
||||||
|
provider_id="sqlite-vec",
|
||||||
|
provider_type="inline::sqlite-vec",
|
||||||
|
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
||||||
|
provider_type="remote::chromadb",
|
||||||
|
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
|
||||||
|
provider_type="remote::pgvector",
|
||||||
|
config=PGVectorVectorIOConfig.sample_run_config(
|
||||||
|
db="${env.PGVECTOR_DB:}",
|
||||||
|
user="${env.PGVECTOR_USER:}",
|
||||||
|
password="${env.PGVECTOR_PASSWORD:}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
default_tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::websearch",
|
||||||
|
provider_id="tavily-search",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::rag",
|
||||||
|
provider_id="rag-runtime",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::code_interpreter",
|
||||||
|
provider_id="code-interpreter",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
default_models = get_model_registry(available_models)
|
||||||
|
return DistributionTemplate(
|
||||||
|
name=name,
|
||||||
|
distro_type="self_hosted",
|
||||||
|
description="Distribution for running open benchmarks",
|
||||||
|
container_image=None,
|
||||||
|
template_path=None,
|
||||||
|
providers=providers,
|
||||||
|
available_models_by_provider=available_models,
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": inference_providers,
|
||||||
|
"vector_io": vector_io_providers,
|
||||||
|
},
|
||||||
|
default_models=default_models,
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
|
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"LLAMA_STACK_PORT": (
|
||||||
|
"5001",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
"OPENAI_API_KEY": (
|
||||||
|
"",
|
||||||
|
"OpenAI API Key",
|
||||||
|
),
|
||||||
|
"GEMINI_API_KEY": (
|
||||||
|
"",
|
||||||
|
"Gemini API Key",
|
||||||
|
),
|
||||||
|
"GROQ_API_KEY": (
|
||||||
|
"",
|
||||||
|
"Groq API Key",
|
||||||
|
),
|
||||||
|
"ANTHROPIC_API_KEY": (
|
||||||
|
"",
|
||||||
|
"Anthropic API Key",
|
||||||
|
),
|
||||||
|
"TOGETHER_API_KEY": (
|
||||||
|
"",
|
||||||
|
"Together API Key",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
212
llama_stack/templates/open-benchmark/run.yaml
Normal file
212
llama_stack/templates/open-benchmark/run.yaml
Normal file
|
@ -0,0 +1,212 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: dev
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: openai
|
||||||
|
provider_type: remote::openai
|
||||||
|
config:
|
||||||
|
api_key: ${env.OPENAI_API_KEY:}
|
||||||
|
- provider_id: anthropic
|
||||||
|
provider_type: remote::anthropic
|
||||||
|
config:
|
||||||
|
api_key: ${env.ANTHROPIC_API_KEY:}
|
||||||
|
- provider_id: gemini
|
||||||
|
provider_type: remote::gemini
|
||||||
|
config:
|
||||||
|
api_key: ${env.GEMINI_API_KEY:}
|
||||||
|
- provider_id: groq
|
||||||
|
provider_type: remote::groq
|
||||||
|
config:
|
||||||
|
url: https://api.groq.com
|
||||||
|
api_key: ${env.GROQ_API_KEY:}
|
||||||
|
- provider_id: together
|
||||||
|
provider_type: remote::together
|
||||||
|
config:
|
||||||
|
url: https://api.together.xyz/v1
|
||||||
|
api_key: ${env.TOGETHER_API_KEY}
|
||||||
|
vector_io:
|
||||||
|
- provider_id: sqlite-vec
|
||||||
|
provider_type: inline::sqlite-vec
|
||||||
|
config:
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db
|
||||||
|
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||||
|
provider_type: remote::chromadb
|
||||||
|
config:
|
||||||
|
url: ${env.CHROMADB_URL:}
|
||||||
|
- provider_id: ${env.ENABLE_PGVECTOR+pgvector}
|
||||||
|
provider_type: remote::pgvector
|
||||||
|
config:
|
||||||
|
host: ${env.PGVECTOR_HOST:localhost}
|
||||||
|
port: ${env.PGVECTOR_PORT:5432}
|
||||||
|
db: ${env.PGVECTOR_DB:}
|
||||||
|
user: ${env.PGVECTOR_USER:}
|
||||||
|
password: ${env.PGVECTOR_PASSWORD:}
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config: {}
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
|
||||||
|
eval:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config: {}
|
||||||
|
datasetio:
|
||||||
|
- provider_id: huggingface
|
||||||
|
provider_type: remote::huggingface
|
||||||
|
config: {}
|
||||||
|
- provider_id: localfs
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config: {}
|
||||||
|
scoring:
|
||||||
|
- provider_id: basic
|
||||||
|
provider_type: inline::basic
|
||||||
|
config: {}
|
||||||
|
- provider_id: llm-as-judge
|
||||||
|
provider_type: inline::llm-as-judge
|
||||||
|
config: {}
|
||||||
|
- provider_id: braintrust
|
||||||
|
provider_type: inline::braintrust
|
||||||
|
config:
|
||||||
|
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: code-interpreter
|
||||||
|
provider_type: inline::code-interpreter
|
||||||
|
config: {}
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
- provider_id: model-context-protocol
|
||||||
|
provider_type: remote::model-context-protocol
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: openai/gpt-4o
|
||||||
|
provider_id: openai
|
||||||
|
provider_model_id: openai/gpt-4o
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-405B-Instruct
|
||||||
|
provider_id: together
|
||||||
|
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: anthropic/claude-3-5-sonnet-latest
|
||||||
|
provider_id: anthropic
|
||||||
|
provider_model_id: anthropic/claude-3-5-sonnet-latest
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: gemini/gemini-1.5-flash
|
||||||
|
provider_id: gemini
|
||||||
|
provider_model_id: gemini/gemini-1.5-flash
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||||
|
provider_id: groq
|
||||||
|
provider_model_id: groq/llama-3.3-70b-versatile
|
||||||
|
model_type: llm
|
||||||
|
shields:
|
||||||
|
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
vector_dbs: []
|
||||||
|
datasets:
|
||||||
|
- dataset_id: simpleqa
|
||||||
|
provider_id: huggingface
|
||||||
|
url:
|
||||||
|
uri: https://huggingface.co/datasets/llamastack/simpleqa
|
||||||
|
metadata:
|
||||||
|
path: llamastack/simpleqa
|
||||||
|
name:
|
||||||
|
split: train
|
||||||
|
dataset_schema:
|
||||||
|
input_query:
|
||||||
|
type: string
|
||||||
|
expected_answer:
|
||||||
|
type: string
|
||||||
|
chat_completion_input:
|
||||||
|
type: string
|
||||||
|
- dataset_id: mmlu_cot
|
||||||
|
provider_id: huggingface
|
||||||
|
url:
|
||||||
|
uri: https://huggingface.co/datasets/llamastack/mmlu_cot
|
||||||
|
metadata:
|
||||||
|
path: llamastack/mmlu_cot
|
||||||
|
name: all
|
||||||
|
split: test
|
||||||
|
dataset_schema:
|
||||||
|
input_query:
|
||||||
|
type: string
|
||||||
|
expected_answer:
|
||||||
|
type: string
|
||||||
|
chat_completion_input:
|
||||||
|
type: string
|
||||||
|
- dataset_id: gpqa_cot
|
||||||
|
provider_id: huggingface
|
||||||
|
url:
|
||||||
|
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
|
||||||
|
metadata:
|
||||||
|
path: llamastack/gpqa_0shot_cot
|
||||||
|
name: gpqa_main
|
||||||
|
split: train
|
||||||
|
dataset_schema:
|
||||||
|
input_query:
|
||||||
|
type: string
|
||||||
|
expected_answer:
|
||||||
|
type: string
|
||||||
|
chat_completion_input:
|
||||||
|
type: string
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks:
|
||||||
|
- benchmark_id: meta-reference-simpleqa
|
||||||
|
dataset_id: simpleqa
|
||||||
|
scoring_functions: ["llm-as-judge::405b-simpleqa"]
|
||||||
|
- benchmark_id: meta-reference-mmlu-cot
|
||||||
|
dataset_id: mmlu_cot
|
||||||
|
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
|
||||||
|
- benchmark_id: meta-reference-gpqa-cot
|
||||||
|
dataset_id: gpqa_cot
|
||||||
|
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
- toolgroup_id: builtin::code_interpreter
|
||||||
|
provider_id: code-interpreter
|
||||||
|
server:
|
||||||
|
port: 8321
|
Loading…
Add table
Add a link
Reference in a new issue