diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py
index f9f56119b..dbfc90452 100644
--- a/docs/openapi_generator/generate.py
+++ b/docs/openapi_generator/generate.py
@@ -49,6 +49,7 @@ from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
+from llama_stack.apis.eval_tasks import * # noqa: F403
class LlamaStack(
@@ -63,6 +64,7 @@ class LlamaStack(
PostTraining,
Memory,
Eval,
+ EvalTasks,
Scoring,
ScoringFunctions,
DatasetIO,
diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 363d968f9..8156039a9 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
- "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905"
+ "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-11 13:59:59.544511"
},
"servers": [
{
@@ -469,7 +469,7 @@
}
}
},
- "/eval/evaluate": {
+ "/eval/evaluate_rows": {
"post": {
"responses": {
"200": {
@@ -501,47 +501,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/EvaluateRequest"
- }
- }
- },
- "required": true
- }
- }
- },
- "/eval/evaluate_batch": {
- "post": {
- "responses": {
- "200": {
- "description": "OK",
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/Job"
- }
- }
- }
- }
- },
- "tags": [
- "Eval"
- ],
- "parameters": [
- {
- "name": "X-LlamaStack-ProviderData",
- "in": "header",
- "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
- "required": false,
- "schema": {
- "type": "string"
- }
- }
- ],
- "requestBody": {
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/EvaluateBatchRequest"
+ "$ref": "#/components/schemas/EvaluateRowsRequest"
}
}
},
@@ -766,6 +726,51 @@
]
}
},
+ "/eval_tasks/get": {
+ "get": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/EvalTaskDefWithProvider"
+ },
+ {
+ "type": "null"
+ }
+ ]
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ "EvalTasks"
+ ],
+ "parameters": [
+ {
+ "name": "name",
+ "in": "query",
+ "required": true,
+ "schema": {
+ "type": "string"
+ }
+ },
+ {
+ "name": "X-LlamaStack-ProviderData",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ]
+ }
+ },
"/memory_banks/get": {
"get": {
"responses": {
@@ -834,7 +839,7 @@
"schema": {
"oneOf": [
{
- "$ref": "#/components/schemas/ModelDefWithProvider"
+ "$ref": "#/components/schemas/Model"
},
{
"type": "null"
@@ -986,7 +991,7 @@
"schema": {
"oneOf": [
{
- "$ref": "#/components/schemas/ShieldDefWithProvider"
+ "$ref": "#/components/schemas/Shield"
},
{
"type": "null"
@@ -1002,7 +1007,7 @@
],
"parameters": [
{
- "name": "shield_type",
+ "name": "identifier",
"in": "query",
"required": true,
"schema": {
@@ -1317,6 +1322,14 @@
"Eval"
],
"parameters": [
+ {
+ "name": "task_id",
+ "in": "query",
+ "required": true,
+ "schema": {
+ "type": "string"
+ }
+ },
{
"name": "job_id",
"in": "query",
@@ -1362,6 +1375,14 @@
"Eval"
],
"parameters": [
+ {
+ "name": "task_id",
+ "in": "query",
+ "required": true,
+ "schema": {
+ "type": "string"
+ }
+ },
{
"name": "job_id",
"in": "query",
@@ -1412,6 +1433,36 @@
]
}
},
+ "/eval_tasks/list": {
+ "get": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/jsonl": {
+ "schema": {
+ "$ref": "#/components/schemas/EvalTaskDefWithProvider"
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ "EvalTasks"
+ ],
+ "parameters": [
+ {
+ "name": "X-LlamaStack-ProviderData",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ]
+ }
+ },
"/memory_banks/list": {
"get": {
"responses": {
@@ -1463,7 +1514,7 @@
"content": {
"application/jsonl": {
"schema": {
- "$ref": "#/components/schemas/ModelDefWithProvider"
+ "$ref": "#/components/schemas/Model"
}
}
}
@@ -1592,7 +1643,7 @@
"content": {
"application/jsonl": {
"schema": {
- "$ref": "#/components/schemas/ShieldDefWithProvider"
+ "$ref": "#/components/schemas/Shield"
}
}
}
@@ -1760,6 +1811,39 @@
}
}
},
+ "/eval_tasks/register": {
+ "post": {
+ "responses": {
+ "200": {
+ "description": "OK"
+ }
+ },
+ "tags": [
+ "EvalTasks"
+ ],
+ "parameters": [
+ {
+ "name": "X-LlamaStack-ProviderData",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/RegisterEvalTaskRequest"
+ }
+ }
+ },
+ "required": true
+ }
+ }
+ },
"/memory_banks/register": {
"post": {
"responses": {
@@ -1797,7 +1881,14 @@
"post": {
"responses": {
"200": {
- "description": "OK"
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Model"
+ }
+ }
+ }
}
},
"tags": [
@@ -1863,7 +1954,14 @@
"post": {
"responses": {
"200": {
- "description": "OK"
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Shield"
+ }
+ }
+ }
}
},
"tags": [
@@ -1892,6 +1990,46 @@
}
}
},
+ "/eval/run_eval": {
+ "post": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Job"
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ "Eval"
+ ],
+ "parameters": [
+ {
+ "name": "X-LlamaStack-ProviderData",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/RunEvalRequest"
+ }
+ }
+ },
+ "required": true
+ }
+ }
+ },
"/safety/run_shield": {
"post": {
"responses": {
@@ -4490,6 +4628,103 @@
"config"
]
},
+ "AppEvalTaskConfig": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "app",
+ "default": "app"
+ },
+ "eval_candidate": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/ModelCandidate"
+ },
+ {
+ "$ref": "#/components/schemas/AgentCandidate"
+ }
+ ]
+ },
+ "scoring_params": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams"
+ },
+ {
+ "$ref": "#/components/schemas/RegexParserScoringFnParams"
+ }
+ ]
+ }
+ },
+ "num_examples": {
+ "type": "integer"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "eval_candidate",
+ "scoring_params"
+ ]
+ },
+ "BenchmarkEvalTaskConfig": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "benchmark",
+ "default": "benchmark"
+ },
+ "eval_candidate": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/ModelCandidate"
+ },
+ {
+ "$ref": "#/components/schemas/AgentCandidate"
+ }
+ ]
+ },
+ "num_examples": {
+ "type": "integer"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "eval_candidate"
+ ]
+ },
+ "LLMAsJudgeScoringFnParams": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "llm_as_judge",
+ "default": "llm_as_judge"
+ },
+ "judge_model": {
+ "type": "string"
+ },
+ "prompt_template": {
+ "type": "string"
+ },
+ "judge_score_regexes": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "judge_model"
+ ]
+ },
"ModelCandidate": {
"type": "object",
"properties": {
@@ -4515,9 +4750,32 @@
"sampling_params"
]
},
- "EvaluateRequest": {
+ "RegexParserScoringFnParams": {
"type": "object",
"properties": {
+ "type": {
+ "type": "string",
+ "const": "regex_parser",
+ "default": "regex_parser"
+ },
+ "parsing_regexes": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type"
+ ]
+ },
+ "EvaluateRowsRequest": {
+ "type": "object",
+ "properties": {
+ "task_id": {
+ "type": "string"
+ },
"input_rows": {
"type": "array",
"items": {
@@ -4546,28 +4804,29 @@
}
}
},
- "candidate": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/ModelCandidate"
- },
- {
- "$ref": "#/components/schemas/AgentCandidate"
- }
- ]
- },
"scoring_functions": {
"type": "array",
"items": {
"type": "string"
}
+ },
+ "task_config": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/BenchmarkEvalTaskConfig"
+ },
+ {
+ "$ref": "#/components/schemas/AppEvalTaskConfig"
+ }
+ ]
}
},
"additionalProperties": false,
"required": [
+ "task_id",
"input_rows",
- "candidate",
- "scoring_functions"
+ "scoring_functions",
+ "task_config"
]
},
"EvaluateResponse": {
@@ -4677,48 +4936,6 @@
"aggregated_results"
]
},
- "EvaluateBatchRequest": {
- "type": "object",
- "properties": {
- "dataset_id": {
- "type": "string"
- },
- "candidate": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/ModelCandidate"
- },
- {
- "$ref": "#/components/schemas/AgentCandidate"
- }
- ]
- },
- "scoring_functions": {
- "type": "array",
- "items": {
- "type": "string"
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "dataset_id",
- "candidate",
- "scoring_functions"
- ]
- },
- "Job": {
- "type": "object",
- "properties": {
- "job_id": {
- "type": "string"
- }
- },
- "additionalProperties": false,
- "required": [
- "job_id"
- ]
- },
"GetAgentsSessionRequest": {
"type": "object",
"properties": {
@@ -5085,6 +5302,11 @@
]
}
},
+ "type": {
+ "type": "string",
+ "const": "dataset",
+ "default": "dataset"
+ },
"provider_id": {
"type": "string"
}
@@ -5095,18 +5317,25 @@
"dataset_schema",
"url",
"metadata",
+ "type",
"provider_id"
]
},
- "ModelDefWithProvider": {
+ "EvalTaskDefWithProvider": {
"type": "object",
"properties": {
"identifier": {
"type": "string"
},
- "llama_model": {
+ "dataset_id": {
"type": "string"
},
+ "scoring_functions": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
"metadata": {
"type": "object",
"additionalProperties": {
@@ -5132,6 +5361,11 @@
]
}
},
+ "type": {
+ "type": "string",
+ "const": "eval_task",
+ "default": "eval_task"
+ },
"provider_id": {
"type": "string"
}
@@ -5139,11 +5373,65 @@
"additionalProperties": false,
"required": [
"identifier",
- "llama_model",
+ "dataset_id",
+ "scoring_functions",
"metadata",
+ "type",
"provider_id"
]
},
+ "Model": {
+ "type": "object",
+ "properties": {
+ "identifier": {
+ "type": "string"
+ },
+ "provider_resource_id": {
+ "type": "string"
+ },
+ "provider_id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "model",
+ "default": "model"
+ },
+ "metadata": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "identifier",
+ "provider_resource_id",
+ "provider_id",
+ "type",
+ "metadata"
+ ]
+ },
"PaginatedRowsResult": {
"type": "object",
"properties": {
@@ -5188,166 +5476,6 @@
"total_count"
]
},
- "Parameter": {
- "type": "object",
- "properties": {
- "name": {
- "type": "string"
- },
- "type": {
- "oneOf": [
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "string",
- "default": "string"
- }
- },
- "additionalProperties": false,
- "required": [
- "type"
- ]
- },
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "number",
- "default": "number"
- }
- },
- "additionalProperties": false,
- "required": [
- "type"
- ]
- },
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "boolean",
- "default": "boolean"
- }
- },
- "additionalProperties": false,
- "required": [
- "type"
- ]
- },
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "array",
- "default": "array"
- }
- },
- "additionalProperties": false,
- "required": [
- "type"
- ]
- },
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "object",
- "default": "object"
- }
- },
- "additionalProperties": false,
- "required": [
- "type"
- ]
- },
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "json",
- "default": "json"
- }
- },
- "additionalProperties": false,
- "required": [
- "type"
- ]
- },
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "union",
- "default": "union"
- }
- },
- "additionalProperties": false,
- "required": [
- "type"
- ]
- },
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "chat_completion_input",
- "default": "chat_completion_input"
- }
- },
- "additionalProperties": false,
- "required": [
- "type"
- ]
- },
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "completion_input",
- "default": "completion_input"
- }
- },
- "additionalProperties": false,
- "required": [
- "type"
- ]
- },
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "agent_turn_input",
- "default": "agent_turn_input"
- }
- },
- "additionalProperties": false,
- "required": [
- "type"
- ]
- }
- ]
- },
- "description": {
- "type": "string"
- }
- },
- "additionalProperties": false,
- "required": [
- "name",
- "type"
- ]
- },
"ScoringFnDefWithProvider": {
"type": "object",
"properties": {
@@ -5382,12 +5510,6 @@
]
}
},
- "parameters": {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/Parameter"
- }
- },
"return_type": {
"oneOf": [
{
@@ -5532,27 +5654,21 @@
}
]
},
- "context": {
- "type": "object",
- "properties": {
- "judge_model": {
- "type": "string"
+ "params": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams"
},
- "prompt_template": {
- "type": "string"
- },
- "judge_score_regex": {
- "type": "array",
- "items": {
- "type": "string"
- }
+ {
+ "$ref": "#/components/schemas/RegexParserScoringFnParams"
}
- },
- "additionalProperties": false,
- "required": [
- "judge_model"
]
},
+ "type": {
+ "type": "string",
+ "const": "scoring_fn",
+ "default": "scoring_fn"
+ },
"provider_id": {
"type": "string"
}
@@ -5561,20 +5677,31 @@
"required": [
"identifier",
"metadata",
- "parameters",
"return_type",
+ "type",
"provider_id"
]
},
- "ShieldDefWithProvider": {
+ "Shield": {
"type": "object",
"properties": {
"identifier": {
"type": "string"
},
- "type": {
+ "provider_resource_id": {
"type": "string"
},
+ "provider_id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "shield",
+ "default": "shield"
+ },
+ "shield_type": {
+ "$ref": "#/components/schemas/ShieldType"
+ },
"params": {
"type": "object",
"additionalProperties": {
@@ -5599,17 +5726,26 @@
}
]
}
- },
- "provider_id": {
- "type": "string"
}
},
"additionalProperties": false,
"required": [
"identifier",
+ "provider_resource_id",
+ "provider_id",
"type",
- "params",
- "provider_id"
+ "shield_type",
+ "params"
+ ],
+ "title": "A safety shield resource that can be used to check content"
+ },
+ "ShieldType": {
+ "type": "string",
+ "enum": [
+ "generic_content_shield",
+ "llama_guard",
+ "code_scanner",
+ "prompt_guard"
]
},
"Trace": {
@@ -5867,12 +6003,16 @@
"JobCancelRequest": {
"type": "object",
"properties": {
+ "task_id": {
+ "type": "string"
+ },
"job_id": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
+ "task_id",
"job_id"
]
},
@@ -6514,6 +6654,18 @@
"dataset_def"
]
},
+ "RegisterEvalTaskRequest": {
+ "type": "object",
+ "properties": {
+ "eval_task_def": {
+ "$ref": "#/components/schemas/EvalTaskDefWithProvider"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "eval_task_def"
+ ]
+ },
"RegisterMemoryBankRequest": {
"type": "object",
"properties": {
@@ -6542,13 +6694,44 @@
"RegisterModelRequest": {
"type": "object",
"properties": {
- "model": {
- "$ref": "#/components/schemas/ModelDefWithProvider"
+ "model_id": {
+ "type": "string"
+ },
+ "provider_model_id": {
+ "type": "string"
+ },
+ "provider_id": {
+ "type": "string"
+ },
+ "metadata": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
}
},
"additionalProperties": false,
"required": [
- "model"
+ "model_id"
]
},
"RegisterScoringFunctionRequest": {
@@ -6566,19 +6749,89 @@
"RegisterShieldRequest": {
"type": "object",
"properties": {
- "shield": {
- "$ref": "#/components/schemas/ShieldDefWithProvider"
+ "shield_id": {
+ "type": "string"
+ },
+ "shield_type": {
+ "$ref": "#/components/schemas/ShieldType"
+ },
+ "provider_shield_id": {
+ "type": "string"
+ },
+ "provider_id": {
+ "type": "string"
+ },
+ "params": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
}
},
"additionalProperties": false,
"required": [
- "shield"
+ "shield_id",
+ "shield_type"
+ ]
+ },
+ "RunEvalRequest": {
+ "type": "object",
+ "properties": {
+ "task_id": {
+ "type": "string"
+ },
+ "task_config": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/BenchmarkEvalTaskConfig"
+ },
+ {
+ "$ref": "#/components/schemas/AppEvalTaskConfig"
+ }
+ ]
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "task_id",
+ "task_config"
+ ]
+ },
+ "Job": {
+ "type": "object",
+ "properties": {
+ "job_id": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "job_id"
]
},
"RunShieldRequest": {
"type": "object",
"properties": {
- "shield_type": {
+ "shield_id": {
"type": "string"
},
"messages": {
@@ -6628,7 +6881,7 @@
},
"additionalProperties": false,
"required": [
- "shield_type",
+ "shield_id",
"messages",
"params"
]
@@ -6674,9 +6927,23 @@
}
},
"scoring_functions": {
- "type": "array",
- "items": {
- "type": "string"
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams"
+ },
+ {
+ "$ref": "#/components/schemas/RegexParserScoringFnParams"
+ }
+ ]
+ },
+ {
+ "type": "null"
+ }
+ ]
}
}
},
@@ -6708,9 +6975,23 @@
"type": "string"
},
"scoring_functions": {
- "type": "array",
- "items": {
- "type": "string"
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams"
+ },
+ {
+ "$ref": "#/components/schemas/RegexParserScoringFnParams"
+ }
+ ]
+ },
+ {
+ "type": "null"
+ }
+ ]
}
},
"save_results_dataset": {
@@ -7063,56 +7344,59 @@
],
"tags": [
{
- "name": "Memory"
- },
- {
- "name": "Inference"
- },
- {
- "name": "Eval"
- },
- {
- "name": "MemoryBanks"
- },
- {
- "name": "Models"
- },
- {
- "name": "BatchInference"
- },
- {
- "name": "PostTraining"
- },
- {
- "name": "Agents"
- },
- {
- "name": "Shields"
- },
- {
- "name": "Telemetry"
- },
- {
- "name": "Inspect"
- },
- {
- "name": "DatasetIO"
- },
- {
- "name": "SyntheticDataGeneration"
+ "name": "ScoringFunctions"
},
{
"name": "Datasets"
},
{
- "name": "Scoring"
- },
- {
- "name": "ScoringFunctions"
+ "name": "Inspect"
},
{
"name": "Safety"
},
+ {
+ "name": "Eval"
+ },
+ {
+ "name": "Inference"
+ },
+ {
+ "name": "BatchInference"
+ },
+ {
+ "name": "Agents"
+ },
+ {
+ "name": "PostTraining"
+ },
+ {
+ "name": "Shields"
+ },
+ {
+ "name": "Memory"
+ },
+ {
+ "name": "Scoring"
+ },
+ {
+ "name": "SyntheticDataGeneration"
+ },
+ {
+ "name": "EvalTasks"
+ },
+ {
+ "name": "MemoryBanks"
+ },
+ {
+ "name": "DatasetIO"
+ },
+ {
+ "name": "Models"
+ },
+ {
+ "name": "Telemetry"
+ },
{
"name": "BuiltinTool",
"description": ""
@@ -7377,13 +7661,29 @@
"name": "AgentCandidate",
"description": ""
},
+ {
+ "name": "AppEvalTaskConfig",
+ "description": ""
+ },
+ {
+ "name": "BenchmarkEvalTaskConfig",
+ "description": ""
+ },
+ {
+ "name": "LLMAsJudgeScoringFnParams",
+ "description": ""
+ },
{
"name": "ModelCandidate",
"description": ""
},
{
- "name": "EvaluateRequest",
- "description": ""
+ "name": "RegexParserScoringFnParams",
+ "description": ""
+ },
+ {
+ "name": "EvaluateRowsRequest",
+ "description": ""
},
{
"name": "EvaluateResponse",
@@ -7393,14 +7693,6 @@
"name": "ScoringResult",
"description": ""
},
- {
- "name": "EvaluateBatchRequest",
- "description": ""
- },
- {
- "name": "Job",
- "description": ""
- },
{
"name": "GetAgentsSessionRequest",
"description": ""
@@ -7434,24 +7726,28 @@
"description": ""
},
{
- "name": "ModelDefWithProvider",
- "description": ""
+ "name": "EvalTaskDefWithProvider",
+ "description": ""
+ },
+ {
+ "name": "Model",
+ "description": ""
},
{
"name": "PaginatedRowsResult",
"description": ""
},
- {
- "name": "Parameter",
- "description": ""
- },
{
"name": "ScoringFnDefWithProvider",
"description": ""
},
{
- "name": "ShieldDefWithProvider",
- "description": ""
+ "name": "Shield",
+ "description": "A safety shield resource that can be used to check content\n\n"
+ },
+ {
+ "name": "ShieldType",
+ "description": ""
},
{
"name": "Trace",
@@ -7573,6 +7869,10 @@
"name": "RegisterDatasetRequest",
"description": ""
},
+ {
+ "name": "RegisterEvalTaskRequest",
+ "description": ""
+ },
{
"name": "RegisterMemoryBankRequest",
"description": ""
@@ -7589,6 +7889,14 @@
"name": "RegisterShieldRequest",
"description": ""
},
+ {
+ "name": "RunEvalRequest",
+ "description": ""
+ },
+ {
+ "name": "Job",
+ "description": ""
+ },
{
"name": "RunShieldRequest",
"description": ""
@@ -7651,6 +7959,7 @@
"DatasetIO",
"Datasets",
"Eval",
+ "EvalTasks",
"Inference",
"Inspect",
"Memory",
@@ -7680,11 +7989,13 @@
"AgentTurnResponseStreamChunk",
"AgentTurnResponseTurnCompletePayload",
"AgentTurnResponseTurnStartPayload",
+ "AppEvalTaskConfig",
"Attachment",
"BatchChatCompletionRequest",
"BatchChatCompletionResponse",
"BatchCompletionRequest",
"BatchCompletionResponse",
+ "BenchmarkEvalTaskConfig",
"BuiltinTool",
"CancelTrainingJobRequest",
"ChatCompletionRequest",
@@ -7708,9 +8019,9 @@
"DoraFinetuningConfig",
"EmbeddingsRequest",
"EmbeddingsResponse",
- "EvaluateBatchRequest",
- "EvaluateRequest",
+ "EvalTaskDefWithProvider",
"EvaluateResponse",
+ "EvaluateRowsRequest",
"FinetuningAlgorithm",
"FunctionCallToolDefinition",
"GetAgentsSessionRequest",
@@ -7724,6 +8035,7 @@
"JobStatus",
"KeyValueMemoryBankDef",
"KeywordMemoryBankDef",
+ "LLMAsJudgeScoringFnParams",
"LogEventRequest",
"LogSeverity",
"LoraFinetuningConfig",
@@ -7731,11 +8043,10 @@
"MemoryRetrievalStep",
"MemoryToolDefinition",
"MetricEvent",
+ "Model",
"ModelCandidate",
- "ModelDefWithProvider",
"OptimizerConfig",
"PaginatedRowsResult",
- "Parameter",
"PhotogenToolDefinition",
"PostTrainingJob",
"PostTrainingJobArtifactsResponse",
@@ -7748,7 +8059,9 @@
"QueryDocumentsRequest",
"QueryDocumentsResponse",
"RLHFAlgorithm",
+ "RegexParserScoringFnParams",
"RegisterDatasetRequest",
+ "RegisterEvalTaskRequest",
"RegisterMemoryBankRequest",
"RegisterModelRequest",
"RegisterScoringFunctionRequest",
@@ -7756,6 +8069,7 @@
"RestAPIExecutionConfig",
"RestAPIMethod",
"RouteInfo",
+ "RunEvalRequest",
"RunShieldRequest",
"RunShieldResponse",
"SafetyViolation",
@@ -7769,8 +8083,9 @@
"ScoringResult",
"SearchToolDefinition",
"Session",
+ "Shield",
"ShieldCallStep",
- "ShieldDefWithProvider",
+ "ShieldType",
"SpanEndPayload",
"SpanStartPayload",
"SpanStatus",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 7dd231965..0e6571301 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -218,6 +218,30 @@ components:
- event_type
- turn_id
type: object
+ AppEvalTaskConfig:
+ additionalProperties: false
+ properties:
+ eval_candidate:
+ oneOf:
+ - $ref: '#/components/schemas/ModelCandidate'
+ - $ref: '#/components/schemas/AgentCandidate'
+ num_examples:
+ type: integer
+ scoring_params:
+ additionalProperties:
+ oneOf:
+ - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
+ - $ref: '#/components/schemas/RegexParserScoringFnParams'
+ type: object
+ type:
+ const: app
+ default: app
+ type: string
+ required:
+ - type
+ - eval_candidate
+ - scoring_params
+ type: object
Attachment:
additionalProperties: false
properties:
@@ -322,6 +346,23 @@ components:
required:
- completion_message_batch
type: object
+ BenchmarkEvalTaskConfig:
+ additionalProperties: false
+ properties:
+ eval_candidate:
+ oneOf:
+ - $ref: '#/components/schemas/ModelCandidate'
+ - $ref: '#/components/schemas/AgentCandidate'
+ num_examples:
+ type: integer
+ type:
+ const: benchmark
+ default: benchmark
+ type: string
+ required:
+ - type
+ - eval_candidate
+ type: object
BuiltinTool:
enum:
- brave_search
@@ -790,6 +831,10 @@ components:
type: object
provider_id:
type: string
+ type:
+ const: dataset
+ default: dataset
+ type: string
url:
$ref: '#/components/schemas/URL'
required:
@@ -797,6 +842,7 @@ components:
- dataset_schema
- url
- metadata
+ - type
- provider_id
type: object
DeleteAgentsRequest:
@@ -872,51 +918,40 @@ components:
required:
- embeddings
type: object
- EvaluateBatchRequest:
+ EvalTaskDefWithProvider:
additionalProperties: false
properties:
- candidate:
- oneOf:
- - $ref: '#/components/schemas/ModelCandidate'
- - $ref: '#/components/schemas/AgentCandidate'
dataset_id:
type: string
+ identifier:
+ type: string
+ metadata:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ provider_id:
+ type: string
scoring_functions:
items:
type: string
type: array
+ type:
+ const: eval_task
+ default: eval_task
+ type: string
required:
+ - identifier
- dataset_id
- - candidate
- - scoring_functions
- type: object
- EvaluateRequest:
- additionalProperties: false
- properties:
- candidate:
- oneOf:
- - $ref: '#/components/schemas/ModelCandidate'
- - $ref: '#/components/schemas/AgentCandidate'
- input_rows:
- items:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- type: array
- scoring_functions:
- items:
- type: string
- type: array
- required:
- - input_rows
- - candidate
- scoring_functions
+ - metadata
+ - type
+ - provider_id
type: object
EvaluateResponse:
additionalProperties: false
@@ -941,6 +976,37 @@ components:
- generations
- scores
type: object
+ EvaluateRowsRequest:
+ additionalProperties: false
+ properties:
+ input_rows:
+ items:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ type: array
+ scoring_functions:
+ items:
+ type: string
+ type: array
+ task_config:
+ oneOf:
+ - $ref: '#/components/schemas/BenchmarkEvalTaskConfig'
+ - $ref: '#/components/schemas/AppEvalTaskConfig'
+ task_id:
+ type: string
+ required:
+ - task_id
+ - input_rows
+ - scoring_functions
+ - task_config
+ type: object
FinetuningAlgorithm:
enum:
- full
@@ -1082,7 +1148,10 @@ components:
properties:
job_id:
type: string
+ task_id:
+ type: string
required:
+ - task_id
- job_id
type: object
JobStatus:
@@ -1124,6 +1193,25 @@ components:
- provider_id
- type
type: object
+ LLMAsJudgeScoringFnParams:
+ additionalProperties: false
+ properties:
+ judge_model:
+ type: string
+ judge_score_regexes:
+ items:
+ type: string
+ type: array
+ prompt_template:
+ type: string
+ type:
+ const: llm_as_judge
+ default: llm_as_judge
+ type: string
+ required:
+ - type
+ - judge_model
+ type: object
LogEventRequest:
additionalProperties: false
properties:
@@ -1405,6 +1493,36 @@ components:
- value
- unit
type: object
+ Model:
+ additionalProperties: false
+ properties:
+ identifier:
+ type: string
+ metadata:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ provider_id:
+ type: string
+ provider_resource_id:
+ type: string
+ type:
+ const: model
+ default: model
+ type: string
+ required:
+ - identifier
+ - provider_resource_id
+ - provider_id
+ - type
+ - metadata
+ type: object
ModelCandidate:
additionalProperties: false
properties:
@@ -1423,31 +1541,6 @@ components:
- model
- sampling_params
type: object
- ModelDefWithProvider:
- additionalProperties: false
- properties:
- identifier:
- type: string
- llama_model:
- type: string
- metadata:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- provider_id:
- type: string
- required:
- - identifier
- - llama_model
- - metadata
- - provider_id
- type: object
OptimizerConfig:
additionalProperties: false
properties:
@@ -1492,109 +1585,6 @@ components:
- rows
- total_count
type: object
- Parameter:
- additionalProperties: false
- properties:
- description:
- type: string
- name:
- type: string
- type:
- oneOf:
- - additionalProperties: false
- properties:
- type:
- const: string
- default: string
- type: string
- required:
- - type
- type: object
- - additionalProperties: false
- properties:
- type:
- const: number
- default: number
- type: string
- required:
- - type
- type: object
- - additionalProperties: false
- properties:
- type:
- const: boolean
- default: boolean
- type: string
- required:
- - type
- type: object
- - additionalProperties: false
- properties:
- type:
- const: array
- default: array
- type: string
- required:
- - type
- type: object
- - additionalProperties: false
- properties:
- type:
- const: object
- default: object
- type: string
- required:
- - type
- type: object
- - additionalProperties: false
- properties:
- type:
- const: json
- default: json
- type: string
- required:
- - type
- type: object
- - additionalProperties: false
- properties:
- type:
- const: union
- default: union
- type: string
- required:
- - type
- type: object
- - additionalProperties: false
- properties:
- type:
- const: chat_completion_input
- default: chat_completion_input
- type: string
- required:
- - type
- type: object
- - additionalProperties: false
- properties:
- type:
- const: completion_input
- default: completion_input
- type: string
- required:
- - type
- type: object
- - additionalProperties: false
- properties:
- type:
- const: agent_turn_input
- default: agent_turn_input
- type: string
- required:
- - type
- type: object
- required:
- - name
- - type
- type: object
PhotogenToolDefinition:
additionalProperties: false
properties:
@@ -1844,6 +1834,20 @@ components:
enum:
- dpo
type: string
+ RegexParserScoringFnParams:
+ additionalProperties: false
+ properties:
+ parsing_regexes:
+ items:
+ type: string
+ type: array
+ type:
+ const: regex_parser
+ default: regex_parser
+ type: string
+ required:
+ - type
+ type: object
RegisterDatasetRequest:
additionalProperties: false
properties:
@@ -1852,6 +1856,14 @@ components:
required:
- dataset_def
type: object
+ RegisterEvalTaskRequest:
+ additionalProperties: false
+ properties:
+ eval_task_def:
+ $ref: '#/components/schemas/EvalTaskDefWithProvider'
+ required:
+ - eval_task_def
+ type: object
RegisterMemoryBankRequest:
additionalProperties: false
properties:
@@ -1867,10 +1879,24 @@ components:
RegisterModelRequest:
additionalProperties: false
properties:
- model:
- $ref: '#/components/schemas/ModelDefWithProvider'
+ metadata:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ model_id:
+ type: string
+ provider_id:
+ type: string
+ provider_model_id:
+ type: string
required:
- - model
+ - model_id
type: object
RegisterScoringFunctionRequest:
additionalProperties: false
@@ -1883,10 +1909,27 @@ components:
RegisterShieldRequest:
additionalProperties: false
properties:
- shield:
- $ref: '#/components/schemas/ShieldDefWithProvider'
+ params:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ provider_id:
+ type: string
+ provider_shield_id:
+ type: string
+ shield_id:
+ type: string
+ shield_type:
+ $ref: '#/components/schemas/ShieldType'
required:
- - shield
+ - shield_id
+ - shield_type
type: object
RestAPIExecutionConfig:
additionalProperties: false
@@ -1952,6 +1995,19 @@ components:
- method
- provider_types
type: object
+ RunEvalRequest:
+ additionalProperties: false
+ properties:
+ task_config:
+ oneOf:
+ - $ref: '#/components/schemas/BenchmarkEvalTaskConfig'
+ - $ref: '#/components/schemas/AppEvalTaskConfig'
+ task_id:
+ type: string
+ required:
+ - task_id
+ - task_config
+ type: object
RunShieldRequest:
additionalProperties: false
properties:
@@ -1973,10 +2029,10 @@ components:
- type: array
- type: object
type: object
- shield_type:
+ shield_id:
type: string
required:
- - shield_type
+ - shield_id
- messages
- params
type: object
@@ -2045,9 +2101,13 @@ components:
save_results_dataset:
type: boolean
scoring_functions:
- items:
- type: string
- type: array
+ additionalProperties:
+ oneOf:
+ - oneOf:
+ - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
+ - $ref: '#/components/schemas/RegexParserScoringFnParams'
+ - type: 'null'
+ type: object
required:
- dataset_id
- scoring_functions
@@ -2081,9 +2141,13 @@ components:
type: object
type: array
scoring_functions:
- items:
- type: string
- type: array
+ additionalProperties:
+ oneOf:
+ - oneOf:
+ - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
+ - $ref: '#/components/schemas/RegexParserScoringFnParams'
+ - type: 'null'
+ type: object
required:
- input_rows
- scoring_functions
@@ -2101,20 +2165,6 @@ components:
ScoringFnDefWithProvider:
additionalProperties: false
properties:
- context:
- additionalProperties: false
- properties:
- judge_model:
- type: string
- judge_score_regex:
- items:
- type: string
- type: array
- prompt_template:
- type: string
- required:
- - judge_model
- type: object
description:
type: string
identifier:
@@ -2129,10 +2179,10 @@ components:
- type: array
- type: object
type: object
- parameters:
- items:
- $ref: '#/components/schemas/Parameter'
- type: array
+ params:
+ oneOf:
+ - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
+ - $ref: '#/components/schemas/RegexParserScoringFnParams'
provider_id:
type: string
return_type:
@@ -2227,11 +2277,15 @@ components:
required:
- type
type: object
+ type:
+ const: scoring_fn
+ default: scoring_fn
+ type: string
required:
- identifier
- metadata
- - parameters
- return_type
+ - type
- provider_id
type: object
ScoringResult:
@@ -2320,6 +2374,40 @@ components:
- started_at
title: A single session of an interaction with an Agentic System.
type: object
+ Shield:
+ additionalProperties: false
+ properties:
+ identifier:
+ type: string
+ params:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ provider_id:
+ type: string
+ provider_resource_id:
+ type: string
+ shield_type:
+ $ref: '#/components/schemas/ShieldType'
+ type:
+ const: shield
+ default: shield
+ type: string
+ required:
+ - identifier
+ - provider_resource_id
+ - provider_id
+ - type
+ - shield_type
+ - params
+ title: A safety shield resource that can be used to check content
+ type: object
ShieldCallStep:
additionalProperties: false
properties:
@@ -2344,31 +2432,13 @@ components:
- step_id
- step_type
type: object
- ShieldDefWithProvider:
- additionalProperties: false
- properties:
- identifier:
- type: string
- params:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- provider_id:
- type: string
- type:
- type: string
- required:
- - identifier
- - type
- - params
- - provider_id
- type: object
+ ShieldType:
+ enum:
+ - generic_content_shield
+ - llama_guard
+ - code_scanner
+ - prompt_guard
+ type: string
SpanEndPayload:
additionalProperties: false
properties:
@@ -2998,7 +3068,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
- \ draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905"
+ \ draft and subject to change.\n Generated at 2024-11-11 13:59:59.544511"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@@ -3387,7 +3457,7 @@ paths:
description: OK
tags:
- Datasets
- /eval/evaluate:
+ /eval/evaluate_rows:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
@@ -3401,7 +3471,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/EvaluateRequest'
+ $ref: '#/components/schemas/EvaluateRowsRequest'
required: true
responses:
'200':
@@ -3412,31 +3482,6 @@ paths:
description: OK
tags:
- Eval
- /eval/evaluate_batch:
- post:
- parameters:
- - description: JSON-encoded provider data which will be made available to the
- adapter servicing the API
- in: header
- name: X-LlamaStack-ProviderData
- required: false
- schema:
- type: string
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/EvaluateBatchRequest'
- required: true
- responses:
- '200':
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/Job'
- description: OK
- tags:
- - Eval
/eval/job/cancel:
post:
parameters:
@@ -3461,6 +3506,11 @@ paths:
/eval/job/result:
get:
parameters:
+ - in: query
+ name: task_id
+ required: true
+ schema:
+ type: string
- in: query
name: job_id
required: true
@@ -3485,6 +3535,11 @@ paths:
/eval/job/status:
get:
parameters:
+ - in: query
+ name: task_id
+ required: true
+ schema:
+ type: string
- in: query
name: job_id
required: true
@@ -3508,6 +3563,97 @@ paths:
description: OK
tags:
- Eval
+ /eval/run_eval:
+ post:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-ProviderData
+ required: false
+ schema:
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/RunEvalRequest'
+ required: true
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/Job'
+ description: OK
+ tags:
+ - Eval
+ /eval_tasks/get:
+ get:
+ parameters:
+ - in: query
+ name: name
+ required: true
+ schema:
+ type: string
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-ProviderData
+ required: false
+ schema:
+ type: string
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ oneOf:
+ - $ref: '#/components/schemas/EvalTaskDefWithProvider'
+ - type: 'null'
+ description: OK
+ tags:
+ - EvalTasks
+ /eval_tasks/list:
+ get:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-ProviderData
+ required: false
+ schema:
+ type: string
+ responses:
+ '200':
+ content:
+ application/jsonl:
+ schema:
+ $ref: '#/components/schemas/EvalTaskDefWithProvider'
+ description: OK
+ tags:
+ - EvalTasks
+ /eval_tasks/register:
+ post:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-ProviderData
+ required: false
+ schema:
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/RegisterEvalTaskRequest'
+ required: true
+ responses:
+ '200':
+ description: OK
+ tags:
+ - EvalTasks
/health:
get:
parameters:
@@ -3747,7 +3893,7 @@ paths:
application/json:
schema:
oneOf:
- - $ref: '#/components/schemas/ModelDefWithProvider'
+ - $ref: '#/components/schemas/Model'
- type: 'null'
description: OK
tags:
@@ -3767,7 +3913,7 @@ paths:
content:
application/jsonl:
schema:
- $ref: '#/components/schemas/ModelDefWithProvider'
+ $ref: '#/components/schemas/Model'
description: OK
tags:
- Models
@@ -3789,6 +3935,10 @@ paths:
required: true
responses:
'200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/Model'
description: OK
tags:
- Models
@@ -4143,7 +4293,7 @@ paths:
get:
parameters:
- in: query
- name: shield_type
+ name: identifier
required: true
schema:
type: string
@@ -4160,7 +4310,7 @@ paths:
application/json:
schema:
oneOf:
- - $ref: '#/components/schemas/ShieldDefWithProvider'
+ - $ref: '#/components/schemas/Shield'
- type: 'null'
description: OK
tags:
@@ -4180,7 +4330,7 @@ paths:
content:
application/jsonl:
schema:
- $ref: '#/components/schemas/ShieldDefWithProvider'
+ $ref: '#/components/schemas/Shield'
description: OK
tags:
- Shields
@@ -4202,6 +4352,10 @@ paths:
required: true
responses:
'200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/Shield'
description: OK
tags:
- Shields
@@ -4280,23 +4434,24 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
-- name: Memory
-- name: Inference
-- name: Eval
-- name: MemoryBanks
-- name: Models
-- name: BatchInference
-- name: PostTraining
-- name: Agents
-- name: Shields
-- name: Telemetry
-- name: Inspect
-- name: DatasetIO
-- name: SyntheticDataGeneration
-- name: Datasets
-- name: Scoring
- name: ScoringFunctions
+- name: Datasets
+- name: Inspect
- name: Safety
+- name: Eval
+- name: Inference
+- name: BatchInference
+- name: Agents
+- name: PostTraining
+- name: Shields
+- name: Memory
+- name: Scoring
+- name: SyntheticDataGeneration
+- name: EvalTasks
+- name: MemoryBanks
+- name: DatasetIO
+- name: Models
+- name: Telemetry
- description:
name: BuiltinTool
- description:
name: AgentCandidate
+- description:
+ name: AppEvalTaskConfig
+- description:
+ name: BenchmarkEvalTaskConfig
+- description:
+ name: LLMAsJudgeScoringFnParams
- description:
name: ModelCandidate
-- description:
- name: EvaluateRequest
+ name: RegexParserScoringFnParams
+- description:
+ name: EvaluateRowsRequest
- description:
name: EvaluateResponse
- description:
name: ScoringResult
-- description:
- name: EvaluateBatchRequest
-- description:
- name: Job
- description:
name: GetAgentsSessionRequest
@@ -4544,20 +4706,24 @@ tags:
- description:
name: DatasetDefWithProvider
-- description:
- name: ModelDefWithProvider
+ name: EvalTaskDefWithProvider
+- description:
+ name: Model
- description:
name: PaginatedRowsResult
-- description:
- name: Parameter
- description:
name: ScoringFnDefWithProvider
-- description:
- name: ShieldDefWithProvider
+- description: 'A safety shield resource that can be used to check content
+
+
+ '
+ name: Shield
+- description:
+ name: ShieldType
- description:
name: Trace
- description: 'Checkpoint created during training runs
@@ -4647,6 +4813,9 @@ tags:
- description:
name: RegisterDatasetRequest
+- description:
+ name: RegisterEvalTaskRequest
- description:
name: RegisterMemoryBankRequest
@@ -4659,6 +4828,10 @@ tags:
- description:
name: RegisterShieldRequest
+- description:
+ name: RunEvalRequest
+- description:
+ name: Job
- description:
name: RunShieldRequest
@@ -4708,6 +4881,7 @@ x-tagGroups:
- DatasetIO
- Datasets
- Eval
+ - EvalTasks
- Inference
- Inspect
- Memory
@@ -4734,11 +4908,13 @@ x-tagGroups:
- AgentTurnResponseStreamChunk
- AgentTurnResponseTurnCompletePayload
- AgentTurnResponseTurnStartPayload
+ - AppEvalTaskConfig
- Attachment
- BatchChatCompletionRequest
- BatchChatCompletionResponse
- BatchCompletionRequest
- BatchCompletionResponse
+ - BenchmarkEvalTaskConfig
- BuiltinTool
- CancelTrainingJobRequest
- ChatCompletionRequest
@@ -4762,9 +4938,9 @@ x-tagGroups:
- DoraFinetuningConfig
- EmbeddingsRequest
- EmbeddingsResponse
- - EvaluateBatchRequest
- - EvaluateRequest
+ - EvalTaskDefWithProvider
- EvaluateResponse
+ - EvaluateRowsRequest
- FinetuningAlgorithm
- FunctionCallToolDefinition
- GetAgentsSessionRequest
@@ -4778,6 +4954,7 @@ x-tagGroups:
- JobStatus
- KeyValueMemoryBankDef
- KeywordMemoryBankDef
+ - LLMAsJudgeScoringFnParams
- LogEventRequest
- LogSeverity
- LoraFinetuningConfig
@@ -4785,11 +4962,10 @@ x-tagGroups:
- MemoryRetrievalStep
- MemoryToolDefinition
- MetricEvent
+ - Model
- ModelCandidate
- - ModelDefWithProvider
- OptimizerConfig
- PaginatedRowsResult
- - Parameter
- PhotogenToolDefinition
- PostTrainingJob
- PostTrainingJobArtifactsResponse
@@ -4802,7 +4978,9 @@ x-tagGroups:
- QueryDocumentsRequest
- QueryDocumentsResponse
- RLHFAlgorithm
+ - RegexParserScoringFnParams
- RegisterDatasetRequest
+ - RegisterEvalTaskRequest
- RegisterMemoryBankRequest
- RegisterModelRequest
- RegisterScoringFunctionRequest
@@ -4810,6 +4988,7 @@ x-tagGroups:
- RestAPIExecutionConfig
- RestAPIMethod
- RouteInfo
+ - RunEvalRequest
- RunShieldRequest
- RunShieldResponse
- SafetyViolation
@@ -4823,8 +5002,9 @@ x-tagGroups:
- ScoringResult
- SearchToolDefinition
- Session
+ - Shield
- ShieldCallStep
- - ShieldDefWithProvider
+ - ShieldType
- SpanEndPayload
- SpanStartPayload
- SpanStatus
diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py
index 50fb922fe..04a5a55d5 100644
--- a/llama_stack/apis/eval/eval.py
+++ b/llama_stack/apis/eval/eval.py
@@ -40,6 +40,10 @@ EvalCandidate = Annotated[
class BenchmarkEvalTaskConfig(BaseModel):
type: Literal["benchmark"] = "benchmark"
eval_candidate: EvalCandidate
+ num_examples: Optional[int] = Field(
+ description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
+ default=None,
+ )
@json_schema_type
@@ -50,6 +54,10 @@ class AppEvalTaskConfig(BaseModel):
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
+ num_examples: Optional[int] = Field(
+ description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
+ default=None,
+ )
# we could optinally add any specific dataset config here
diff --git a/llama_stack/providers/adapters/datasetio/huggingface/__init__.py b/llama_stack/providers/adapters/datasetio/huggingface/__init__.py
new file mode 100644
index 000000000..db803d183
--- /dev/null
+++ b/llama_stack/providers/adapters/datasetio/huggingface/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from .config import HuggingfaceDatasetIOConfig
+
+
+async def get_adapter_impl(
+ config: HuggingfaceDatasetIOConfig,
+ _deps,
+):
+ from .huggingface import HuggingfaceDatasetIOImpl
+
+ impl = HuggingfaceDatasetIOImpl(config)
+ await impl.initialize()
+ return impl
diff --git a/llama_stack/providers/adapters/datasetio/huggingface/config.py b/llama_stack/providers/adapters/datasetio/huggingface/config.py
new file mode 100644
index 000000000..89dbe53a0
--- /dev/null
+++ b/llama_stack/providers/adapters/datasetio/huggingface/config.py
@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+from llama_stack.apis.datasetio import * # noqa: F401, F403
+
+
+class HuggingfaceDatasetIOConfig(BaseModel): ...
diff --git a/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py b/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py
new file mode 100644
index 000000000..598ca5cfd
--- /dev/null
+++ b/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py
@@ -0,0 +1,81 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+from typing import List, Optional
+
+from llama_stack.apis.datasetio import * # noqa: F403
+
+
+import datasets as hf_datasets
+from llama_stack.providers.datatypes import DatasetsProtocolPrivate
+from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
+
+from .config import HuggingfaceDatasetIOConfig
+
+
+def load_hf_dataset(dataset_def: DatasetDef):
+ if dataset_def.metadata.get("path", None):
+ return hf_datasets.load_dataset(**dataset_def.metadata)
+
+ df = get_dataframe_from_url(dataset_def.url)
+
+ if df is None:
+ raise ValueError(f"Failed to load dataset from {dataset_def.url}")
+
+ dataset = hf_datasets.Dataset.from_pandas(df)
+ return dataset
+
+
+class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
+ def __init__(self, config: HuggingfaceDatasetIOConfig) -> None:
+ self.config = config
+ # local registry for keeping track of datasets within the provider
+ self.dataset_infos = {}
+
+ async def initialize(self) -> None:
+ pass
+
+ async def shutdown(self) -> None: ...
+
+ async def register_dataset(
+ self,
+ dataset_def: DatasetDef,
+ ) -> None:
+ self.dataset_infos[dataset_def.identifier] = dataset_def
+
+ async def list_datasets(self) -> List[DatasetDef]:
+ return list(self.dataset_infos.values())
+
+ async def get_rows_paginated(
+ self,
+ dataset_id: str,
+ rows_in_page: int,
+ page_token: Optional[str] = None,
+ filter_condition: Optional[str] = None,
+ ) -> PaginatedRowsResult:
+ dataset_def = self.dataset_infos[dataset_id]
+ loaded_dataset = load_hf_dataset(dataset_def)
+
+ if page_token and not page_token.isnumeric():
+ raise ValueError("Invalid page_token")
+
+ if page_token is None or len(page_token) == 0:
+ next_page_token = 0
+ else:
+ next_page_token = int(page_token)
+
+ start = next_page_token
+ if rows_in_page == -1:
+ end = len(loaded_dataset)
+ else:
+ end = min(start + rows_in_page, len(loaded_dataset))
+
+ rows = [loaded_dataset[i] for i in range(start, end)]
+
+ return PaginatedRowsResult(
+ rows=rows,
+ total_count=len(rows),
+ next_page_token=str(end),
+ )
diff --git a/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py b/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py
index a96d9bcab..a6fe4feb3 100644
--- a/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py
+++ b/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py
@@ -3,20 +3,17 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-import io
from typing import List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
-import base64
from abc import ABC, abstractmethod
from dataclasses import dataclass
-from urllib.parse import unquote
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
-from llama_stack.providers.utils.memory.vector_store import parse_data_url
+from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from .config import MetaReferenceDatasetIOConfig
@@ -73,31 +70,9 @@ class PandasDataframeDataset(BaseDataset):
if self.df is not None:
return
- # TODO: more robust support w/ data url
- if self.dataset_def.url.uri.endswith(".csv"):
- df = pandas.read_csv(self.dataset_def.url.uri)
- elif self.dataset_def.url.uri.endswith(".xlsx"):
- df = pandas.read_excel(self.dataset_def.url.uri)
- elif self.dataset_def.url.uri.startswith("data:"):
- parts = parse_data_url(self.dataset_def.url.uri)
- data = parts["data"]
- if parts["is_base64"]:
- data = base64.b64decode(data)
- else:
- data = unquote(data)
- encoding = parts["encoding"] or "utf-8"
- data = data.encode(encoding)
-
- mime_type = parts["mimetype"]
- mime_category = mime_type.split("/")[0]
- data_bytes = io.BytesIO(data)
-
- if mime_category == "text":
- df = pandas.read_csv(data_bytes)
- else:
- df = pandas.read_excel(data_bytes)
- else:
- raise ValueError(f"Unsupported file type: {self.dataset_def.url}")
+ df = get_dataframe_from_url(self.dataset_def.url)
+ if df is None:
+ raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
self.df = self._validate_dataset_schema(df)
diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py
index 4a61c9d93..48d8e2b04 100644
--- a/llama_stack/providers/inline/meta_reference/eval/eval.py
+++ b/llama_stack/providers/inline/meta_reference/eval/eval.py
@@ -9,6 +9,8 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
+from tqdm import tqdm
+
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTaskDef
@@ -47,7 +49,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
self.eval_tasks = {}
- async def initialize(self) -> None: ...
+ async def initialize(self) -> None:
+ pass
async def shutdown(self) -> None: ...
@@ -93,7 +96,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
- rows_in_page=-1,
+ rows_in_page=(
+ -1 if task_config.num_examples is None else task_config.num_examples
+ ),
)
res = await self.evaluate_rows(
task_id=task_id,
@@ -125,7 +130,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
), "SamplingParams.max_tokens must be provided"
generations = []
- for x in input_rows:
+ for x in tqdm(input_rows):
if ColumnName.completion_input.value in x:
input_content = eval(str(x[ColumnName.completion_input.value]))
response = await self.inference_api.completion(
diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring.py b/llama_stack/providers/inline/meta_reference/scoring/scoring.py
index c4add966d..6370ea5e5 100644
--- a/llama_stack/providers/inline/meta_reference/scoring/scoring.py
+++ b/llama_stack/providers/inline/meta_reference/scoring/scoring.py
@@ -13,21 +13,14 @@ from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
-from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
- EqualityScoringFn,
-)
-
-from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
- LlmAsJudgeScoringFn,
-)
-
-from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
- SubsetOfScoringFn,
-)
from .config import MetaReferenceScoringConfig
+from .scoring_fn.equality_scoring_fn import EqualityScoringFn
+from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
+from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
+from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
-FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]
+FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py
index 99fa6cc3a..b54bf7ae8 100644
--- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py
+++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py
@@ -11,6 +11,5 @@ from llama_stack.apis.scoring_functions import ScoringFnDef
equality = ScoringFnDef(
identifier="meta-reference::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
- parameters=[],
return_type=NumberType(),
)
diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py
index cfef52160..68d77b8df 100644
--- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py
+++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py
@@ -26,7 +26,6 @@ Total rating:
llm_as_judge_8b_correctness = ScoringFnDef(
identifier="meta-reference::llm_as_judge_8b_correctness",
description="Llm As Judge Scoring Function",
- parameters=[],
return_type=NumberType(),
params=LLMAsJudgeScoringFnParams(
prompt_template=JUDGE_PROMPT,
diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py
new file mode 100644
index 000000000..84e518887
--- /dev/null
+++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py
@@ -0,0 +1,69 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from llama_stack.apis.scoring_functions import * # noqa: F401, F403
+from llama_stack.apis.scoring import * # noqa: F401, F403
+from llama_stack.apis.common.type_system import NumberType
+
+MULTILINGUAL_ANSWER_REGEXES = [
+ r"Answer\s*:",
+ r"Answer\s*:", # Korean invisible character
+ r"উত্তর\s*:",
+ r"उत्तर\s*:",
+ r"উত্তরঃ",
+ r"উত্তর\s*:",
+ r"Antwort\s*:",
+ r"답변\s*:",
+ r"정답\s*:",
+ r"답\s*:",
+ r"答案\s*:",
+ r"答案\s*:",
+ r"答\s*:",
+ r"答\s*:",
+ r"答复\s*:",
+ r"答曰\s*:",
+ r"الإجابة:",
+ r"الجواب:",
+ r"إجابة:",
+ r"الإجابة النهائية:",
+ r"الإجابة الصحيحة:",
+ r"الإجابة الصحيحة هي:",
+ r"الإجابة هي:",
+ r"Respuesta\s*:",
+ r"Risposta\s*:",
+ r"答え\s*:",
+ r"答え\s*:",
+ r"回答\s*:",
+ r"回答\s*:",
+ r"解答\s*:",
+ r"Jawaban\s*:",
+ r"Réponse\s*:",
+ r"Resposta\s*:",
+ r"Jibu\s*:",
+ r"Idahun\s*:",
+ r"Ìdáhùn\s*:",
+ r"Idáhùn\s*:",
+ r"Àmọ̀nà\s*:",
+ r"Àdáhùn\s*:",
+ r"Ànúgọ\s*:",
+ r"Àṣàyàn\s*:",
+]
+
+MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
+ r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
+)
+
+regex_parser_multiple_choice_answer = ScoringFnDef(
+ identifier="meta-reference::regex_parser_multiple_choice_answer",
+ description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
+ return_type=NumberType(),
+ params=RegexParserScoringFnParams(
+ parsing_regexes=[
+ MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
+ for x in MULTILINGUAL_ANSWER_REGEXES
+ ],
+ ),
+)
diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py
new file mode 100644
index 000000000..0aff2f535
--- /dev/null
+++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py
@@ -0,0 +1,67 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+import re
+
+from .base_scoring_fn import BaseScoringFn
+from llama_stack.apis.scoring_functions import * # noqa: F401, F403
+from llama_stack.apis.scoring import * # noqa: F401, F403
+from llama_stack.apis.common.type_system import * # noqa: F403
+from .common import aggregate_accuracy
+
+from .fn_defs.regex_parser_multiple_choice_answer import (
+ regex_parser_multiple_choice_answer,
+)
+
+
+class RegexParserScoringFn(BaseScoringFn):
+ """
+ A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.supported_fn_defs_registry = {
+ regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer,
+ }
+
+ async def score_row(
+ self,
+ input_row: Dict[str, Any],
+ scoring_fn_identifier: Optional[str] = None,
+ scoring_params: Optional[ScoringFnParams] = None,
+ ) -> ScoringResultRow:
+ assert (
+ scoring_fn_identifier is not None
+ ), "Scoring function identifier not found."
+ fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
+ if scoring_params is not None:
+ fn_def.params = scoring_params
+
+ assert (
+ fn_def.params is not None
+ and fn_def.params.type == ScoringConfigType.regex_parser.value
+ ), f"RegexParserScoringFnParams not found for {fn_def}."
+
+ expected_answer = input_row["expected_answer"]
+ generated_answer = input_row["generated_answer"]
+
+ # parse answer according to regex
+ parsed_answer = None
+ for regex in fn_def.params.parsing_regexes:
+ match = re.search(regex, generated_answer)
+ if match:
+ parsed_answer = match.group(1)
+ break
+
+ score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
+ return {
+ "score": score,
+ }
+
+ async def aggregate(
+ self, scoring_results: List[ScoringResultRow]
+ ) -> Dict[str, Any]:
+ return aggregate_accuracy(scoring_results)
diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py
index 976bbd448..3fdeac997 100644
--- a/llama_stack/providers/registry/datasetio.py
+++ b/llama_stack/providers/registry/datasetio.py
@@ -19,4 +19,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.inline.meta_reference.datasetio.MetaReferenceDatasetIOConfig",
api_dependencies=[],
),
+ remote_provider_spec(
+ api=Api.datasetio,
+ adapter=AdapterSpec(
+ adapter_type="huggingface",
+ pip_packages=[
+ "datasets",
+ ],
+ module="llama_stack.providers.adapters.datasetio.huggingface",
+ config_class="llama_stack.providers.adapters.datasetio.huggingface.HuggingfaceDatasetIOConfig",
+ ),
+ ),
]
diff --git a/llama_stack/providers/tests/datasetio/fixtures.py b/llama_stack/providers/tests/datasetio/fixtures.py
index 7d7615b55..d810d5e02 100644
--- a/llama_stack/providers/tests/datasetio/fixtures.py
+++ b/llama_stack/providers/tests/datasetio/fixtures.py
@@ -31,7 +31,20 @@ def datasetio_meta_reference() -> ProviderFixture:
)
-DATASETIO_FIXTURES = ["meta_reference", "remote"]
+@pytest.fixture(scope="session")
+def datasetio_huggingface() -> ProviderFixture:
+ return ProviderFixture(
+ providers=[
+ Provider(
+ provider_id="huggingface",
+ provider_type="remote::huggingface",
+ config={},
+ )
+ ],
+ )
+
+
+DATASETIO_FIXTURES = ["meta_reference", "remote", "huggingface"]
@pytest_asyncio.fixture(scope="session")
diff --git a/llama_stack/providers/tests/eval/conftest.py b/llama_stack/providers/tests/eval/conftest.py
index 064feb611..985a8bc37 100644
--- a/llama_stack/providers/tests/eval/conftest.py
+++ b/llama_stack/providers/tests/eval/conftest.py
@@ -34,6 +34,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="meta_reference_eval_together_inference",
marks=pytest.mark.meta_reference_eval_together_inference,
),
+ pytest.param(
+ {
+ "eval": "meta_reference",
+ "scoring": "meta_reference",
+ "datasetio": "huggingface",
+ "inference": "together",
+ },
+ id="meta_reference_eval_together_inference_huggingface_datasetio",
+ marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio,
+ ),
]
@@ -41,6 +51,7 @@ def pytest_configure(config):
for fixture_name in [
"meta_reference_eval_fireworks_inference",
"meta_reference_eval_together_inference",
+ "meta_reference_eval_together_inference_huggingface_datasetio",
]:
config.addinivalue_line(
"markers",
diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py
index a55a754c5..fdd4dcfbb 100644
--- a/llama_stack/providers/tests/eval/test_eval.py
+++ b/llama_stack/providers/tests/eval/test_eval.py
@@ -7,10 +7,15 @@
import pytest
-from llama_models.llama3.api import SamplingParams
+from llama_models.llama3.api import SamplingParams, URL
+
+from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
+
+from llama_stack.apis.datasetio.datasetio import DatasetDefWithProvider
from llama_stack.apis.eval.eval import (
AppEvalTaskConfig,
+ BenchmarkEvalTaskConfig,
EvalTaskDefWithProvider,
ModelCandidate,
)
@@ -21,7 +26,7 @@ from llama_stack.providers.tests.datasetio.test_datasetio import register_datase
# How to run this test:
#
# pytest llama_stack/providers/tests/eval/test_eval.py
-# -m "meta_reference"
+# -m "meta_reference_eval_together_inference_huggingface_datasetio"
# -v -s --tb=short --disable-warnings
@@ -33,21 +38,26 @@ class Testeval:
eval_tasks_impl = eval_stack[Api.eval_tasks]
response = await eval_tasks_impl.list_eval_tasks()
assert isinstance(response, list)
- assert len(response) == 0
@pytest.mark.asyncio
async def test_eval_evaluate_rows(self, eval_stack):
- eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl = (
+ eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl, models_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
eval_stack[Api.datasetio],
eval_stack[Api.datasets],
+ eval_stack[Api.models],
)
+ for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
+ await models_impl.register_model(
+ model_id=model_id,
+ provider_id="",
+ )
await register_dataset(
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
)
response = await datasets_impl.list_datasets()
- assert len(response) == 1
+
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset_for_eval",
rows_in_page=3,
@@ -66,7 +76,6 @@ class Testeval:
provider_id="meta-reference",
)
await eval_tasks_impl.register_eval_task(task_def)
-
response = await eval_impl.evaluate_rows(
task_id=task_id,
input_rows=rows.rows,
@@ -84,11 +93,17 @@ class Testeval:
@pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack):
- eval_impl, eval_tasks_impl, datasets_impl = (
+ eval_impl, eval_tasks_impl, datasets_impl, models_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
eval_stack[Api.datasets],
+ eval_stack[Api.models],
)
+ for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
+ await models_impl.register_model(
+ model_id=model_id,
+ provider_id="",
+ )
await register_dataset(
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
)
@@ -124,3 +139,72 @@ class Testeval:
assert len(eval_response.generations) == 5
assert "meta-reference::subset_of" in eval_response.scores
assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores
+
+ @pytest.mark.asyncio
+ async def test_eval_run_benchmark_eval(self, eval_stack):
+ eval_impl, eval_tasks_impl, datasets_impl, models_impl = (
+ eval_stack[Api.eval],
+ eval_stack[Api.eval_tasks],
+ eval_stack[Api.datasets],
+ eval_stack[Api.models],
+ )
+ for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
+ await models_impl.register_model(
+ model_id=model_id,
+ provider_id="",
+ )
+ response = await datasets_impl.list_datasets()
+ assert len(response) > 0
+ if response[0].provider_id != "huggingface":
+ pytest.skip(
+ "Only huggingface provider supports pre-registered remote datasets"
+ )
+ # register dataset
+ mmlu = DatasetDefWithProvider(
+ identifier="mmlu",
+ url=URL(uri="https://huggingface.co/datasets/llamastack/evals"),
+ dataset_schema={
+ "input_query": StringType(),
+ "expected_answer": StringType(),
+ "chat_completion_input": ChatCompletionInputType(),
+ },
+ metadata={
+ "path": "llamastack/evals",
+ "name": "evals__mmlu__details",
+ "split": "train",
+ },
+ provider_id="",
+ )
+
+ await datasets_impl.register_dataset(mmlu)
+
+ # register eval task
+ meta_reference_mmlu = EvalTaskDefWithProvider(
+ identifier="meta-reference-mmlu",
+ dataset_id="mmlu",
+ scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"],
+ provider_id="",
+ )
+
+ await eval_tasks_impl.register_eval_task(meta_reference_mmlu)
+
+ # list benchmarks
+ response = await eval_tasks_impl.list_eval_tasks()
+ assert len(response) > 0
+
+ benchmark_id = "meta-reference-mmlu"
+ response = await eval_impl.run_eval(
+ task_id=benchmark_id,
+ task_config=BenchmarkEvalTaskConfig(
+ eval_candidate=ModelCandidate(
+ model="Llama3.2-3B-Instruct",
+ sampling_params=SamplingParams(),
+ ),
+ num_examples=3,
+ ),
+ )
+ job_status = await eval_impl.job_status(benchmark_id, response.job_id)
+ assert job_status and job_status.value == "completed"
+ eval_response = await eval_impl.job_result(benchmark_id, response.job_id)
+ assert eval_response is not None
+ assert len(eval_response.generations) == 3
diff --git a/llama_stack/providers/utils/datasetio/__init__.py b/llama_stack/providers/utils/datasetio/__init__.py
new file mode 100644
index 000000000..756f351d8
--- /dev/null
+++ b/llama_stack/providers/utils/datasetio/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
diff --git a/llama_stack/providers/utils/datasetio/url_utils.py b/llama_stack/providers/utils/datasetio/url_utils.py
new file mode 100644
index 000000000..3faea9f95
--- /dev/null
+++ b/llama_stack/providers/utils/datasetio/url_utils.py
@@ -0,0 +1,45 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import base64
+import io
+from urllib.parse import unquote
+
+import pandas
+
+from llama_models.llama3.api.datatypes import URL
+
+from llama_stack.providers.utils.memory.vector_store import parse_data_url
+
+
+def get_dataframe_from_url(url: URL):
+ df = None
+ if url.uri.endswith(".csv"):
+ df = pandas.read_csv(url.uri)
+ elif url.uri.endswith(".xlsx"):
+ df = pandas.read_excel(url.uri)
+ elif url.uri.startswith("data:"):
+ parts = parse_data_url(url.uri)
+ data = parts["data"]
+ if parts["is_base64"]:
+ data = base64.b64decode(data)
+ else:
+ data = unquote(data)
+ encoding = parts["encoding"] or "utf-8"
+ data = data.encode(encoding)
+
+ mime_type = parts["mimetype"]
+ mime_category = mime_type.split("/")[0]
+ data_bytes = io.BytesIO(data)
+
+ if mime_category == "text":
+ df = pandas.read_csv(data_bytes)
+ else:
+ df = pandas.read_excel(data_bytes)
+ else:
+ raise ValueError(f"Unsupported file type: {url}")
+
+ return df