Compare commits

...
Sign in to create a new pull request.

85 commits

Author SHA1 Message Date
Xi Yan
7854885e5a openapi 2025-03-26 12:32:24 -07:00
Xi Yan
cbb53af701 distro codegen 2025-03-26 12:31:08 -07:00
Xi Yan
bc0cd07008 Merge branch 'main' into eval_api_final 2025-03-26 12:29:45 -07:00
Xi Yan
7f12ea290f
feat(eval api): (2.3/n) remove scoring / eval impls + benchmarks (#1766)
# What does this PR do?
- Remove `/eval` and `/scoring` impls
- Clean up benchmarks. The benchmarks exists in the `llama-stack-evals`
repo.
- Rest of grading functions will be added in follow up PR. 

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
- CI

[//]: # (## Documentation)
2025-03-23 16:51:17 -07:00
Xi Yan
97e7717c9b fix precommit 2025-03-23 16:42:50 -07:00
Xi Yan
81bc051411 fix precommit 2025-03-23 16:32:06 -07:00
Xi Yan
5038f0e376 precommit 2025-03-23 16:27:56 -07:00
Xi Yan
c2eb47d7e6 precommit 2025-03-23 16:22:32 -07:00
Xi Yan
2723b05164 remove aggregation functions 2025-03-23 16:17:09 -07:00
Xi Yan
64388de068 precommit 2025-03-23 16:15:08 -07:00
Xi Yan
3f8c7a584a precommit 2025-03-23 16:00:48 -07:00
Xi Yan
45f6d5cd08 remove ScoringFn 2025-03-23 15:57:48 -07:00
Xi Yan
a54d757ade merge 2025-03-23 15:48:14 -07:00
Xi Yan
c1d18283d2
feat(eval api): (2.2/n) delete eval / scoring / scoring_fn apis (#1700)
# What does this PR do?
- To make it easier, delete existing `eval/scoring/scoring_function`
apis. There will be a bunch of broken impls here. The sequence is:
1. migrate benchmark graders
2. clean up existing scoring functions

- Add a skeleton evaluation impl to make tests pass. 

## Test Plan
tested in following PRs

[//]: # (## Documentation)
2025-03-19 11:04:23 -07:00
Xi Yan
0048274ec0 update 2025-03-19 09:49:53 -07:00
Xi Yan
42447729e4 update 2025-03-19 09:48:30 -07:00
Xi Yan
a92756a4b7 result_data in evaluation response 2025-03-18 22:09:35 -07:00
Xi Yan
08c0c5505e
feat(eval api): (2.1/n) fix resolver for benchmark routing table + fix precommit (#1691)
# What does this PR do?
- fixes routing table so that `llama stack run` works
- fixes pre-commit
- one of many fixes to address implementation fix

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
```
llama stack run
```

[//]: # (## Documentation)
2025-03-18 21:09:49 -07:00
Xi Yan
bf135f38b1 precommit 2025-03-18 20:48:03 -07:00
Xi Yan
205a50f10b openapi gen 2025-03-18 20:38:05 -07:00
Xi Yan
24d48b3692 Merge branch 'main' into eval_api_final 2025-03-18 20:17:24 -07:00
Xi Yan
913e6eb50f
Update llama_stack/apis/evaluation/evaluation.py
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
2025-03-18 20:16:24 -07:00
Xi Yan
820b9a00c7
Update llama_stack/apis/evaluation/evaluation.py
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
2025-03-18 20:16:15 -07:00
Xi Yan
85cad639ca
Update llama_stack/apis/evaluation/evaluation.py
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
2025-03-18 20:16:08 -07:00
Xi Yan
d994499f09 update EvaluationTask 2025-03-18 19:30:01 -07:00
Xi Yan
f107e3229b update EvaluationTask 2025-03-18 19:28:34 -07:00
Xi Yan
5e817cd56a update 2025-03-18 18:16:00 -07:00
Xi Yan
398319fe7a agent config 2025-03-18 18:14:55 -07:00
Xi Yan
238cdc4e69 grading 2025-03-18 18:12:06 -07:00
Xi Yan
b98497ee56 docs 2025-03-18 18:10:45 -07:00
Xi Yan
e860c536da pre 2025-03-18 18:01:40 -07:00
Xi Yan
a69759613a comments 2025-03-18 15:01:41 -07:00
Xi Yan
a8b0467ec3 address 2025-03-18 14:51:52 -07:00
Xi Yan
5c0888c29a comments 2025-03-18 14:50:19 -07:00
Xi Yan
46f2ba5910 Merge branch 'main' into eval_api_final 2025-03-18 14:49:57 -07:00
Xi Yan
ade3391170 revert job related 2025-03-17 17:12:28 -07:00
Xi Yan
452b2b1284 precommit 2025-03-17 17:08:21 -07:00
Xi Yan
66cd83fb58 Merge branch 'main' into eval_api_final 2025-03-17 17:00:30 -07:00
Xi Yan
62abe2899a pre 2025-03-16 20:49:09 -07:00
Xi Yan
cb492eba37 inline -> sync 2025-03-16 20:46:07 -07:00
Xi Yan
1860751655 benchmarks 2025-03-16 19:41:40 -07:00
Xi Yan
c80d1f906b precommit 2025-03-16 19:40:18 -07:00
Xi Yan
035b2dcb60 new apis 2025-03-16 19:33:57 -07:00
Xi Yan
d34b70e3ab grader 2025-03-16 18:30:06 -07:00
Xi Yan
d9264a0925 dataaset 2025-03-16 16:34:56 -07:00
Xi Yan
63f1525165 precommit 2025-03-16 16:27:29 -07:00
Xi Yan
5cf7779b8f fix integeration 2025-03-15 17:36:39 -07:00
Xi Yan
a6fa3aa5a2
feat(dataset api): (1.6/n) fix all iterrows callsites (#1660)
# What does this PR do?
- as title

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
CI

```
pytest -v -s --nbval-lax ./docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb
```
<img width="587" alt="image"
src="https://github.com/user-attachments/assets/4a25f493-501e-43f4-9836-d9802223a93a"
/>


[//]: # (## Documentation)
2025-03-15 17:24:16 -07:00
Xi Yan
f2d93324e9 pre 2025-03-15 17:08:32 -07:00
Xi Yan
28b8c1c815 scoring fix 2025-03-15 17:06:53 -07:00
Xi Yan
6f5df08ebf fix hf url endpoint 2025-03-15 16:50:43 -07:00
Xi Yan
a568bf3f9d
feat(dataset api): (1.5/n) fix dataset registeration (#1659)
# What does this PR do?

- fix dataset registeration & iterrows
> NOTE: the URL endpoint is changed to datasetio due to flaky path
routing

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
```
LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/datasets/test_datasets.py
```
<img width="854" alt="image"
src="https://github.com/user-attachments/assets/0168b352-1c5a-48d1-8e9a-93141d418e54"
/>


[//]: # (## Documentation)
2025-03-15 16:48:09 -07:00
Xi Yan
2c9d624910
feat(dataset api): (1.4/n) fix resolver signature mismatch (#1658)
# What does this PR do?
- fix datasets api signature mis-match so that llama stack run can start

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
```
llama stack run
```
<img width="626" alt="image"
src="https://github.com/user-attachments/assets/59072d1a-ccb6-453a-80e8-d87419896c41"
/>


[//]: # (## Documentation)
2025-03-15 14:56:11 -07:00
Xi Yan
72ccdc19a8
feat(datasets api): (1.3/n) patch OpenAPI gen for datasetio->datasets (#1657)
# What does this PR do?
- We need to tag DatasetIO class correctly with Datasets with the
endpoint change

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
**Before**
<img width="1474" alt="image"
src="https://github.com/user-attachments/assets/48737317-28a3-4aa6-a1b5-e1ea680cef84"
/>


**After**
<img width="1508" alt="image"
src="https://github.com/user-attachments/assets/123322f0-a52f-47ee-99a7-ecc66c1b09ec"
/>

[//]: # (## Documentation)
2025-03-15 14:12:45 -07:00
Xi Yan
5cb0ad7d7f openapi gen + precommit fix 2025-03-15 14:08:01 -07:00
Xi Yan
39f4dfbf50
feat(api): (1.2/n) datasets.iterrorws pagination api updates (#1656)
# What does this PR do?
- as title
- uses "cursor" pagination scheme for iterrows

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
<img width="1226" alt="image"
src="https://github.com/user-attachments/assets/3220eaac-7117-4d0a-b344-2bbb77a22065"
/>


[//]: # (## Documentation)
2025-03-15 13:58:47 -07:00
Xi Yan
c7d741d89e Merge branch 'main' into pr1573 2025-03-15 13:40:18 -07:00
Xi Yan
cba4842a87 Merge branch 'main' into pr1573 2025-03-14 15:58:27 -07:00
Xi Yan
0e2a13da9c Merge branch 'main' into pr1573 2025-03-14 12:18:00 -07:00
Xi Yan
7606e49dbc
feat(dataset api): (1.1/n) dataset api implementation fix pre-commit (#1625)
# What does this PR do?
- fix pre-commit with api updates

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
```
pre-commit
```

[//]: # (## Documentation)
2025-03-13 16:41:03 -07:00
Xi Yan
a6095820af docs 2025-03-13 14:48:11 -07:00
Xi Yan
89885fd2fa datasetio->datasets 2025-03-13 14:47:06 -07:00
Xi Yan
78ec3d98f6 Merge branch 'main' into pr1573 2025-03-13 11:05:04 -07:00
Xi Yan
8b80a77fae docs 2025-03-12 23:50:52 -07:00
Xi Yan
8a6fa41a93 more purposes 2025-03-12 23:44:18 -07:00
Xi Yan
0df33049e3 update doc 2025-03-12 23:32:54 -07:00
Xi Yan
b4d118fc5c update doc 2025-03-12 23:30:47 -07:00
Xi Yan
772339bebf update doc 2025-03-12 23:27:45 -07:00
Xi Yan
4f6f0f6a91 update doc 2025-03-12 23:27:01 -07:00
Xi Yan
4cc1958af9 huggingface obey consistency 2025-03-12 21:37:13 -07:00
Xi Yan
09039eca57 source 2025-03-12 18:52:05 -07:00
Xi Yan
790b2d5cc0 source 2025-03-12 18:51:46 -07:00
Xi Yan
a3173e8284 update 2025-03-12 18:46:40 -07:00
Xi Yan
18de4cd08a comments 2025-03-12 18:38:07 -07:00
Xi Yan
8942071b3b Merge branch 'main' into pr1573 2025-03-12 18:23:39 -07:00
Xi Yan
f840018088 Merge branch 'main' into pr1573 2025-03-12 12:31:49 -07:00
Xi Yan
31e3409909 Merge branch 'main' into pr1573 2025-03-12 11:38:02 -07:00
Xi Yan
1d80ec7f81 upgrade doc 2025-03-12 00:17:58 -07:00
Xi Yan
0abedd070c comment 2025-03-12 00:13:27 -07:00
Xi Yan
817331e76e precommit 2025-03-11 18:34:38 -07:00
Xi Yan
0e47c65051 update 2025-03-11 18:29:55 -07:00
Xi Yan
02aa9a1e85 remove json_schema_type decorator 2025-03-11 16:08:06 -07:00
Xi Yan
0e8a53ab69 openapi 2025-03-11 15:03:48 -07:00
Xi Yan
8592c2b48a precommit 2025-03-11 14:56:12 -07:00
Xi Yan
bc551e6459 datasets api
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
2025-03-11 14:44:49 -07:00
166 changed files with 2513 additions and 11289 deletions

View file

@ -1,23 +1,19 @@
{
"bedrock": [
"aiosqlite",
"autoevals",
"blobfile",
"boto3",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -25,7 +21,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -33,27 +28,22 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"cerebras": [
"aiosqlite",
"autoevals",
"blobfile",
"cerebras_cloud_sdk",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -61,7 +51,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -69,29 +58,24 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"ci-tests": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -99,7 +83,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -108,7 +91,6 @@
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
@ -116,22 +98,18 @@
"dell": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -139,7 +117,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -147,30 +124,25 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"dev": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -178,7 +150,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -187,30 +158,25 @@
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"fireworks": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -218,7 +184,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -226,28 +191,23 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"groq": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -255,7 +215,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -263,29 +222,24 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"hf-endpoint": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -293,7 +247,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -301,29 +254,24 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"hf-serverless": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -331,7 +279,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -339,7 +286,6 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
@ -347,24 +293,20 @@
"meta-reference-gpu": [
"accelerate",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fairscale",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"lm-format-enforcer",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -372,7 +314,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -383,32 +324,27 @@
"torchvision",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"zmq"
],
"meta-reference-quantized-gpu": [
"accelerate",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fairscale",
"faiss-cpu",
"fastapi",
"fbgemm-gpu",
"fire",
"httpx",
"langdetect",
"lm-format-enforcer",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -416,7 +352,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -428,7 +363,6 @@
"torchvision",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"zmq"
],
@ -437,12 +371,10 @@
"aiosqlite",
"blobfile",
"chardet",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"nltk",
"numpy",
@ -454,7 +386,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -462,29 +393,24 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"ollama": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"ollama",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -492,7 +418,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -500,65 +425,22 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"open-benchmark": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"together",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"passthrough": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -566,7 +448,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -574,24 +455,20 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
@ -604,7 +481,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -612,7 +488,6 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
@ -649,23 +524,19 @@
"tgi": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -673,7 +544,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -681,29 +551,24 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"together": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -711,7 +576,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -720,29 +584,24 @@
"together",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"vllm-gpu": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
@ -750,7 +609,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -758,7 +616,6 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"vllm",
"sentence-transformers --no-deps",

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -7,11 +7,9 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::nvidia` |
| post_training | `remote::nvidia` |
| safety | `remote::nvidia` |
| scoring | `inline::basic` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `inline::rag-runtime` |
| vector_io | `inline::faiss` |

View file

@ -14,10 +14,8 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::bedrock` |
| safety | `remote::bedrock` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -7,10 +7,8 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::cerebras`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -17,10 +17,8 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::fireworks`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -17,10 +17,8 @@ The `llamastack/distribution-groq` distribution consists of the following provid
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::groq` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
| vector_io | `inline::faiss` |

View file

@ -17,10 +17,8 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `inline::meta-reference` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -17,10 +17,8 @@ The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `inline::meta-reference-quantized` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -17,10 +17,8 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::ollama` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -17,10 +17,8 @@ The `llamastack/distribution-passthrough` distribution consists of the following
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::passthrough`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -16,10 +16,8 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::vllm`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -18,10 +18,8 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::tgi`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -17,10 +17,8 @@ The `llamastack/distribution-together` distribution consists of the following pr
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::together`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -12,11 +12,17 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class CommonBenchmarkFields(BaseModel):
"""
:param dataset_id: The ID of the dataset to used to run the benchmark.
:param grader_ids: The grader ids to use for this benchmark.
:param metadata: Metadata for this benchmark for additional descriptions.
"""
dataset_id: str
scoring_functions: List[str]
grader_ids: List[str]
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Metadata for this evaluation task",
description="Metadata for this benchmark",
)
@ -45,22 +51,46 @@ class ListBenchmarksResponse(BaseModel):
@runtime_checkable
class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET")
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
@webmethod(route="/benchmarks", method="POST")
async def register_benchmark(
self,
dataset_id: str,
grader_ids: List[str],
benchmark_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Benchmark:
"""
Register a new benchmark. A benchmark consists of a dataset id and a list of grader ids.
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
:param dataset_id: The ID of the dataset to be used to run the benchmark. ID obtained through `datasets.register()`
:param grader_ids: List of grader ids to use for this benchmark. ID obtained through `graders.register()`
:param benchmark_id: (Optional) The ID of the benchmark to register. If not provided, an ID will be generated.
:param metadata: (Optional) Metadata for this benchmark for additional descriptions.
"""
...
@webmethod(route="/benchmarks", method="GET")
async def list_benchmarks(self) -> ListBenchmarksResponse:
"""
List all benchmarks.
"""
...
@webmethod(route="/benchmarks/{benchmark_id}", method="GET")
async def get_benchmark(
self,
benchmark_id: str,
) -> Benchmark: ...
) -> Benchmark:
"""
Get a benchmark by ID.
@webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
:param benchmark_id: The ID of the benchmark to get.
"""
...
@webmethod(route="/benchmarks/{benchmark_id}", method="DELETE")
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""
Unregister a benchmark by ID.
"""
...

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
@ -10,6 +11,18 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Job(BaseModel):
# NOTE: this will be DEPRECATED in favour of CommonJobFields
job_id: str
class JobType(Enum):
batch_inference = "batch_inference"
evaluation = "evaluation"
finetuning = "finetuning"
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
@ -19,6 +32,17 @@ class JobStatus(Enum):
@json_schema_type
class Job(BaseModel):
job_id: str
class CommonJobFields(BaseModel):
"""Common fields for all jobs.
:param id: The ID of the job.
:param status: The status of the job.
:param created_at: The time the job was created.
:param completed_at: The time the job completed.
:param error: If status of the job is failed, this will contain the error message.
"""
id: str
status: JobStatus
created_at: datetime
completed_at: datetime | None = None
error: str | None = None

View file

@ -20,10 +20,9 @@ class Api(Enum):
agents = "agents"
vector_io = "vector_io"
datasetio = "datasetio"
scoring = "scoring"
eval = "eval"
post_training = "post_training"
tool_runtime = "tool_runtime"
evaluation = "evaluation"
telemetry = "telemetry"
@ -31,7 +30,6 @@ class Api(Enum):
shields = "shields"
vector_dbs = "vector_dbs"
datasets = "datasets"
scoring_functions = "scoring_functions"
benchmarks = "benchmarks"
tool_groups = "tool_groups"
files = "files"

View file

@ -1,143 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class ModelCandidate(BaseModel):
"""A model candidate for evaluation.
:param model: The model ID to evaluate.
:param sampling_params: The sampling parameters for the model.
:param system_message: (Optional) The system message providing instructions or context to the model.
"""
type: Literal["model"] = "model"
model: str
sampling_params: SamplingParams
system_message: Optional[SystemMessage] = None
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
config: AgentConfig
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate")
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
:param eval_candidate: The candidate to evaluate.
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
"""
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
num_examples: Optional[int] = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
# we could optinally add any specific dataset config here
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation.
:param generations: The generations from the evaluation.
:param scores: The scores from the evaluation.
"""
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: Dict[str, ScoringResult]
class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation.
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to cancel.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
"""

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring import * # noqa: F401 F403
from .evaluation import * # noqa: F401 F403

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import CommonJobFields, JobType
from llama_stack.apis.datasets import DataSource
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class ModelCandidate(BaseModel):
"""A model candidate for evaluation.
:param model: The model ID to evaluate.
:param sampling_params: The sampling parameters for the model.
:param system_message: (Optional) The system message providing instructions or context to the model.
"""
type: Literal["model"] = "model"
model_id: str
sampling_params: SamplingParams
system_message: Optional[SystemMessage] = None
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
agent_config: AgentConfig
EvaluationCandidate = register_schema(
Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")],
name="EvaluationCandidate",
)
@json_schema_type
class EvaluationTask(BaseModel):
"""
A task for evaluation. To specify a task, one of the following must be provided:
- `benchmark_id`: Run evaluation task against a benchmark_id. Use this when you have a curated dataset and have settled on the graders.
- `dataset_id` and `grader_ids`: Run evaluation task against a dataset_id and a list of grader_ids. Use this when you have datasets and / or are iterating on your graders.
- `data_source` and `grader_ids`: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids. Prefer this when you are early in your evaluation cycle and experimenting much more with your data and graders.
:param benchmark_id: The benchmark ID to evaluate.
:param dataset_id: The dataset ID to evaluate.
:param data_source: The data source to evaluate.
:param grader_ids: The grader IDs to evaluate.
"""
benchmark_id: Optional[str] = None
dataset_id: Optional[str] = None
data_source: Optional[DataSource] = None
grader_ids: Optional[List[str]] = None
@json_schema_type
class EvaluationJob(CommonJobFields):
type: Literal[JobType.evaluation.value] = JobType.evaluation.value
# input params for the submitted evaluation job
task: EvaluationTask
candidate: EvaluationCandidate
@json_schema_type
class EvaluationResponse(BaseModel):
"""
A response to an inline evaluation.
:param result_rows: The result data containing inputs, generations and grades in each row.
:param grades: Map of grader id to aggregated value.
"""
result_rows: List[Dict[str, Any]]
grades: Dict[str, Any]
class Evaluation(Protocol):
@webmethod(route="/evaluation/run", method="POST")
async def run(
self,
task: EvaluationTask,
candidate: EvaluationCandidate,
) -> EvaluationJob:
"""
Schedule a full evaluation job, by generating results using candidate and grading them.
:param task: The task to evaluate. To specify a task, one of the following must be provided:
- `benchmark_id`: Run evaluation task against a benchmark_id
- `dataset_id` and `grader_ids`: Run evaluation task against a dataset_id and a list of grader_ids
- `data_source` and `grader_ids`: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids
:param candidate: The candidate to evaluate.
"""
...
@webmethod(route="/evaluation/run_sync", method="POST")
async def run_sync(
self,
task: EvaluationTask,
candidate: EvaluationCandidate,
) -> EvaluationResponse:
"""
Run an evaluation synchronously, i.e., without scheduling a job".
You should use this for quick testing, or when the number of rows is limited. Some implementations may have stricter restrictions on inputs which will be accepted.
:param task: The task to evaluate. To specify a task, one of the following must be provided:
- `benchmark_id`: Run evaluation task against a benchmark_id
- `dataset_id` and `grader_ids`: Run evaluation task against a dataset_id and a list of grader_ids
- `data_source` and `grader_ids`: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids
:param candidate: The candidate to evaluate.
"""
...
@webmethod(route="/evaluation/grade", method="POST")
async def grade(self, task: EvaluationTask) -> EvaluationJob:
"""
Schedule a grading job, by grading generated (model or agent) results. The generated results are expected to be in the dataset.
:param task: The task to evaluate. To specify a task, one of the following must be provided:
- `benchmark_id`: Run evaluation task against a benchmark_id
- `dataset_id` and `grader_ids`: Run evaluation task against a dataset_id and a list of grader_ids
- `data_source` and `grader_ids`: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids
:return: The evaluation job containing grader scores.
"""
...
@webmethod(route="/evaluation/grade_sync", method="POST")
async def grade_sync(self, task: EvaluationTask) -> EvaluationResponse:
"""
Run grading synchronously on generated results, i.e., without scheduling a job.
You should use this for quick testing, or when the number of rows is limited. Some implementations may have stricter restrictions on inputs which will be accepted.
:param task: The task to evaluate. To specify a task, one of the following must be provided:
- `benchmark_id`: Run evaluation task against a benchmark_id
- `dataset_id` and `grader_ids`: Run evaluation task against a dataset_id and a list of grader_ids
- `data_source` and `grader_ids`: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids
:return: The evaluation job containing grader scores. "generations" is not populated in the response.
"""
...

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .eval import * # noqa: F401 F403
from .graders import * # noqa: F401 F403

View file

@ -0,0 +1,217 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import (
Annotated,
Any,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
)
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetPurpose
from llama_stack.apis.resource import Resource
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class GraderType(Enum):
"""
A type of grader. Each type is a criteria for evaluating answers.
:cvar llm: Use an LLM to score the answer.
:cvar regex_parser: Use a regex parser to score the answer.
:cvar equality: Check if the answer is equal to the reference answer.
:cvar subset_of: Check if the answer is a subset of the reference answer.
:cvar factuality: Check if the answer is factually correct using LLM as judge.
:cvar faithfulness: Check if the answer is faithful to the reference answer using LLM as judge.
"""
llm = "llm"
regex_parser = "regex_parser"
equality = "equality"
subset_of = "subset_of"
factuality = "factuality"
faithfulness = "faithfulness"
@json_schema_type
class GraderTypeInfo(BaseModel):
"""
:param type: The type of grader.
:param description: A description of the grader type.
- E.g. Write your custom judge prompt to score the answer.
:param supported_dataset_purposes: The purposes that this grader can be used for.
"""
grader_type: GraderType
description: str
supported_dataset_purposes: List[DatasetPurpose] = Field(
description="The supported purposes (supported dataset schema) that this grader can be used for. E.g. eval/question-answer",
default_factory=list,
)
class LlmGraderParams(BaseModel):
model: str
prompt: str
score_regexes: List[str]
class RegexParserGraderParams(BaseModel):
parsing_regexes: List[str]
@json_schema_type
class LlmGrader(BaseModel):
type: Literal["llm"] = "llm"
llm: LlmGraderParams
@json_schema_type
class RegexParserGrader(BaseModel):
type: Literal["regex_parser"] = "regex_parser"
regex_parser: RegexParserGraderParams
@json_schema_type
class EqualityGrader(BaseModel):
type: Literal["equality"] = "equality"
@json_schema_type
class SubsetOfGrader(BaseModel):
type: Literal["subset_of"] = "subset_of"
@json_schema_type
class FactualityGrader(BaseModel):
type: Literal["factuality"] = "factuality"
@json_schema_type
class FaithfulnessGrader(BaseModel):
type: Literal["faithfulness"] = "faithfulness"
GraderDefinition = register_schema(
Annotated[
Union[
LlmGrader,
RegexParserGrader,
EqualityGrader,
SubsetOfGrader,
FactualityGrader,
FaithfulnessGrader,
],
Field(discriminator="type"),
],
name="GraderDefinition",
)
class CommonGraderFields(BaseModel):
grader: GraderDefinition
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
@json_schema_type
class Grader(CommonGraderFields, Resource):
type: Literal["grader"] = "grader"
@property
def grader_id(self) -> str:
return self.identifier
@property
def provider_grader_id(self) -> str:
return self.provider_resource_id
class GraderInput(CommonGraderFields, BaseModel):
grader_id: str
provider_id: Optional[str] = None
provider_grader_id: Optional[str] = None
class ListGradersResponse(BaseModel):
data: List[Grader]
class ListGraderTypesResponse(BaseModel):
data: List[GraderTypeInfo]
@runtime_checkable
class Graders(Protocol):
@webmethod(route="/graders", method="POST")
async def register_grader(
self,
grader: GraderDefinition,
grader_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Grader:
"""
Register a new grader.
:param grader: The grader definition, E.g.
- {
"type": "llm",
"llm": {
"model": "llama-405b",
"prompt": "You are a judge. Score the answer based on the question. {question} {answer}",
}
}
:param grader_id: (Optional) The ID of the grader. If not provided, a random ID will be generated.
:param metadata: (Optional) Any additional metadata for this grader.
- E.g. {
"description": "A grader that scores the answer based on the question.",
}
:return: The registered grader.
"""
...
@webmethod(route="/graders", method="GET")
async def list_graders(self) -> ListGradersResponse:
"""
List all graders.
:return: A list of graders.
"""
...
@webmethod(route="/graders/{grader_id:path}", method="GET")
async def get_grader(self, grader_id: str) -> Grader:
"""
Get a grader by ID.
:param grader_id: The ID of the grader.
:return: The grader.
"""
...
@webmethod(route="/graders/{grader_id:path}", method="DELETE")
async def unregister_grader(self, grader_id: str) -> None:
"""
Unregister a grader by ID.
:param grader_id: The ID of the grader.
"""
...
@webmethod(route="/graders/types", method="GET")
async def list_grader_types(self) -> ListGraderTypesResponse:
"""
List all grader types.
:return: A list of grader types and information about the types.
"""
...

View file

@ -14,6 +14,8 @@ class ResourceType(Enum):
shield = "shield"
vector_db = "vector_db"
dataset = "dataset"
grader = "grader"
# TODO: migrate scoring_function -> grader
scoring_function = "scoring_function"
benchmark = "benchmark"
tool = "tool"

View file

@ -1,78 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.schema_utils import json_schema_type, webmethod
# mapping of metric to value
ScoringResultRow = Dict[str, Any]
@json_schema_type
class ScoringResult(BaseModel):
"""
A scoring result for a single row.
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
:param aggregated_results: Map of metric name to aggregated value
"""
score_rows: List[ScoringResultRow]
# aggregated metrics to value
aggregated_results: Dict[str, Any]
@json_schema_type
class ScoreBatchResponse(BaseModel):
dataset_id: Optional[str] = None
results: Dict[str, ScoringResult]
@json_schema_type
class ScoreResponse(BaseModel):
"""
The response from scoring.
:param results: A map of scoring function name to ScoringResult.
"""
# each key in the dict is a scoring function name
results: Dict[str, ScoringResult]
class ScoringFunctionStore(Protocol):
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
@runtime_checkable
class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score-batch", method="POST")
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score", method="POST")
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse:
"""Score a list of rows.
:param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring.
:return: ScoreResponse object containing rows and aggregated results
"""
...

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring_functions import * # noqa: F401 F403

View file

@ -1,148 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
)
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
@json_schema_type
class ScoringFnParamsType(Enum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
basic = "basic"
@json_schema_type
class AggregationFunctionType(Enum):
average = "average"
weighted_average = "weighted_average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
judge_model: str
prompt_template: Optional[str] = None
judge_score_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
parsing_regexes: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@json_schema_type
class BasicScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
Field(discriminator="type"),
]
register_schema(ScoringFnParams, name="ScoringFnParams")
class CommonScoringFnFields(BaseModel):
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
params: Optional[ScoringFnParams] = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
@property
def scoring_fn_id(self) -> str:
return self.identifier
@property
def provider_scoring_fn_id(self) -> str:
return self.provider_resource_id
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: Optional[str] = None
provider_scoring_fn_id: Optional[str] = None
class ListScoringFunctionsResponse(BaseModel):
data: List[ScoringFn]
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
) -> None: ...

View file

@ -11,13 +11,10 @@ from pydantic import BaseModel, Field
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.resource import Resource
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
@ -125,10 +122,6 @@ class DatasetWithACL(Dataset, ResourceWithACL):
pass
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
pass
class BenchmarkWithACL(Benchmark, ResourceWithACL):
pass
@ -146,7 +139,6 @@ RoutableObject = Union[
Shield,
VectorDB,
Dataset,
ScoringFn,
Benchmark,
Tool,
ToolGroup,
@ -159,7 +151,6 @@ RoutableObjectWithProvider = Annotated[
ShieldWithACL,
VectorDBWithACL,
DatasetWithACL,
ScoringFnWithACL,
BenchmarkWithACL,
ToolWithACL,
ToolGroupWithACL,
@ -172,8 +163,6 @@ RoutedProtocol = Union[
Safety,
VectorIO,
DatasetIO,
Scoring,
Eval,
ToolRuntime,
]
@ -301,7 +290,6 @@ a default SQLite store will be used.""",
shields: List[ShieldInput] = Field(default_factory=list)
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
datasets: List[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
tool_groups: List[ToolGroupInput] = Field(default_factory=list)

View file

@ -40,23 +40,19 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
router_api=Api.datasetio,
),
AutoRoutedApiInfo(
routing_table_api=Api.scoring_functions,
router_api=Api.scoring,
routing_table_api=Api.tool_groups,
router_api=Api.tool_runtime,
),
AutoRoutedApiInfo(
routing_table_api=Api.benchmarks,
router_api=Api.eval,
),
AutoRoutedApiInfo(
routing_table_api=Api.tool_groups,
router_api=Api.tool_runtime,
router_api=Api.evaluation,
),
]
def providable_apis() -> List[Api]:
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
return [api for api in Api if api not in routing_table_apis and api not in [Api.inspect, Api.providers]]
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:

View file

@ -11,7 +11,7 @@ from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.evaluation import Evaluation
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
@ -19,8 +19,6 @@ from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers as ProvidersAPI
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
@ -46,7 +44,6 @@ from llama_stack.providers.datatypes import (
ProviderSpec,
RemoteProviderConfig,
RemoteProviderSpec,
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolsProtocolPrivate,
VectorDBsProtocolPrivate,
@ -73,13 +70,11 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.telemetry: Telemetry,
Api.datasetio: DatasetIO,
Api.datasets: Datasets,
Api.scoring: Scoring,
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
Api.benchmarks: Benchmarks,
Api.post_training: PostTraining,
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
Api.evaluation: Evaluation,
Api.files: Files,
}
@ -91,12 +86,7 @@ def additional_protocols_map() -> Dict[Api, Any]:
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (
ScoringFunctionsProtocolPrivate,
ScoringFunctions,
Api.scoring_functions,
),
Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
Api.evaluation: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
}
@ -119,7 +109,9 @@ async def resolve_impls(
2. Sorting them in dependency order.
3. Instantiating them with required dependencies.
"""
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
routing_table_apis = {
x.routing_table_api for x in builtin_automatically_routed_apis()
}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = validate_and_prepare_providers(
@ -127,7 +119,9 @@ async def resolve_impls(
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
list(providers_with_specs.keys())
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
@ -137,7 +131,9 @@ async def resolve_impls(
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
def specs_for_autorouted_apis(
apis_to_serve: List[str] | Set[str],
) -> Dict[str, Dict[str, ProviderWithSpec]]:
"""Generates specifications for automatically routed APIs."""
specs = {}
for info in builtin_automatically_routed_apis():
@ -179,7 +175,10 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
def validate_and_prepare_providers(
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
routing_table_apis: Set[Api],
router_apis: Set[Api],
) -> Dict[str, Dict[str, ProviderWithSpec]]:
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
@ -187,17 +186,23 @@ def validate_and_prepare_providers(
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
raise ValueError(
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
)
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
logger.warning(
f"Provider `{provider.provider_type}` for API `{api}` is disabled"
)
continue
validate_provider(provider, api, provider_registry)
p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
p.deps__ = [a.value for a in p.api_dependencies] + [
a.value for a in p.optional_api_dependencies
]
spec = ProviderWithSpec(spec=p, **provider.model_dump())
specs[provider.provider_id] = spec
@ -207,10 +212,14 @@ def validate_and_prepare_providers(
return providers_with_specs
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
def validate_provider(
provider: Provider, api: Api, provider_registry: ProviderRegistry
):
"""Validates if the provider is allowed and handles deprecations."""
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
raise ValueError(
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
@ -223,7 +232,8 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
def sort_providers_by_deps(
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]],
run_config: StackRunConfig,
) -> List[Tuple[str, ProviderWithSpec]]:
"""Sorts providers based on their dependencies."""
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
@ -278,11 +288,15 @@ def sort_providers_by_deps(
async def instantiate_providers(
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
sorted_providers: List[Tuple[str, ProviderWithSpec]],
router_apis: Set[Api],
dist_registry: DistributionRegistry,
) -> Dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: Dict[Api, Any] = {}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {
f"inner-{x.value}": {} for x in router_apis
}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies:
@ -291,7 +305,9 @@ async def instantiate_providers(
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
inner_impls = inner_impls_by_provider_id[
f"inner-{provider.spec.router_api.value}"
]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
@ -349,7 +365,9 @@ async def instantiate_provider(
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
raise AttributeError(
f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute"
)
module = importlib.import_module(provider_spec.module)
args = []
@ -386,7 +404,10 @@ async def instantiate_provider(
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
if (
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
@ -414,12 +435,19 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
logger.error(
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
if method_owner is None or method_owner.__name__ == protocol.__name__:
method_owner = next(
(cls for cls in mro if name in cls.__dict__), None
)
if (
method_owner is None
or method_owner.__name__ == protocol.__name__
):
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:

View file

@ -14,7 +14,6 @@ from .routing_tables import (
BenchmarksRoutingTable,
DatasetsRoutingTable,
ModelsRoutingTable,
ScoringFunctionsRoutingTable,
ShieldsRoutingTable,
ToolGroupsRoutingTable,
VectorDBsRoutingTable,
@ -32,7 +31,6 @@ async def get_routing_table_impl(
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable,
"benchmarks": BenchmarksRoutingTable,
"tool_groups": ToolGroupsRoutingTable,
}
@ -48,10 +46,9 @@ async def get_routing_table_impl(
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
from .routers import (
DatasetIORouter,
EvalRouter,
EvaluationRouter,
InferenceRouter,
SafetyRouter,
ScoringRouter,
ToolRuntimeRouter,
VectorIORouter,
)
@ -61,9 +58,8 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict
"inference": InferenceRouter,
"safety": SafetyRouter,
"datasetio": DatasetIORouter,
"scoring": ScoringRouter,
"eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
"evaluation": EvaluationRouter,
}
api_to_deps = {
"inference": {"telemetry": Api.telemetry},

View file

@ -7,14 +7,21 @@
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.datasets import Dataset, DatasetPurpose, DataSource
from llama_stack.apis.evaluation import (
Evaluation,
EvaluationCandidate,
EvaluationJob,
EvaluationResponse,
EvaluationTask,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
@ -36,12 +43,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import (
@ -481,11 +482,11 @@ class DatasetIORouter(DatasetIO):
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> None:
) -> Dataset:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
return await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
@ -515,135 +516,6 @@ class DatasetIORouter(DatasetIO):
)
class ScoringRouter(Scoring):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown")
pass
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown")
pass
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
async def job_status(
self,
benchmark_id: str,
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel(
self,
benchmark_id: str,
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
)
async def job_result(
self,
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
)
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
@ -709,3 +581,57 @@ class ToolRuntimeRouter(ToolRuntime):
) -> List[ToolDef]:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
class EvaluationRouter(Evaluation):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvaluationRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvaluationRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvaluationRouter.shutdown")
pass
async def register_benchmark(
self,
dataset_id: str,
grader_ids: List[str],
benchmark_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Benchmark:
logger.debug(
f"EvaluationRouter.register_benchmark: {benchmark_id=} {dataset_id=} {grader_ids=} {metadata=}",
)
return await self.routing_table.register_benchmark(
benchmark_id=benchmark_id,
dataset_id=dataset_id,
grader_ids=grader_ids,
metadata=metadata,
)
async def run(
self,
task: EvaluationTask,
candidate: EvaluationCandidate,
) -> EvaluationJob:
raise NotImplementedError("Run is not implemented yet")
async def run_sync(
self,
task: EvaluationTask,
candidate: EvaluationCandidate,
) -> EvaluationResponse:
raise NotImplementedError("Run sync is not implemented yet")
async def grade(self, task: EvaluationTask) -> EvaluationJob:
raise NotImplementedError("Grade is not implemented yet")
async def grade_sync(self, task: EvaluationTask) -> EvaluationResponse:
raise NotImplementedError("Grade sync is not implemented yet")

View file

@ -12,7 +12,6 @@ from pydantic import TypeAdapter
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
@ -25,12 +24,6 @@ from llama_stack.apis.datasets import (
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.apis.tools import (
ListToolGroupsResponse,
@ -50,7 +43,6 @@ from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
ScoringFnWithACL,
ShieldWithACL,
ToolGroupWithACL,
ToolWithACL,
@ -81,10 +73,6 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
return await p.register_vector_db(obj)
elif api == Api.datasetio:
return await p.register_dataset(obj)
elif api == Api.scoring:
return await p.register_scoring_function(obj)
elif api == Api.eval:
return await p.register_benchmark(obj)
elif api == Api.tool_runtime:
return await p.register_tool(obj)
else:
@ -130,7 +118,7 @@ class CommonRoutingTableImpl(RoutingTable):
await self.dist_registry.register(obj)
# Register all objects from providers
for pid, p in self.impls_by_provider_id.items():
for _pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
@ -140,12 +128,6 @@ class CommonRoutingTableImpl(RoutingTable):
p.vector_db_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.benchmark_store = self
elif api == Api.tool_runtime:
p.tool_store = self
@ -163,8 +145,6 @@ class CommonRoutingTableImpl(RoutingTable):
return ("VectorIO", "vector_db")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable):
@ -457,46 +437,6 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
await self.unregister_object(dataset)
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
) -> None:
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFnWithACL(
identifier=scoring_fn_id,
description=description,
return_type=return_type,
provider_resource_id=provider_scoring_fn_id,
provider_id=provider_id,
params=params,
)
scoring_fn.provider_id = provider_id
await self.register_object(scoring_fn)
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
@ -507,35 +447,38 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
raise ValueError(f"Benchmark '{benchmark_id}' not found")
return benchmark
async def unregister_benchmark(self, benchmark_id: str) -> None:
benchmark = await self.get_benchmark(benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark {benchmark_id} not found")
await self.unregister_object(benchmark)
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: List[str],
grader_ids: List[str],
benchmark_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
) -> None:
) -> Benchmark:
if metadata is None:
metadata = {}
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
# TODO (xiyan): we will need a way to infer provider_id for evaluation
# keep it as meta-reference for now
if len(self.impls_by_provider_id) == 0:
raise ValueError("No evaluation providers available. Please configure an evaluation provider.")
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
benchmark = BenchmarkWithACL(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
grader_ids=grader_ids,
metadata=metadata,
provider_id=provider_id,
provider_resource_id=provider_benchmark_id,
provider_resource_id=benchmark_id,
)
await self.register_object(benchmark)
return benchmark
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):

View file

@ -17,16 +17,15 @@ from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.evaluation import Evaluation
from llama_stack.apis.files import Files
from llama_stack.apis.graders import Graders
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
@ -56,10 +55,7 @@ class LlamaStack(
Telemetry,
PostTraining,
VectorIO,
Eval,
Benchmarks,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
@ -68,6 +64,8 @@ class LlamaStack(
ToolRuntime,
RAGToolRuntime,
Files,
Graders,
Evaluation,
):
pass
@ -77,12 +75,6 @@ RESOURCES = [
("shields", Api.shields, "register_shield", "list_shields"),
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
Api.scoring_functions,
"register_scoring_function",
"list_scoring_functions",
),
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
]

View file

@ -26,7 +26,10 @@ class LlamaStackApi:
"""Run scoring on a single row"""
if not scoring_params:
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
# TODO(xiyan): fix this
# return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
raise NotImplementedError("Scoring is not implemented")
llama_stack_api = LlamaStackApi()

View file

@ -9,7 +9,6 @@ from streamlit_option_menu import option_menu
from llama_stack.distribution.ui.page.distribution.datasets import datasets
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.distribution.ui.page.distribution.models import models
from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions
from llama_stack.distribution.ui.page.distribution.shields import shields
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
@ -43,8 +42,9 @@ def resources_page():
datasets()
elif selected_resource == "Models":
models()
elif selected_resource == "Scoring Functions":
scoring_functions()
# TODO(xiyan): fix this
# elif selected_resource == "Scoring Functions":
# scoring_functions()
elif selected_resource == "Shields":
shields()

View file

@ -13,7 +13,6 @@ from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasets import Dataset
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool
from llama_stack.apis.vector_dbs import VectorDB
@ -42,12 +41,6 @@ class DatasetsProtocolPrivate(Protocol):
async def unregister_dataset(self, dataset_id: str) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFn]: ...
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
class BenchmarksProtocolPrivate(Protocol):
async def register_benchmark(self, benchmark: Benchmark) -> None: ...

View file

@ -1,234 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List
from tqdm import tqdm
from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:"
class MetaReferenceEvalImpl(
Eval,
BenchmarksProtocolPrivate,
):
def __init__(
self,
config: MetaReferenceEvalConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
scoring_api: Scoring,
inference_api: Inference,
agents_api: Agents,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.scoring_api = scoring_api
self.inference_api = inference_api
self.agents_api = agents_api
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
self.benchmarks = {}
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing benchmarks from kvstore
start_key = EVAL_TASKS_PREFIX
end_key = f"{EVAL_TASKS_PREFIX}\xff"
stored_benchmarks = await self.kvstore.range(start_key, end_key)
for benchmark in stored_benchmarks:
benchmark = Benchmark.model_validate_json(benchmark)
self.benchmarks[benchmark.identifier] = benchmark
async def shutdown(self) -> None: ...
async def register_benchmark(self, task_def: Benchmark) -> None:
# Store in kvstore
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
await self.kvstore.set(
key=key,
value=task_def.model_dump_json(),
)
self.benchmarks[task_def.identifier] = task_def
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions
# TODO (xiyan): validate dataset schema
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
)
res = await self.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
# TODO: currently needs to wait for generation before returning
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs))
self.jobs[job_id] = res
return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]:
candidate = benchmark_config.eval_candidate
create_response = await self.agents_api.create_agent(candidate.config)
agent_id = create_response.agent_id
generations = []
for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
session_id = session_create_response.session_id
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=input_messages,
stream=True,
)
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
final_event = turn_response[-1].event.payload
# check if there's a memory retrieval step and extract the context
memory_rag_context = None
for step in final_event.turn.steps:
if step.step_type == StepType.tool_execution.value:
for tool_response in step.tool_responses:
if tool_response.tool_name == MEMORY_QUERY_TOOL:
memory_rag_context = " ".join(x.text for x in tool_response.content)
agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context
generations.append(agent_generation)
return generations
async def _run_model_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]:
candidate = benchmark_config.eval_candidate
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
generations = []
for x in tqdm(input_rows):
if ColumnName.completion_input.value in x:
input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.completion(
model=candidate.model,
content=input_content,
sampling_params=candidate.sampling_params,
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages
response = await self.inference_api.chat_completion(
model_id=candidate.model,
messages=messages,
sampling_params=candidate.sampling_params,
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
else:
raise ValueError("Invalid input row")
return generations
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
candidate = benchmark_config.eval_candidate
if candidate.type == "agent":
generations = await self._run_agent_generation(input_rows, benchmark_config)
elif candidate.type == "model":
generations = await self._run_model_generation(input_rows, benchmark_config)
else:
raise ValueError(f"Invalid candidate type: {candidate.type}")
# scoring with generated_answer
score_input_rows = [
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
]
if benchmark_config.scoring_params is not None:
scoring_functions_dict = {
scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
)
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs:
return Job(job_id=job_id, status=JobStatus.completed)
raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")
return self.jobs[job_id]

View file

@ -7,20 +7,19 @@ from typing import Any, Dict
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceEvalConfig
from .config import MetaReferenceEvaluationConfig
async def get_provider_impl(
config: MetaReferenceEvalConfig,
config: MetaReferenceEvaluationConfig,
deps: Dict[Api, Any],
):
from .eval import MetaReferenceEvalImpl
from .evaluation import MetaReferenceEvaluationImpl
impl = MetaReferenceEvalImpl(
impl = MetaReferenceEvaluationImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
deps[Api.scoring],
deps[Api.inference],
deps[Api.agents],
)

View file

@ -13,7 +13,7 @@ from llama_stack.providers.utils.kvstore.config import (
)
class MetaReferenceEvalConfig(BaseModel):
class MetaReferenceEvaluationConfig(BaseModel):
kvstore: KVStoreConfig
@classmethod
@ -21,6 +21,6 @@ class MetaReferenceEvalConfig(BaseModel):
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="meta_reference_eval.db",
db_name="meta_reference_evaluation.db",
)
}

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from .....apis.benchmarks import Benchmark
from .....apis.evaluation.evaluation import (
Evaluation,
EvaluationCandidate,
EvaluationJob,
EvaluationResponse,
EvaluationTask,
)
from .config import MetaReferenceEvaluationConfig
EVAL_TASKS_PREFIX = "benchmarks:"
class MetaReferenceEvaluationImpl(
Evaluation,
BenchmarksProtocolPrivate,
):
def __init__(
self,
config: MetaReferenceEvaluationConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
inference_api: Inference,
agents_api: Agents,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.inference_api = inference_api
self.agents_api = agents_api
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_benchmark(self, benchmark: Benchmark) -> None:
pass
async def run(
self,
task: EvaluationTask,
candidate: EvaluationCandidate,
) -> EvaluationJob:
raise NotImplementedError("Run is not implemented yet")
async def run_sync(
self,
task: EvaluationTask,
candidate: EvaluationCandidate,
) -> EvaluationResponse:
raise NotImplementedError("Run sync is not implemented yet")
async def grade(self, task: EvaluationTask) -> EvaluationJob:
raise NotImplementedError("Grade is not implemented yet")
async def grade_sync(self, task: EvaluationTask) -> EvaluationResponse:
raise NotImplementedError("Grade sync is not implemented yet")

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,25 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api
from .config import BasicScoringConfig
async def get_provider_impl(
config: BasicScoringConfig,
deps: Dict[Api, Any],
):
from .scoring import BasicScoringImpl
impl = BasicScoringImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
await impl.initialize()
return impl

View file

@ -1,14 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class BasicScoringConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {}

View file

@ -1,128 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
)
from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn,
)
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [
EqualityScoringFn,
SubsetOfScoringFn,
RegexParserScoringFn,
RegexParserMathResponseScoringFn,
BFCLScoringFn,
IfEvalScoringFn,
DocVQAScoringFn,
]
class BasicScoringImpl(
Scoring,
ScoringFunctionsProtocolPrivate,
):
def __init__(
self,
config: BasicScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.scoring_fn_id_impls = {}
async def initialize(self) -> None:
for fn in FIXED_FNS:
impl = fn()
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
]
for f in scoring_fn_defs_list:
assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! "
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
res = await self.score(
input_rows=all_rows.data,
scoring_functions=scoring_functions,
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res.results,
)
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
)
return ScoreResponse(
results=res,
)

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,93 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from ..utils.bfcl.ast_parser import decode_ast
from ..utils.bfcl.checker import ast_checker, is_empty_output
from .fn_defs.bfcl import bfcl
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
contain_func_call = False
error = None
error_type = None
checker_result = {}
try:
prediction = decode_ast(x["generated_answer"], x["language"]) or ""
contain_func_call = True
# if not is_function_calling_format_output(prediction):
if is_empty_output(prediction):
contain_func_call = False
error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability."
error_type = "ast_decoder:decoder_wrong_output_format"
else:
checker_result = ast_checker(
json.loads(x["function"]),
prediction,
json.loads(x["ground_truth"]),
x["language"],
test_category=test_category,
model_name="",
)
except Exception as e:
prediction = ""
error = f"Invalid syntax. Failed to decode AST. {str(e)}"
error_type = "ast_decoder:decoder_failed"
return {
"prediction": prediction,
"contain_func_call": contain_func_call,
"valid": checker_result.get("valid", False),
"error": error or checker_result.get("error", ""),
"error_type": error_type or checker_result.get("error_type", ""),
}
def gen_valid(x: Dict[str, Any]) -> Dict[str, float]:
return {"valid": x["valid"]}
def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
# If `test_category` is "irrelevance", the model is expected to output no function call.
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
return {"valid": float(acc)}
class BFCLScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn for BFCL
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
bfcl.identifier: bfcl,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "bfcl",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
score_result = postprocess(input_row, test_category)
if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
score = gen_relevance_acc(score_result)["valid"]
else:
score = gen_valid(score_result)["valid"]
return {
"score": float(score),
}

View file

@ -1,240 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.docvqa import docvqa
CONTRACTIONS = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
"1st": "first",
"2nd": "second",
"3rd": "third",
}
NUMBERS = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
ARTICLES = [
"a",
"an",
"the",
"to",
"in",
"from",
"by",
] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
PUNCTUATION = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def normalize_answer(s: str) -> str:
# process punctuation
for p in PUNCTUATION:
if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None):
s = s.replace(p, "")
else:
s = s.replace(p, " ")
s = PERIOD_STRIP.sub("", s, re.UNICODE)
# process digits and articles
temp_text = s.lower().split()
out_text = []
for word in temp_text:
word = NUMBERS.setdefault(word, word)
if word not in ARTICLES:
out_text.append(word)
# standardize contractions
for word_id, word in enumerate(out_text):
if word in CONTRACTIONS:
out_text[word_id] = CONTRACTIONS[word]
return " ".join(out_text)
class DocVQAScoringFn(RegisteredBaseScoringFn):
"""
docvqa basically matches the generated answer against several allowed
choices, but we need to normalize the answer to avoid penalizing
trivial differences
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
docvqa.identifier: docvqa,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "docvqa",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
expected_answers = json.loads(input_row["expected_answer"])
generated_answer = input_row["generated_answer"]
score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0
return {
"score": score,
}

View file

@ -1,41 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.equality import equality
class EqualityScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
equality.identifier: equality,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "equality",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert "generated_answer" in input_row, "Generated answer not found in input row."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer == generated_answer else 0.0
return {
"score": score,
}

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
bfcl = ScoringFn(
identifier="basic::bfcl",
description="BFCL complex scoring",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="bfcl",
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
docvqa = ScoringFn(
identifier="basic::docvqa",
description="DocVQA Visual Question & Answer scoring function",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="docvqa",
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
equality = ScoringFn(
identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
provider_id="basic",
provider_resource_id="equality",
return_type=NumberType(),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
ifeval = ScoringFn(
identifier="basic::ifeval",
description="Eval intruction follow capacity by checkping how many instructions can be followed in each example",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="ifeval",
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.weighted_average],
),
)

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
RegexParserScoringFnParams,
ScoringFn,
)
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
regex_parser_math_response = ScoringFn(
identifier="basic::regex_parser_math_response",
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="regex-parser-math-response",
params=RegexParserScoringFnParams(
parsing_regexes=MATH_ANSWER_REGEXES,
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -1,71 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
RegexParserScoringFnParams,
ScoringFn,
)
MULTILINGUAL_ANSWER_REGEXES = [
r"The best answer is ",
r"Answer\s*:",
r"Answer\s*:", # Korean invisible character
r"উত্তর\s*:",
r"उत्तर\s*:",
r"উত্তরঃ",
r"উত্তর\s*:",
r"Antwort\s*:",
r"답변\s*:",
r"정답\s*:",
r"\s*:",
r"答案\s*",
r"答案\s*:",
r"\s*",
r"\s*:",
r"答复\s*",
r"答曰\s*",
r"الإجابة:",
r"الجواب:",
r"إجابة:",
r"الإجابة النهائية:",
r"الإجابة الصحيحة:",
r"الإجابة الصحيحة هي:",
r"الإجابة هي:",
r"Respuesta\s*:",
r"Risposta\s*:",
r"答え\s*:",
r"答え\s*",
r"回答\s*:",
r"回答\s*",
r"解答\s*:",
r"Jawaban\s*:",
r"Réponse\s*:",
r"Resposta\s*:",
r"Jibu\s*:",
r"Idahun\s*:",
r"Ìdáhùn\s*:",
r"Idáhùn\s*:",
r"Àmọ̀nà\s*:",
r"Àdáhùn\s*:",
r"Ànúgọ\s*:",
r"Àṣàyàn\s*:",
]
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
regex_parser_multiple_choice_answer = ScoringFn(
identifier="basic::regex_parser_multiple_choice_answer",
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="regex-parser-multiple-choice-answer",
params=RegexParserScoringFnParams(
parsing_regexes=[MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) for x in MULTILINGUAL_ANSWER_REGEXES],
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
subset_of = ScoringFn(
identifier="basic::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="subset-of",
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -1,80 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.ifeval import (
ifeval,
)
class IfEvalScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn Instruction-Following Eval (IFEval) benchmark
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
ifeval.identifier: ifeval,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
instruction_list = input_row["instruction_id_list"]
generated_answer = input_row["generated_answer"].strip()
is_following_list = []
results = dict(
{k + "_correct": 0.0 for k in INSTRUCTION_LIST},
**{k + "_total": 0.0 for k in INSTRUCTION_LIST},
)
for index, instruction_id in enumerate(instruction_list):
instruction_cls = INSTRUCTION_DICT[instruction_id]
instruction = instruction_cls(instruction_id)
results[instruction_id + "_total"] += 1.0
results[instruction_id.split(":")[0] + "_total"] += 1.0
clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
print(clean_input_row)
instruction.build_description(**clean_input_row)
args = instruction.get_instruction_args()
if args and "prompt" in args:
instruction.build_description(prompt=input_row["prompt"])
if generated_answer and instruction.check_following(generated_answer):
is_following_list.append(True)
results[instruction_id + "_correct"] += 1.0
results[instruction_id.split(":")[0] + "_correct"] += 1.0
else:
is_following_list.append(False)
if len(is_following_list) == 0:
return {
"score": 0.0,
"weight": 0.0,
}
return {
"score": float(sum(is_following_list)) / float(len(is_following_list)),
"weight": float(len(is_following_list)),
}

View file

@ -1,66 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
from .fn_defs.regex_parser_math_response import (
regex_parser_math_response,
)
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
regex_parser_math_response.identifier: regex_parser_math_response,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
f"RegexParserScoringFnParams not found for {fn_def}."
)
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
parsing_regexes = fn_def.params.parsing_regexes
assert len(parsing_regexes) == 1, (
"Only one parsing regex is supported for regex_parser_math_response scoring function."
)
parsing_regexes = fn_def.params.parsing_regexes[0]
normalized_generated_answer = normalize_final_answer(
first_answer(generated_answer),
parsing_regexes,
match_first=True,
)
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
return {
"score": score,
}

View file

@ -1,58 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
)
class RegexParserScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
f"RegexParserScoringFnParams not found for {fn_def}."
)
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
# parse answer according to regex
parsed_answer = None
for regex in fn_def.params.parsing_regexes:
match = re.search(regex, generated_answer)
if match:
parsed_answer = match.group(1)
break
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
return {
"score": score,
}

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.subset_of import subset_of
class SubsetOfScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
subset_of.identifier: subset_of,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "subset_of",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer in generated_answer else 0.0
return {
"score": score,
}

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,296 +0,0 @@
# ruff: noqa
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import ast
from .tree_sitter import get_parser
def parse_java_function_call(source_code):
if not source_code.endswith(";"):
source_code += ";" # Necessary for the parser not to register an error
parser = get_parser("java")
tree = parser.parse(bytes(source_code, "utf8"))
root_node = tree.root_node
if root_node.has_error:
raise Exception("Error parsing java the source code.")
def get_text(node):
"""Returns the text represented by the node."""
return source_code[node.start_byte : node.end_byte]
def traverse_node(node, nested=False):
if node.type == "string_literal":
if nested:
return get_text(node)
# Strip surrounding quotes from string literals
return get_text(node)[1:-1]
elif node.type == "character_literal":
if nested:
return get_text(node)
# Strip surrounding single quotes from character literals
return get_text(node)[1:-1]
"""Traverse the node to collect texts for complex structures."""
if node.type in [
"identifier",
"class_literal",
"type_identifier",
"method_invocation",
]:
return get_text(node)
elif node.type == "array_creation_expression":
# Handle array creation expression specifically
type_node = node.child_by_field_name("type")
value_node = node.child_by_field_name("value")
type_text = traverse_node(type_node, True)
value_text = traverse_node(value_node, True)
return f"new {type_text}[]{value_text}"
elif node.type == "object_creation_expression":
# Handle object creation expression specifically
type_node = node.child_by_field_name("type")
arguments_node = node.child_by_field_name("arguments")
type_text = traverse_node(type_node, True)
if arguments_node:
# Process each argument carefully, avoiding unnecessary punctuation
argument_texts = []
for child in arguments_node.children:
if child.type not in [
",",
"(",
")",
]: # Exclude commas and parentheses
argument_text = traverse_node(child, True)
argument_texts.append(argument_text)
arguments_text = ", ".join(argument_texts)
return f"new {type_text}({arguments_text})"
else:
return f"new {type_text}()"
elif node.type == "set":
# Handling sets specifically
items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
return "{" + ", ".join(items) + "}"
elif node.child_count > 0:
return "".join(traverse_node(child, True) for child in node.children)
else:
return get_text(node)
def extract_arguments(args_node):
arguments = {}
for child in args_node.children:
if child.type == "assignment_expression":
# For named parameters
name_node, value_node = child.children[0], child.children[2]
name = get_text(name_node)
value = traverse_node(value_node)
if name in arguments:
if not isinstance(arguments[name], list):
arguments[name] = [arguments[name]]
arguments[name].append(value)
else:
arguments[name] = value
# arguments.append({'name': name, 'value': value})
elif child.type in ["identifier", "class_literal", "set"]:
# For unnamed parameters and handling sets
value = traverse_node(child)
if None in arguments:
if not isinstance(arguments[None], list):
arguments[None] = [arguments[None]]
arguments[None].append(value)
else:
arguments[None] = value
return arguments
def traverse(node):
if node.type == "method_invocation":
# Extract the function name and its arguments
method_name = get_text(node.child_by_field_name("name"))
class_name_node = node.child_by_field_name("object")
if class_name_node:
class_name = get_text(class_name_node)
function_name = f"{class_name}.{method_name}"
else:
function_name = method_name
arguments_node = node.child_by_field_name("arguments")
if arguments_node:
arguments = extract_arguments(arguments_node)
for key, value in arguments.items():
if isinstance(value, list):
raise Exception("Error: Multiple arguments with the same name are not supported.")
return [{function_name: arguments}]
else:
for child in node.children:
result = traverse(child)
if result:
return result
result = traverse(root_node)
return result if result else {}
def parse_javascript_function_call(source_code):
if not source_code.endswith(";"):
source_code += ";" # Necessary for the parser not to register an error
parser = get_parser("javascript")
# Parse the source code
tree = parser.parse(bytes(source_code, "utf8"))
root_node = tree.root_node
if root_node.has_error:
raise Exception("Error js parsing the source code.")
# Function to recursively extract argument details
def extract_arguments(node):
args = {}
for child in node.children:
if child.type == "assignment_expression":
# Extract left (name) and right (value) parts of the assignment
name = child.children[0].text.decode("utf-8")
value = child.children[2].text.decode("utf-8")
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
value = value[1:-1] # Trim the quotation marks
if name in args:
if not isinstance(args[name], list):
args[name] = [args[name]]
args[name].append(value)
else:
args[name] = value
elif child.type == "identifier" or child.type == "true":
# Handle non-named arguments and boolean values
value = child.text.decode("utf-8")
if None in args:
if not isinstance(args[None], list):
args[None] = [args[None]]
args[None].append(value)
else:
args[None] = value
return args
# Find the function call and extract its name and arguments
if root_node.type == "program":
for child in root_node.children:
if child.type == "expression_statement":
for sub_child in child.children:
if sub_child.type == "call_expression":
function_name = sub_child.children[0].text.decode("utf8")
arguments_node = sub_child.children[1]
parameters = extract_arguments(arguments_node)
for key, value in parameters.items():
if isinstance(value, list):
raise Exception("Error: Multiple arguments with the same name are not supported.")
result = [{function_name: parameters}]
return result
def ast_parse(input_str, language="Python"):
if language == "Python":
cleaned_input = input_str.strip("[]'")
parsed = ast.parse(cleaned_input, mode="eval")
extracted = []
if isinstance(parsed.body, ast.Call):
extracted.append(resolve_ast_call(parsed.body))
else:
for elem in parsed.body.elts:
extracted.append(resolve_ast_call(elem))
return extracted
elif language == "Java":
return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
elif language == "JavaScript":
return parse_javascript_function_call(input_str[1:-1])
else:
raise NotImplementedError(f"Unsupported language: {language}")
def resolve_ast_call(elem):
# Handle nested attributes for deeply nested module paths
func_parts = []
func_part = elem.func
while isinstance(func_part, ast.Attribute):
func_parts.append(func_part.attr)
func_part = func_part.value
if isinstance(func_part, ast.Name):
func_parts.append(func_part.id)
func_name = ".".join(reversed(func_parts))
args_dict = {}
# Parse when args are simply passed as an unnamed dictionary arg
for arg in elem.args:
if isinstance(arg, ast.Dict):
for key, value in zip(arg.keys, arg.values):
if isinstance(key, ast.Constant):
arg_name = key.value
output = resolve_ast_by_type(value)
args_dict[arg_name] = output
for arg in elem.keywords:
output = resolve_ast_by_type(arg.value)
args_dict[arg.arg] = output
return {func_name: args_dict}
def resolve_ast_by_type(value):
if isinstance(value, ast.Constant):
if value.value is Ellipsis:
output = "..."
else:
output = value.value
elif isinstance(value, ast.UnaryOp):
output = -value.operand.value
elif isinstance(value, ast.List):
output = [resolve_ast_by_type(v) for v in value.elts]
elif isinstance(value, ast.Dict):
output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
output = value.value
elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
output = eval(ast.unparse(value))
elif isinstance(value, ast.Name):
output = value.id
elif isinstance(value, ast.Call):
if len(value.keywords) == 0:
output = ast.unparse(value)
else:
output = resolve_ast_call(value)
elif isinstance(value, ast.Tuple):
output = tuple(resolve_ast_by_type(v) for v in value.elts)
elif isinstance(value, ast.Lambda):
output = eval(ast.unparse(value.body[0].value))
elif isinstance(value, ast.Ellipsis):
output = "..."
elif isinstance(value, ast.Subscript):
try:
output = ast.unparse(value.body[0].value)
except:
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
else:
raise Exception(f"Unsupported AST type: {type(value)}")
return output
def decode_ast(result, language="Python"):
func = result
func = func.replace("\n", "") # remove new line characters
if not func.startswith("["):
func = "[" + func
if not func.endswith("]"):
func = func + "]"
decoded_output = ast_parse(func, language)
return decoded_output
def decode_execute(result):
func = result
func = func.replace("\n", "") # remove new line characters
if not func.startswith("["):
func = "[" + func
if not func.endswith("]"):
func = func + "]"
decode_output = ast_parse(func)
execution_list = []
for function_call in decode_output:
for key, value in function_call.items():
execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
return execution_list

View file

@ -1,989 +0,0 @@
# ruff: noqa
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
import time
from typing import Any
# Comment out for now until we actually use the rest checker in evals
# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function.
class NoAPIKeyError(Exception):
def __init__(self):
self.message = "Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate."
super().__init__(self.message)
REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2
JAVA_TYPE_CONVERSION = {
"byte": int,
"short": int,
"integer": int,
"float": float,
"double": float,
"long": int,
"boolean": bool,
"char": str,
"Array": list,
"ArrayList": list,
"Set": set,
"HashMap": dict,
"Hashtable": dict,
"Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list
"Stack": list,
"String": str,
"any": str,
}
JS_TYPE_CONVERSION = {
"String": str,
"integer": int,
"float": float,
"Bigint": int,
"Boolean": bool,
"dict": dict,
"array": list,
"any": str,
}
# We switch to conditional import for the following two imports to avoid unnecessary installations.
# User doesn't need to setup the tree-sitter packages if they are not running the test for that language.
# from js_type_converter import js_type_converter
# from java_type_converter import java_type_converter
PYTHON_TYPE_MAPPING = {
"string": str,
"integer": int,
"float": float,
"boolean": bool,
"array": list,
"tuple": list,
"dict": dict,
"any": str,
}
# This is the list of types that we need to recursively check its values
PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"]
NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"]
#### Helper functions for AST ####
def find_description(func_descriptions, name):
if type(func_descriptions) == list:
for func_description in func_descriptions:
if func_description["name"] == name:
return func_description
return None
else:
# it is a dict, there is only one function
return func_descriptions
def get_possible_answer_type(possible_answer: list):
for answer in possible_answer:
if answer != "": # Optional parameter
return type(answer)
return None
def type_checker(
param: str,
value,
possible_answer: list,
expected_type_description: str,
expected_type_converted,
nested_type_converted,
):
# NOTE: This type checker only supports nested type checking for one level deep.
# We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex.
result: Any = {
"valid": True,
"error": [],
"is_variable": False,
"error_type": "type_error:simple",
}
is_variable = False
# check for the case where a variable is used instead of a actual value.
# use the type in possible_answer as the expected type
possible_answer_type = get_possible_answer_type(possible_answer)
# if possible_answer only contains optional parameters, we can't determine the type
if possible_answer_type != None:
# we are being precise here.
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
if possible_answer_type != expected_type_converted:
is_variable = True
# value is the same type as in function description
if type(value) == expected_type_converted:
# We don't need to do recursive check for simple types
if nested_type_converted == None:
result["is_variable"] = is_variable
return result
else:
for possible_answer_item in possible_answer:
flag = True # Each parameter should match to at least one possible answer type.
# Here, we assume that each item should be the same type. We could also relax it.
if type(possible_answer_item) == list:
for value_item in value:
checker_result = type_checker(
param,
value_item,
possible_answer_item,
str(nested_type_converted),
nested_type_converted,
None,
)
if not checker_result["valid"]:
flag = False
break
if flag:
return {"valid": True, "error": [], "is_variable": is_variable}
result["valid"] = False
result["error"] = [
f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}."
]
result["error_type"] = "type_error:nested"
# value is not as expected, check for the case where a variable is used instead of a actual value
# use the type in possible_answer as the expected type
possible_answer_type = get_possible_answer_type(possible_answer)
# if possible_answer only contains optional parameters, we can't determine the type
if possible_answer_type != None:
# we are being precise here.
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
if type(value) == possible_answer_type:
result["is_variable"] = True
return result
result["valid"] = False
result["error"].append(
f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}."
)
result["error_type"] = "type_error:simple"
return result
def standardize_string(input_string: str):
# This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase
# It will also convert all the single quotes to double quotes
# This is used to compare the model output with the possible answers
# We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024
regex_string = r"[ \,\.\/\-\_\*\^]"
return re.sub(regex_string, "", input_string).lower().replace("'", '"')
def string_checker(param: str, model_output: str, possible_answer: list):
standardize_possible_answer = []
standardize_model_output = standardize_string(model_output)
for i in range(len(possible_answer)):
if type(possible_answer[i]) == str:
standardize_possible_answer.append(standardize_string(possible_answer[i]))
if standardize_model_output not in standardize_possible_answer:
return {
"valid": False,
"error": [
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive."
],
"error_type": "value_error:string",
}
return {"valid": True, "error": []}
def list_checker(param: str, model_output: list, possible_answer: list):
# Convert the tuple to a list
standardize_model_output = list(model_output)
# If the element in the list is a string, we need to standardize it
for i in range(len(standardize_model_output)):
if type(standardize_model_output[i]) == str:
standardize_model_output[i] = standardize_string(model_output[i])
standardize_possible_answer: Any = []
# We also need to standardize the possible answers
for i in range(len(possible_answer)):
standardize_possible_answer.append([])
for j in range(len(possible_answer[i])):
if type(possible_answer[i][j]) == str:
standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
else:
standardize_possible_answer[i].append(possible_answer[i][j])
if standardize_model_output not in standardize_possible_answer:
return {
"valid": False,
"error": [
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}."
],
"error_type": "value_error:list/tuple",
}
return {"valid": True, "error": []}
def dict_checker(param: str, model_output: dict, possible_answers: list):
# This function works for simple dictionaries, but not dictionaries with nested dictionaries.
# The current dataset only contains simple dictionaries, so this is sufficient.
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
for i in range(len(possible_answers)):
if possible_answers[i] == "":
continue
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
flag = True
possible_answer = possible_answers[i]
# possible_anwer is a single dictionary
for key, value in model_output.items():
if key not in possible_answer:
result["valid"] = False
result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined]
result["error_type"] = "value_error:dict_key"
flag = False
break
standardize_value = value
# If the value is a string, we need to standardize it
if type(value) == str:
standardize_value = standardize_string(value)
# We also need to standardize the possible answers if they are string
standardize_possible_answer = []
for i in range(len(possible_answer[key])):
if type(possible_answer[key][i]) == str:
standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
else:
standardize_possible_answer.append(possible_answer[key][i])
if standardize_value not in standardize_possible_answer:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}."
)
result["error_type"] = "value_error:dict_value"
flag = False
break
for key, value in possible_answer.items():
if key not in model_output and "" not in value:
result["valid"] = False
result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined]
result["error_type"] = "value_error:dict_key"
flag = False
break
if flag:
return {"valid": True, "error": []}
return result
def list_dict_checker(param: str, model_output: list, possible_answers: list):
# This function takes in a list of dictionaries and checks if each dictionary is valid
# The order of the dictionaries in the list must match the order of the possible answers
result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"}
for answer_index in range(len(possible_answers)):
flag = True # True means so far, all dictionaries are valid
# Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers
if len(model_output) != len(possible_answers[answer_index]):
result["valid"] = False
result["error"] = ["Wrong number of dictionaries in the list."]
result["error_type"] = "value_error:list_dict_count"
flag = False
continue
for dict_index in range(len(model_output)):
result = dict_checker(
param,
model_output[dict_index],
[possible_answers[answer_index][dict_index]],
)
if not result["valid"]:
flag = False
break
if flag:
return {"valid": True, "error": []}
return result
def simple_function_checker(
func_description: dict,
model_output: dict,
possible_answer: dict,
language: str,
model_name: str,
):
possible_answer = list(possible_answer.values())[0]
# Extract function name and parameters details
func_name = func_description["name"]
param_details = func_description["parameters"]["properties"]
required_params = func_description["parameters"]["required"]
# Initialize a result dictionary
result = {
"valid": True,
"error": [],
"error_type": "simple_function_checker:unclear",
}
# Check if function name matches
if func_name not in model_output:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Function name {repr(func_name)} not found in model output."
)
result["error_type"] = "simple_function_checker:wrong_func_name"
return result
model_params = model_output[func_name]
# Check for required parameters in model output
for param in required_params:
if param not in model_params:
result["valid"] = False
result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined]
result["error_type"] = "simple_function_checker:missing_required"
return result
# Validate types and values for each parameter in model output
for param, value in model_params.items():
if param not in param_details or param not in possible_answer:
result["valid"] = False
result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined]
result["error_type"] = "simple_function_checker:unexpected_param"
return result
full_param_details = param_details[param]
expected_type_description = full_param_details["type"] # This is a string
is_variable = False
nested_type_converted = None
if language == "Java":
from evals.utils.bfcl.java_type_converter import java_type_converter
expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description]
if expected_type_description in JAVA_TYPE_CONVERSION:
if type(value) != str:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
)
result["error_type"] = "type_error:java"
return result
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
nested_type = param_details[param]["items"]["type"]
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
value = java_type_converter(value, expected_type_description, nested_type)
else:
value = java_type_converter(value, expected_type_description)
elif language == "JavaScript":
from evals.utils.bfcl.js_type_converter import js_type_converter
expected_type_converted = JS_TYPE_CONVERSION[expected_type_description]
if expected_type_description in JS_TYPE_CONVERSION:
if type(value) != str:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
)
result["error_type"] = "type_error:js"
return result
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
nested_type = param_details[param]["items"]["type"]
nested_type_converted = JS_TYPE_CONVERSION[nested_type]
value = js_type_converter(value, expected_type_description, nested_type)
else:
value = js_type_converter(value, expected_type_description)
elif language == "Python":
expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description]
if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST:
nested_type = param_details[param]["items"]["type"]
nested_type_converted = PYTHON_TYPE_MAPPING[nested_type]
# We convert all tuple value to list when the expected type is tuple.
# The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load().
# This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future.
if expected_type_description == "tuple" and type(value) == tuple:
value = list(value)
# Allow python auto conversion from int to float
if language == "Python" and expected_type_description == "float" and type(value) == int:
value = float(value)
# Type checking
# In fact, we only check for Python here.
# Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct.
type_check_result = type_checker(
param,
value,
possible_answer[param],
expected_type_description,
expected_type_converted,
nested_type_converted,
)
is_variable = type_check_result["is_variable"]
if not type_check_result["valid"]:
return type_check_result
# It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable.
# We can just treat the variable as a string and use the normal flow.
if not is_variable:
# Special handle for dictionaries
if expected_type_converted == dict:
result = dict_checker(param, value, possible_answer[param])
if not result["valid"]:
return result
continue
# Special handle for list of dictionaries
elif expected_type_converted == list and nested_type_converted == dict:
result = list_dict_checker(param, value, possible_answer[param])
if not result["valid"]:
return result
continue
# Special handle for strings
elif expected_type_converted == str:
# We don't check for case sensitivity for string, as long as it's not a variable
result = string_checker(param, value, possible_answer[param])
if not result["valid"]:
return result
continue
elif expected_type_converted == list:
result = list_checker(param, value, possible_answer[param])
if not result["valid"]:
return result
continue
# Check if the value is within the possible answers
if value not in possible_answer[param]:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}."
)
result["error_type"] = "value_error:others"
return result
# Check for optional parameters not provided but allowed
for param in possible_answer:
if param not in model_params and "" not in possible_answer[param]:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Optional parameter {repr(param)} not provided and not marked as optional."
)
result["error_type"] = "simple_function_checker:missing_optional"
return result
return result
def parallel_function_checker_enforce_order(
func_descriptions: list,
model_output: list,
possible_answers: dict,
language: str,
model_name: str,
):
if len(model_output) != len(possible_answers):
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "parallel_function_checker_enforce_order:wrong_count",
}
func_name_list = list(possible_answers.keys())
possible_answers_list = []
for key, value in possible_answers.items():
possible_answers_list.append({key: value})
for i in range(len(possible_answers_list)):
func_description = find_description(func_descriptions, func_name_list[i])
result = simple_function_checker(
func_description,
model_output[i],
possible_answers_list[i],
language,
model_name,
)
if not result["valid"]:
return result
return {"valid": True, "error": []}
def parallel_function_checker_no_order(
func_descriptions: list,
model_output: list,
possible_answers: list,
language: str,
model_name: str,
):
if len(model_output) != len(possible_answers):
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "parallel_function_checker_no_order:wrong_count",
}
matched_indices = []
# We go throught the possible answers one by one, and eliminate the model output that matches the possible answer
# It must be this way because we need ground truth to fetch the correct function description
for i in range(len(possible_answers)):
# possible_answers[i] is a dictionary with only one key
func_name_expected = list(possible_answers[i].keys())[0]
func_description = find_description(func_descriptions, func_name_expected)
all_errors = []
for index in range(len(model_output)):
if index in matched_indices:
continue
result = simple_function_checker(
func_description,
model_output[index],
possible_answers[i],
language,
model_name,
)
if result["valid"]:
matched_indices.append(index)
break
else:
all_errors.append(
{
f"Model Result Index {index}": {
"sub_error": result["error"],
"sub_error_type": result["error_type"],
"model_output_item": model_output[index],
"possible_answer_item": possible_answers[i],
}
}
)
if not result["valid"]:
considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
all_errors.insert(
0,
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
)
return {
"valid": False,
"error": all_errors,
"error_type": "parallel_function_checker_no_order:cannot_find_match",
}
return {"valid": True, "error": []}
def multiple_function_checker(
func_descriptions: list,
model_output: list,
possible_answers: list,
language: str,
model_name: str,
):
if len(model_output) != len(possible_answers):
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "multiple_function_checker:wrong_count",
}
# possible_answers is a list of only one dictionary with only one key
func_name_expected = list(possible_answers[0].keys())[0]
func_description = find_description(func_descriptions, func_name_expected)
return simple_function_checker(
func_description,
model_output[0],
possible_answers[0],
language,
model_name,
)
def patten_matcher(exec_output, expected_result, function_call, is_sanity_check):
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
if type(exec_output) != type(expected_result):
return {
"valid": False,
"error": [
f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}."
],
"error_type": "executable_checker:wrong_result_type",
"model_executed_output": exec_output,
}
if type(exec_output) == dict:
# We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one.
# This happens when the key is a timestamp or a random number.
if is_sanity_check:
if len(exec_output) != len(expected_result):
return {
"valid": False,
"error": [
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
],
"error_type": "executable_checker:wrong_result_type:dict_length",
"model_executed_output": exec_output,
}
else:
return result
for key, value in expected_result.items():
if key not in exec_output:
return {
"valid": False,
"error": [
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output."
],
"error_type": "executable_checker:wrong_result_type:dict_key_not_found",
"model_executed_output": exec_output,
}
for key, value in exec_output.items():
if key not in expected_result:
return {
"valid": False,
"error": [
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output."
],
"error_type": "executable_checker:wrong_result_type:dict_extra_key",
"model_executed_output": exec_output,
}
if type(exec_output) == list:
if len(exec_output) != len(expected_result):
return {
"valid": False,
"error": [
f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
],
"error_type": "executable_checker:wrong_result_type:list_length",
"model_executed_output": exec_output,
}
return result
#### Helper functions for Exec ####
def executable_checker_simple(
function_call: str,
expected_result,
expected_result_type: str,
is_sanity_check=False,
):
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
exec_dict: Any = {}
try:
exec(
"from executable_python_function import *" + "\nresult=" + function_call,
exec_dict,
)
exec_output = exec_dict["result"]
except NoAPIKeyError as e:
raise e
except Exception as e:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Error in execution: {repr(function_call)}. Error: {str(e)}"
)
result["error_type"] = "executable_checker:execution_error"
return result
# We need to special handle the case where the execution result is a tuple and convert it to a list
# Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json
if isinstance(exec_output, tuple):
exec_output = list(exec_output)
if expected_result_type == "exact_match":
if exec_output != expected_result:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}."
)
result["error_type"] = "executable_checker:wrong_result"
result["model_executed_output"] = exec_output
return result
elif expected_result_type == "real_time_match":
# Allow for 5% difference
if (type(expected_result) == float or type(expected_result) == int) and (
type(exec_output) == float or type(exec_output) == int
):
if not (
expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
<= exec_output
<= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
):
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed."
)
result["error_type"] = "executable_checker:wrong_result_real_time"
result["model_executed_output"] = exec_output
return result
else:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria."
)
result["error_type"] = "executable_checker:wrong_result_real_time"
result["model_executed_output"] = exec_output
return result
else:
# structural match
pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
if not pattern_match_result["valid"]:
return pattern_match_result
return result
def executable_checker_parallel_no_order(
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
):
if len(decoded_result) != len(expected_exec_result):
return {
"valid": False,
"error": [
f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}."
],
"error_type": "value_error:exec_result_count",
}
matched_indices = []
for i in range(len(expected_exec_result)):
all_errors = []
for index in range(len(decoded_result)):
if index in matched_indices:
continue
result = executable_checker_simple(
decoded_result[index],
expected_exec_result[i],
expected_exec_result_type[i],
False,
)
if result["valid"]:
matched_indices.append(index)
break
else:
all_errors.append(
{
f"Model Result Index {index}": {
"sub_error": result["error"],
"sub_error_type": result["error_type"],
"model_executed_output": (
result["model_executed_output"] if "model_executed_output" in result else None
),
}
}
)
if not result["valid"]:
considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
all_errors.insert(
0,
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
)
return {
"valid": False,
"error": all_errors,
"error_type": "executable_checker:cannot_find_match",
}
return {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
#### Main function ####
def executable_checker_rest(func_call, idx):
# Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used.
EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution
with open(EVAL_GROUND_TRUTH_PATH, "r") as f:
EVAL_GROUND_TRUTH = f.readlines()
if "https://geocode.maps.co" in func_call:
time.sleep(2)
if "requests_get" in func_call:
func_call = func_call.replace("requests_get", "requests.get")
try:
response = eval(func_call)
except Exception as e:
return {
"valid": False,
"error": [f"Execution failed. {str(e)}"],
"error_type": "executable_checker_rest:execution_error",
}
try:
if response.status_code == 200:
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
try:
if isinstance(eval_GT_json, dict):
if isinstance(response.json(), dict):
if set(eval_GT_json.keys()) == set(response.json().keys()):
return {"valid": True, "error": [], "error_type": ""}
return {
"valid": False,
"error": ["Key inconsistency"],
"error_type": "executable_checker_rest:wrong_key",
}
return {
"valid": False,
"error": [f"Expected dictionary, but got {type(response.json())}"],
"error_type": "executable_checker_rest:wrong_type",
}
elif isinstance(eval_GT_json, list):
if isinstance(response.json(), list):
if len(eval_GT_json) != len(response.json()):
return {
"valid": False,
"error": [f"Response list length inconsistency."],
"error_type": "value_error:exec_result_rest_count",
}
else:
for i in range(len(eval_GT_json)):
if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
return {
"valid": False,
"error": [f"Key inconsistency"],
"error_type": "executable_checker_rest:wrong_key",
}
return {"valid": True, "error": []}
else:
return {
"valid": False,
"error": [f"Expected list, but got {type(response.json())}"],
"error_type": "executable_checker_rest:wrong_type",
}
return {
"valid": False,
"error": [f"Expected dict or list, but got {type(response.json())}"],
"error_type": "executable_checker_rest:wrong_type",
}
except Exception as e:
return {
"valid": False,
"error": [
f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}"
],
"error_type": "executable_checker_rest:response_format_error",
}
else:
return {
"valid": False,
"error": [f"Execution result status code is not 200, got {response.status_code}"],
"error_type": "executable_checker_rest:wrong_status_code",
}
except Exception as e:
return {
"valid": False,
"error": [f"Cannot get status code of the response. Error: {str(e)}"],
"error_type": "executable_checker_rest:cannot_get_status_code",
}
def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
if "parallel" in test_category:
return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
elif "multiple" in test_category:
return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
else:
if len(model_output) != 1:
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "simple_function_checker:wrong_count",
}
return simple_function_checker(
func_description[0],
model_output[0],
possible_answer[0],
language,
model_name,
)
def exec_checker(decoded_result: list, func_description: dict, test_category: str):
if "multiple" in test_category or "parallel" in test_category:
return executable_checker_parallel_no_order(
decoded_result,
func_description["execution_result"],
func_description["execution_result_type"],
)
else:
if len(decoded_result) != 1:
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "simple_exec_checker:wrong_count",
}
return executable_checker_simple(
decoded_result[0],
func_description["execution_result"][0],
func_description["execution_result_type"][0],
False,
)
def is_empty_output(decoded_output):
# This function is a patch to the ast decoder for relevance detection
# Sometimes the ast decoder will parse successfully, but the input doens't really have a function call
# [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct)
if not is_function_calling_format_output(decoded_output):
return True
if len(decoded_output) == 0:
return True
if len(decoded_output) == 1 and len(decoded_output[0]) == 0:
return True
def is_function_calling_format_output(decoded_output):
# Ensure the output is a list of dictionaries
if type(decoded_output) == list:
for item in decoded_output:
if type(item) != dict:
return False
return True
return False

View file

@ -1,40 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Tree-sitter changes its API with unfortunate frequency. Modules that need it should
import it from here so that we can centrally manage things as necessary.
"""
# These currently work with tree-sitter 0.23.0
# NOTE: Don't import tree-sitter or any of the language modules in the main module
# because not all environments have them. Import lazily inside functions where needed.
import importlib
import typing
if typing.TYPE_CHECKING:
import tree_sitter
def get_language(language: str) -> "tree_sitter.Language":
import tree_sitter
language_module_name = f"tree_sitter_{language}"
try:
language_module = importlib.import_module(language_module_name)
except ModuleNotFoundError as exc:
raise ValueError(
f"Language {language} is not found. Please install the tree-sitter-{language} package."
) from exc
return tree_sitter.Language(language_module.language())
def get_parser(language: str, **kwargs) -> "tree_sitter.Parser":
import tree_sitter
lang = get_language(language)
return tree_sitter.Parser(lang, **kwargs)

View file

@ -1,330 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Sequence
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
# from minerva
SUBSTITUTIONS = [
("an ", ""),
("a ", ""),
(".$", "$"),
("\\$", ""),
(r"\ ", ""),
(" ", ""),
("mbox", "text"),
(",\\text{and}", ","),
("\\text{and}", ","),
("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
"square",
"ways",
"integers",
"dollars",
"mph",
"inches",
"ft",
"hours",
"km",
"units",
"\\ldots",
"sue",
"points",
"feet",
"minutes",
"digits",
"cents",
"degrees",
"cm",
"gm",
"pounds",
"meters",
"meals",
"edges",
"students",
"childrentickets",
"multiples",
"\\text{s}",
"\\text{.}",
"\\text{\ns}",
"\\text{}^2",
"\\text{}^3",
"\\text{\n}",
"\\text{}",
r"\mathrm{th}",
r"^\circ",
r"^{\circ}",
r"\;",
r",\!",
"{,}",
'"',
"\\dots",
]
def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str:
if isinstance(expression, float):
return expression
new_expression = f"{expression}"
regex = re.compile(r"\\frac{([^}]+)}{([^}]+)}")
for match in re.finditer(regex, expression):
try:
value = float(match.group(1)) / float(match.group(2))
new_expression = new_expression.replace(
match.group(),
f"{{value:{fmt}}}".format(value=value),
1,
)
except Exception:
continue
return new_expression
def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
try:
with time_limit(seconds=5):
from sympy.parsing.latex import parse_latex
value = parse_latex(expression).evalf() # type: ignore
return f"{{value:{fmt}}}".format(value=value)
except Exception:
return expression
def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str:
for marker in markers:
text = text.split(marker)[0]
return text
def extract_result_from_boxed(answer: str) -> str:
box_start = "\\boxed"
# format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
start = answer.rfind(box_start)
if start < 0:
return ""
answer = answer[start + len(box_start) :].strip()
ends_with_curly = answer.startswith("{")
i = 0
open_braces = 0
while i < len(answer):
if answer[i] == "{":
open_braces += 1
elif answer[i] == "}":
open_braces -= 1
if open_braces == 0:
if ends_with_curly:
answer = answer[: i + 1].strip()
break
elif answer[i] == "$":
answer = answer[:i].strip()
break
i += 1
else:
return ""
# remove extra curly braces
while True:
if answer.startswith("{") and answer.endswith("}"):
answer = answer[1:-1].strip()
else:
break
return answer
# from minerva paper + _normalise_result from xavierm
def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str:
"""Extract and normalize a final answer to a quantitative reasoning question."""
match = re.findall(regex_pattern, final_answer)
extraction: str
if len(match) > 0:
if match_first:
extraction = match[0]
else:
extraction = match[-1]
else:
extraction = extract_result_from_boxed(final_answer)
if len(extraction) == 0:
return final_answer
else:
final_answer = extraction
final_answer = final_answer.split("=")[-1]
for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, "")
# Extract answer that is in LaTeX math, is bold,
# is surrounded by a box, etc.
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
# Normalize shorthand TeX:
# \fracab -> \frac{a}{b}
# \frac{abc}{bef} -> \frac{abc}{bef}
# \fracabc -> \frac{a}{b}c
# \sqrta -> \sqrt{a}
# \sqrtab -> sqrt{a}b
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
final_answer = final_answer.replace("$", "")
# Normalize 100,000 -> 100000
if final_answer.replace(",", "").isdigit():
final_answer = final_answer.replace(",", "")
# If the final answer is a single letter in parentheses, remove the parentheses
# Example: (a) -> a (but not (ab) -> ab)
if re.match(r"\([a-zA-Z]\)", final_answer):
final_answer = final_answer[1]
return _normalise_result(final_answer)
def _normalise_result(string: str) -> str:
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("cfrac", "frac")
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\le", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = _remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace(r"\%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
string = string.split("=")[-1]
# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def _remove_right_units(string: str) -> str:
# "\\text{ " only ever occurs (at least in the val set) when describing units
try:
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
except AssertionError:
return string
def _fix_sqrt(string: str) -> str:
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if len(split) == 0:
return string
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def _fix_fracs(string: str) -> str:
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if len(substr) == 0:
return string
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string: str) -> str:
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
ia = int(a)
ib = int(b)
assert string == "{}/{}".format(ia, ib)
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
return new_string
except (ValueError, AssertionError):
return string

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api
from .config import BraintrustScoringConfig
class BraintrustProviderDataValidator(BaseModel):
openai_api_key: str
async def get_provider_impl(
config: BraintrustScoringConfig,
deps: Dict[Api, Any],
):
from .braintrust import BraintrustScoringImpl
impl = BraintrustScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
await impl.initialize()
return impl

View file

@ -1,232 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, List, Optional
from autoevals.llm import Factuality
from autoevals.ragas import (
AnswerCorrectness,
AnswerRelevancy,
AnswerSimilarity,
ContextEntityRecall,
ContextPrecision,
ContextRecall,
ContextRelevancy,
Faithfulness,
)
from pydantic import BaseModel
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringResult,
ScoringResultRow,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
validate_row_schema,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
from .scoring_fn.fn_defs.answer_similarity import answer_similarity_fn_def
from .scoring_fn.fn_defs.context_entity_recall import context_entity_recall_fn_def
from .scoring_fn.fn_defs.context_precision import context_precision_fn_def
from .scoring_fn.fn_defs.context_recall import context_recall_fn_def
from .scoring_fn.fn_defs.context_relevancy import context_relevancy_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def
from .scoring_fn.fn_defs.faithfulness import faithfulness_fn_def
class BraintrustScoringFnEntry(BaseModel):
identifier: str
evaluator: Any
fn_def: ScoringFn
SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [
BraintrustScoringFnEntry(
identifier="braintrust::factuality",
evaluator=Factuality(),
fn_def=factuality_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-correctness",
evaluator=AnswerCorrectness(),
fn_def=answer_correctness_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-relevancy",
evaluator=AnswerRelevancy(),
fn_def=answer_relevancy_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-similarity",
evaluator=AnswerSimilarity(),
fn_def=answer_similarity_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::faithfulness",
evaluator=Faithfulness(),
fn_def=faithfulness_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-entity-recall",
evaluator=ContextEntityRecall(),
fn_def=context_entity_recall_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-precision",
evaluator=ContextPrecision(),
fn_def=context_precision_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-recall",
evaluator=ContextRecall(),
fn_def=context_recall_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-relevancy",
evaluator=ContextRelevancy(),
fn_def=context_relevancy_fn_def,
),
]
class BraintrustScoringImpl(
Scoring,
ScoringFunctionsProtocolPrivate,
NeedsRequestProviderData,
):
def __init__(
self,
config: BraintrustScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.braintrust_evaluators = {
entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
self.supported_fn_defs_registry = {
entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
for f in scoring_fn_defs_list:
assert f.identifier.startswith("braintrust"), (
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
)
return scoring_fn_defs_list
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
raise NotImplementedError("Registering scoring function not allowed for braintrust provider")
async def set_api_key(self) -> None:
# api key is in the request headers
if not self.config.openai_api_key:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.openai_api_key:
raise ValueError(
'Pass OpenAI API Key in the header X-LlamaStack-Provider-Data as { "openai_api_key": <your api key>}'
)
self.config.openai_api_key = provider_data.openai_api_key
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
res = await self.score(input_rows=all_rows.data, scoring_functions=scoring_functions)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res.results,
)
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
input_query = input_row["input_query"]
evaluator = self.braintrust_evaluators[scoring_fn_identifier]
result = evaluator(
generated_answer,
expected_answer,
input=input_query,
context=input_row["context"] if "context" in input_row else None,
)
score = result.score
return {"score": score, "metadata": result.metadata}
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse:
await self.set_api_key()
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in self.supported_fn_defs_registry:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
# override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None:
override_params = scoring_functions[scoring_fn_id]
if override_params.aggregation_functions:
aggregation_functions = override_params.aggregation_functions
agg_results = aggregate_metrics(score_results, aggregation_functions)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
)
return ScoreResponse(
results=res,
)

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
class BraintrustScoringConfig(BaseModel):
openai_api_key: Optional[str] = Field(
default=None,
description="The OpenAI API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"openai_api_key": "${env.OPENAI_API_KEY:}",
}

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,24 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_correctness_fn_def = ScoringFn(
identifier="braintrust::answer-correctness",
description=(
"Scores the correctness of the answer based on the ground truth. "
"Uses Braintrust LLM-based scorer from autoevals library."
),
provider_id="braintrust",
provider_resource_id="answer-correctness",
return_type=NumberType(),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -1,24 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_relevancy_fn_def = ScoringFn(
identifier="braintrust::answer-relevancy",
description=(
"Test output relevancy against the input query using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="answer-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -1,24 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_similarity_fn_def = ScoringFn(
identifier="braintrust::answer-similarity",
description=(
"Test output similarity against expected value using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="answer-similarity",
return_type=NumberType(),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -1,24 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_entity_recall_fn_def = ScoringFn(
identifier="braintrust::context-entity-recall",
description=(
"Evaluates how well the context captures the named entities present in the "
"reference answer. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-entity-recall",
return_type=NumberType(),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -1,24 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_precision_fn_def = ScoringFn(
identifier="braintrust::context-precision",
description=(
"Measures how much of the provided context is actually relevant to answering the "
"question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-precision",
return_type=NumberType(),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -1,24 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_recall_fn_def = ScoringFn(
identifier="braintrust::context-recall",
description=(
"Evaluates how well the context covers the information needed to answer the "
"question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-recall",
return_type=NumberType(),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_relevancy_fn_def = ScoringFn(
identifier="braintrust::context-relevancy",
description=(
"Assesses how relevant the provided context is to the given question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -1,24 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
factuality_fn_def = ScoringFn(
identifier="braintrust::factuality",
description=(
"Test output factuality against expected value using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="factuality",
return_type=NumberType(),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -1,24 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
faithfulness_fn_def = ScoringFn(
identifier="braintrust::faithfulness",
description=(
"Test output faithfulness to the input query using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="faithfulness",
return_type=NumberType(),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api
from .config import LlmAsJudgeScoringConfig
async def get_provider_impl(
config: LlmAsJudgeScoringConfig,
deps: Dict[Api, Any],
):
from .scoring import LlmAsJudgeScoringImpl
impl = LlmAsJudgeScoringImpl(config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference])
await impl.initialize()
return impl

View file

@ -1,14 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class LlmAsJudgeScoringConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {}

View file

@ -1,110 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
)
from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FN = LlmAsJudgeScoringFn
class LlmAsJudgeScoringImpl(
Scoring,
ScoringFunctionsProtocolPrivate,
):
def __init__(
self,
config: LlmAsJudgeScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
inference_api: Inference,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.inference_api = inference_api
async def initialize(self) -> None:
impl = LLM_JUDGE_FN(inference_api=self.inference_api)
self.llm_as_judge_fn = impl
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
assert f.identifier.startswith("llm-as-judge"), (
"All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
)
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFn) -> None:
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
res = await self.score(
input_rows=all_rows.data,
scoring_functions=scoring_functions,
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res.results,
)
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions.keys():
scoring_fn = self.llm_as_judge_fn
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
)
return ScoreResponse(
results=res,
)

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,96 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
LLMAsJudgeScoringFnParams,
ScoringFn,
)
GRADER_TEMPLATE = """
Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"].
First, I will give examples of each grade, and then you will grade a new example.
The following are examples of CORRECT predicted answers.
```
Question: What are the names of Barack Obama's children?
Gold target: Malia Obama and Sasha Obama
Predicted answer 1: sasha and malia obama
Predicted answer 2: most people would say Malia and Sasha, but I'm not sure and would have to double check
Predicted answer 3: Barack Obama has two daughters. Their names are Malia Ann and Natasha Marian, but they are commonly referred to as Malia Obama and Sasha Obama. Malia was born on July 4, 1998, and Sasha was born on June 10, 2001.
```
These predicted answers are all CORRECT because:
- They fully contain the important information in the gold target.
- They do not contain any information that contradicts the gold target.
- Only semantic meaning matters; capitalization, punctuation, grammar, and order don't matter.
- Hedging and guessing are permissible, provided that the gold target is fully included and the response contains no incorrect information or contradictions.
The following are examples of INCORRECT predicted answers.
```
Question: What are the names of Barack Obama's children?
Gold target: Malia and Sasha
Predicted answer 1: Malia.
Predicted answer 2: Malia, Sasha, and Susan.
Predicted answer 3: Barack Obama does not have any children.
Predicted answer 4: I think it's either Malia and Sasha. Or it could be Malia and Jackie. Or it could be Joey and Malia.
Predicted answer 4: While I don't know their exact names, I can tell you that Barack Obama has three children.
Predicted answer 5: It's possible you may mean Betsy and Olivia. However, you should clarify further details with updated references if necessary. Is that the correct answer?
Predicted answer 6: It may be the case that Obama's child is named James. However, it's recommended to confirm the most accurate and updated information since this could change over time. This model may not always reflect the most current information.
```
These predicted answers are all INCORRECT because:
- A factual statement in the answer contradicts the gold target. Incorrect statements that have some hedging (e.g., "it is possible that", "although i'm not sure, i think") are also considered incorrect.
The following are examples of NOT_ATTEMPTED predicted answers.
```
Question: What are the names of Barack Obama's children?
Gold target: Malia and Sasha
Predicted answer 1: I don't know.
Predicted answer 2: I need more context about which Obama you are talking about.
Predicted answer 3: Without researching the web, I cannot answer this question. However, I can tell you that Barack Obama has two children.
Predicted answer 4: Barack Obama has two children. I know that one of them is Malia, but I'm not sure about the other one.
```
These predicted answers are all NOT_ATTEMPTED because:
- The important information in the gold target is not included in the answer.
- No statements in the answer contradict the gold target.
Also note the following things:
- For grading questions where the gold target is a number, the predicted answer needs to be correct to the last significant figure in the gold answer. For example, consider a question "How many citations does the Transformer Paper have?" with gold target "120k".
- Predicted answers "120k", "124k", and 115k" are all CORRECT.
- Predicted answers "100k" and "113k" are INCORRECT.
- Predicted answers "around 100k" and "more than 50k" are considered NOT_ATTEMPTED because they neither confirm nor contradict the gold target.
- The gold target may contain more information than the question. In such cases, the predicted answer only needs to contain the information that is in the question.
- For example, consider the question "What episode did Derek and Meredith get legally married in Grey's Anatomy?" with gold target "Season 7, Episode 20: White Wedding". Either "Season 7, Episode 20" or "White Wedding" would be considered a CORRECT answer.
- Do not punish predicted answers if they omit information that would be clearly inferred from the question.
- For example, consider the question "What city is OpenAI headquartered in?" and the gold target "San Francisco, California". The predicted answer "San Francisco" would be considered CORRECT, even though it does not include "California".
- Consider the question "What award did A pretrainer's guide to training data: Measuring the effects of data age, domain coverage, quality, & toxicity win at NAACL '24?", the gold target is "Outstanding Paper Award". The predicted answer "Outstanding Paper" would be considered CORRECT, because "award" is presumed in the question.
- For the question "What is the height of Jason Wei in meters?", the gold target is "1.73 m". The predicted answer "1.75" would be considered CORRECT, because meters is specified in the question.
- For the question "What is the name of Barack Obama's wife?", the gold target is "Michelle Obama". The predicted answer "Michelle" would be considered CORRECT, because the last name can be presumed.
- Do not punish for typos in people's name if it's clearly the same name.
- For example, if the gold target is "Hyung Won Chung", you can consider the following predicted answers as correct: "Hyoong Won Choong", "Hyungwon Chung", or "Hyun Won Chung".
Here is a new example. Simply reply with either CORRECT, INCORRECT, NOT ATTEMPTED. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
```
Question: {input_query}
Gold target: {expected_answer}
Predicted answer: {generated_answer}
```
Grade the predicted answer of this new question as one of:
A: CORRECT
B: INCORRECT
C: NOT_ATTEMPTED
Just return the letters "A", "B", or "C", with no text around it.
""".strip()
llm_as_judge_405b_simpleqa = ScoringFn(
identifier="llm-as-judge::405b-simpleqa",
description="Llm As Judge Scoring Function for SimpleQA Benchmark (https://github.com/openai/simple-evals/blob/main/simpleqa_eval.py)",
return_type=NumberType(),
provider_id="llm-as-judge",
provider_resource_id="llm-as-judge-405b-simpleqa",
params=LLMAsJudgeScoringFnParams(
judge_model="meta-llama/Llama-3.1-405B-Instruct",
prompt_template=GRADER_TEMPLATE,
judge_score_regexes=[r"(A|B|C)"],
aggregation_functions=[AggregationFunctionType.categorical_count.value],
),
)

View file

@ -1,20 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
llm_as_judge_base = ScoringFn(
identifier="llm-as-judge::base",
description="Llm As Judge Scoring Function",
return_type=NumberType(),
provider_id="llm-as-judge",
provider_resource_id="llm-as-judge-base",
params=LLMAsJudgeScoringFnParams(
judge_model="meta-llama/Llama-3.1-405B-Instruct",
prompt_template="Enter custom LLM as Judge Prompt Template",
),
)

View file

@ -1,79 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from llama_stack.apis.inference.inference import Inference, UserMessage
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
from .fn_defs.llm_as_judge_base import llm_as_judge_base
class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns
"""
def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
super().__init__(*arg, **kwargs)
self.inference_api = inference_api
self.supported_fn_defs_registry = {
llm_as_judge_base.identifier: llm_as_judge_base,
llm_as_judge_405b_simpleqa.identifier: llm_as_judge_405b_simpleqa,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
# override params if scoring_params is provided
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
assert fn_def.params.prompt_template is not None, "LLM Judge prompt_template not found."
assert fn_def.params.judge_score_regexes is not None, "LLM Judge judge_score_regexes not found."
input_query = input_row["input_query"]
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
judge_input_msg = fn_def.params.prompt_template.format(
input_query=input_query,
expected_answer=expected_answer,
generated_answer=generated_answer,
)
judge_response = await self.inference_api.chat_completion(
model_id=fn_def.params.judge_model,
messages=[
UserMessage(
content=judge_input_msg,
),
],
)
content = judge_response.completion_message.content
rating_regexes = fn_def.params.judge_score_regexes
judge_rating = None
for regex in rating_regexes:
match = re.search(regex, content)
if match:
judge_rating = match.group(1)
break
return {
"score": judge_rating,
"judge_feedback": content,
}

View file

@ -7,22 +7,28 @@
from typing import List
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_stack.providers.utils.kvstore import kvstore_dependencies
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.eval,
api=Api.evaluation,
provider_type="inline::meta-reference",
pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
module="llama_stack.providers.inline.eval.meta_reference",
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
pip_packages=[
"matplotlib",
"pillow",
"pandas",
"scikit-learn",
]
+ kvstore_dependencies(),
module="llama_stack.providers.inline.evaluation.meta_reference",
config_class="llama_stack.providers.inline.evaluation.meta_reference.MetaReferenceEvaluationConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
Api.scoring,
Api.inference,
Api.agents,
Api.datasets,
Api.datasetio,
],
),
]

View file

@ -1,49 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::basic",
pip_packages=[],
module="llama_stack.providers.inline.scoring.basic",
config_class="llama_stack.providers.inline.scoring.basic.BasicScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::llm-as-judge",
pip_packages=[],
module="llama_stack.providers.inline.scoring.llm_as_judge",
config_class="llama_stack.providers.inline.scoring.llm_as_judge.LlmAsJudgeScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
Api.inference,
],
),
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::braintrust",
pip_packages=["autoevals", "openai"],
module="llama_stack.providers.inline.scoring.braintrust",
config_class="llama_stack.providers.inline.scoring.braintrust.BraintrustScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
provider_data_validator="llama_stack.providers.inline.scoring.braintrust.BraintrustProviderDataValidator",
),
]

View file

@ -5,14 +5,12 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
CompletionInputType,
StringType,
)
from llama_stack.distribution.datatypes import Api
class ColumnName(Enum):
@ -75,29 +73,31 @@ VALID_SCHEMAS_FOR_EVAL = [
]
def get_valid_schemas(api_str: str):
if api_str == Api.scoring.value:
return VALID_SCHEMAS_FOR_SCORING
elif api_str == Api.eval.value:
return VALID_SCHEMAS_FOR_EVAL
else:
raise ValueError(f"Invalid API string: {api_str}")
# TODO(xiyan): add this back
# def get_valid_schemas(api_str: str):
# if api_str == Api.scoring.value:
# return VALID_SCHEMAS_FOR_SCORING
# elif api_str == Api.eval.value:
# return VALID_SCHEMAS_FOR_EVAL
# else:
# raise ValueError(f"Invalid API string: {api_str}")
def validate_dataset_schema(
dataset_schema: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
if dataset_schema not in expected_schemas:
raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
# def validate_dataset_schema(
# dataset_schema: Dict[str, Any],
# expected_schemas: List[Dict[str, Any]],
# ):
# if dataset_schema not in expected_schemas:
# raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
def validate_row_schema(
input_row: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
for schema in expected_schemas:
if all(key in input_row for key in schema):
return
# def validate_row_schema(
# input_row: Dict[str, Any],
# expected_schemas: List[Dict[str, Any]],
# ):
# for schema in expected_schemas:
# if all(key in input_row for key in schema):
# return
raise ValueError(f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}")
# raise ValueError(f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}")

View file

@ -23,9 +23,7 @@ def get_distribution_template() -> DistributionTemplate:
"safety": ["remote::bedrock"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",

View file

@ -14,15 +14,9 @@ distribution_spec:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search

View file

@ -3,10 +3,8 @@ image_name: bedrock
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
@ -42,14 +40,6 @@ providers:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/bedrock/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
@ -65,17 +55,6 @@ providers:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
@ -133,7 +112,6 @@ models:
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch

View file

@ -13,15 +13,9 @@ distribution_spec:
- remote::pgvector
agents:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
telemetry:
- inline::meta-reference
tool_runtime:

View file

@ -27,9 +27,7 @@ def get_distribution_template() -> DistributionTemplate:
"safety": ["inline::llama-guard"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"agents": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"telemetry": ["inline::meta-reference"],
"tool_runtime": [
"remote::brave-search",

View file

@ -3,10 +3,8 @@ image_name: cerebras
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
@ -41,14 +39,6 @@ providers:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
@ -64,17 +54,6 @@ providers:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -131,7 +110,6 @@ models:
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch

View file

@ -15,15 +15,9 @@ distribution_spec:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search

Some files were not shown because too many files have changed in this diff Show more