From 2b7d70ba86bf33d55fd6fc67baec3b7ec13e66f8 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 11 Nov 2024 14:49:50 -0500 Subject: [PATCH] [Evals API][11/n] huggingface dataset provider + mmlu scoring fn (#392) * wip * scoring fn api * eval api * eval task * evaluate api update * pre commit * unwrap context -> config * config field doc * typo * naming fix * separate benchmark / app eval * api name * rename * wip tests * wip * datasetio test * delete unused * fixture * scoring resolve * fix scoring register * scoring test pass * score batch * scoring fix * fix eval * test eval works * huggingface provider * datasetdef files * mmlu scoring fn * test wip * remove type ignore * api refactor * add default task_eval_id for routing * add eval_id for jobs * remove type ignore * huggingface provider * wip huggingface register * only keep 1 run_eval * fix optional * register task required * register task required * delete old tests * fix * mmlu loose * refactor * msg * fix tests * move benchmark task def to file * msg * gen openapi * openapi gen * move dataset to hf llamastack repo * remove todo * refactor * add register model to unit test * rename * register to client * delete preregistered dataset/eval task * comments * huggingface -> remote adapter * openapi gen --- docs/openapi_generator/generate.py | 2 + docs/resources/llama-stack-spec.html | 1069 +++++++++++------ docs/resources/llama-stack-spec.yaml | 754 +++++++----- llama_stack/apis/eval/eval.py | 8 + .../datasetio/huggingface/__init__.py | 18 + .../adapters/datasetio/huggingface/config.py | 9 + .../datasetio/huggingface/huggingface.py | 81 ++ .../meta_reference/datasetio/datasetio.py | 33 +- .../inline/meta_reference/eval/eval.py | 11 +- .../inline/meta_reference/scoring/scoring.py | 17 +- .../scoring/scoring_fn/fn_defs/equality.py | 1 - .../fn_defs/llm_as_judge_8b_correctness.py | 1 - .../regex_parser_multiple_choice_answer.py | 69 ++ .../scoring_fn/regex_parser_scoring_fn.py | 67 ++ llama_stack/providers/registry/datasetio.py | 11 + .../providers/tests/datasetio/fixtures.py | 15 +- llama_stack/providers/tests/eval/conftest.py | 11 + llama_stack/providers/tests/eval/test_eval.py | 98 +- .../providers/utils/datasetio/__init__.py | 5 + .../providers/utils/datasetio/url_utils.py | 45 + 20 files changed, 1607 insertions(+), 718 deletions(-) create mode 100644 llama_stack/providers/adapters/datasetio/huggingface/__init__.py create mode 100644 llama_stack/providers/adapters/datasetio/huggingface/config.py create mode 100644 llama_stack/providers/adapters/datasetio/huggingface/huggingface.py create mode 100644 llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py create mode 100644 llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py create mode 100644 llama_stack/providers/utils/datasetio/__init__.py create mode 100644 llama_stack/providers/utils/datasetio/url_utils.py diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py index f9f56119b..dbfc90452 100644 --- a/docs/openapi_generator/generate.py +++ b/docs/openapi_generator/generate.py @@ -49,6 +49,7 @@ from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.inspect import * # noqa: F403 +from llama_stack.apis.eval_tasks import * # noqa: F403 class LlamaStack( @@ -63,6 +64,7 @@ class LlamaStack( PostTraining, Memory, Eval, + EvalTasks, Scoring, ScoringFunctions, DatasetIO, diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 363d968f9..8156039a9 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-11 13:59:59.544511" }, "servers": [ { @@ -469,7 +469,7 @@ } } }, - "/eval/evaluate": { + "/eval/evaluate_rows": { "post": { "responses": { "200": { @@ -501,47 +501,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/EvaluateRequest" - } - } - }, - "required": true - } - } - }, - "/eval/evaluate_batch": { - "post": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/Job" - } - } - } - } - }, - "tags": [ - "Eval" - ], - "parameters": [ - { - "name": "X-LlamaStack-ProviderData", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/EvaluateBatchRequest" + "$ref": "#/components/schemas/EvaluateRowsRequest" } } }, @@ -766,6 +726,51 @@ ] } }, + "/eval_tasks/get": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "$ref": "#/components/schemas/EvalTaskDefWithProvider" + }, + { + "type": "null" + } + ] + } + } + } + } + }, + "tags": [ + "EvalTasks" + ], + "parameters": [ + { + "name": "name", + "in": "query", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ] + } + }, "/memory_banks/get": { "get": { "responses": { @@ -834,7 +839,7 @@ "schema": { "oneOf": [ { - "$ref": "#/components/schemas/ModelDefWithProvider" + "$ref": "#/components/schemas/Model" }, { "type": "null" @@ -986,7 +991,7 @@ "schema": { "oneOf": [ { - "$ref": "#/components/schemas/ShieldDefWithProvider" + "$ref": "#/components/schemas/Shield" }, { "type": "null" @@ -1002,7 +1007,7 @@ ], "parameters": [ { - "name": "shield_type", + "name": "identifier", "in": "query", "required": true, "schema": { @@ -1317,6 +1322,14 @@ "Eval" ], "parameters": [ + { + "name": "task_id", + "in": "query", + "required": true, + "schema": { + "type": "string" + } + }, { "name": "job_id", "in": "query", @@ -1362,6 +1375,14 @@ "Eval" ], "parameters": [ + { + "name": "task_id", + "in": "query", + "required": true, + "schema": { + "type": "string" + } + }, { "name": "job_id", "in": "query", @@ -1412,6 +1433,36 @@ ] } }, + "/eval_tasks/list": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/jsonl": { + "schema": { + "$ref": "#/components/schemas/EvalTaskDefWithProvider" + } + } + } + } + }, + "tags": [ + "EvalTasks" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ] + } + }, "/memory_banks/list": { "get": { "responses": { @@ -1463,7 +1514,7 @@ "content": { "application/jsonl": { "schema": { - "$ref": "#/components/schemas/ModelDefWithProvider" + "$ref": "#/components/schemas/Model" } } } @@ -1592,7 +1643,7 @@ "content": { "application/jsonl": { "schema": { - "$ref": "#/components/schemas/ShieldDefWithProvider" + "$ref": "#/components/schemas/Shield" } } } @@ -1760,6 +1811,39 @@ } } }, + "/eval_tasks/register": { + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "EvalTasks" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RegisterEvalTaskRequest" + } + } + }, + "required": true + } + } + }, "/memory_banks/register": { "post": { "responses": { @@ -1797,7 +1881,14 @@ "post": { "responses": { "200": { - "description": "OK" + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Model" + } + } + } } }, "tags": [ @@ -1863,7 +1954,14 @@ "post": { "responses": { "200": { - "description": "OK" + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Shield" + } + } + } } }, "tags": [ @@ -1892,6 +1990,46 @@ } } }, + "/eval/run_eval": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Job" + } + } + } + } + }, + "tags": [ + "Eval" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunEvalRequest" + } + } + }, + "required": true + } + } + }, "/safety/run_shield": { "post": { "responses": { @@ -4490,6 +4628,103 @@ "config" ] }, + "AppEvalTaskConfig": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "app", + "default": "app" + }, + "eval_candidate": { + "oneOf": [ + { + "$ref": "#/components/schemas/ModelCandidate" + }, + { + "$ref": "#/components/schemas/AgentCandidate" + } + ] + }, + "scoring_params": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams" + }, + { + "$ref": "#/components/schemas/RegexParserScoringFnParams" + } + ] + } + }, + "num_examples": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "type", + "eval_candidate", + "scoring_params" + ] + }, + "BenchmarkEvalTaskConfig": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "benchmark", + "default": "benchmark" + }, + "eval_candidate": { + "oneOf": [ + { + "$ref": "#/components/schemas/ModelCandidate" + }, + { + "$ref": "#/components/schemas/AgentCandidate" + } + ] + }, + "num_examples": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "type", + "eval_candidate" + ] + }, + "LLMAsJudgeScoringFnParams": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "llm_as_judge", + "default": "llm_as_judge" + }, + "judge_model": { + "type": "string" + }, + "prompt_template": { + "type": "string" + }, + "judge_score_regexes": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "required": [ + "type", + "judge_model" + ] + }, "ModelCandidate": { "type": "object", "properties": { @@ -4515,9 +4750,32 @@ "sampling_params" ] }, - "EvaluateRequest": { + "RegexParserScoringFnParams": { "type": "object", "properties": { + "type": { + "type": "string", + "const": "regex_parser", + "default": "regex_parser" + }, + "parsing_regexes": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "required": [ + "type" + ] + }, + "EvaluateRowsRequest": { + "type": "object", + "properties": { + "task_id": { + "type": "string" + }, "input_rows": { "type": "array", "items": { @@ -4546,28 +4804,29 @@ } } }, - "candidate": { - "oneOf": [ - { - "$ref": "#/components/schemas/ModelCandidate" - }, - { - "$ref": "#/components/schemas/AgentCandidate" - } - ] - }, "scoring_functions": { "type": "array", "items": { "type": "string" } + }, + "task_config": { + "oneOf": [ + { + "$ref": "#/components/schemas/BenchmarkEvalTaskConfig" + }, + { + "$ref": "#/components/schemas/AppEvalTaskConfig" + } + ] } }, "additionalProperties": false, "required": [ + "task_id", "input_rows", - "candidate", - "scoring_functions" + "scoring_functions", + "task_config" ] }, "EvaluateResponse": { @@ -4677,48 +4936,6 @@ "aggregated_results" ] }, - "EvaluateBatchRequest": { - "type": "object", - "properties": { - "dataset_id": { - "type": "string" - }, - "candidate": { - "oneOf": [ - { - "$ref": "#/components/schemas/ModelCandidate" - }, - { - "$ref": "#/components/schemas/AgentCandidate" - } - ] - }, - "scoring_functions": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "additionalProperties": false, - "required": [ - "dataset_id", - "candidate", - "scoring_functions" - ] - }, - "Job": { - "type": "object", - "properties": { - "job_id": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "job_id" - ] - }, "GetAgentsSessionRequest": { "type": "object", "properties": { @@ -5085,6 +5302,11 @@ ] } }, + "type": { + "type": "string", + "const": "dataset", + "default": "dataset" + }, "provider_id": { "type": "string" } @@ -5095,18 +5317,25 @@ "dataset_schema", "url", "metadata", + "type", "provider_id" ] }, - "ModelDefWithProvider": { + "EvalTaskDefWithProvider": { "type": "object", "properties": { "identifier": { "type": "string" }, - "llama_model": { + "dataset_id": { "type": "string" }, + "scoring_functions": { + "type": "array", + "items": { + "type": "string" + } + }, "metadata": { "type": "object", "additionalProperties": { @@ -5132,6 +5361,11 @@ ] } }, + "type": { + "type": "string", + "const": "eval_task", + "default": "eval_task" + }, "provider_id": { "type": "string" } @@ -5139,11 +5373,65 @@ "additionalProperties": false, "required": [ "identifier", - "llama_model", + "dataset_id", + "scoring_functions", "metadata", + "type", "provider_id" ] }, + "Model": { + "type": "object", + "properties": { + "identifier": { + "type": "string" + }, + "provider_resource_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "model", + "default": "model" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "identifier", + "provider_resource_id", + "provider_id", + "type", + "metadata" + ] + }, "PaginatedRowsResult": { "type": "object", "properties": { @@ -5188,166 +5476,6 @@ "total_count" ] }, - "Parameter": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "type": { - "oneOf": [ - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "string", - "default": "string" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "number", - "default": "number" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "boolean", - "default": "boolean" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "array", - "default": "array" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "object", - "default": "object" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "json", - "default": "json" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "union", - "default": "union" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "chat_completion_input", - "default": "chat_completion_input" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "completion_input", - "default": "completion_input" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "agent_turn_input", - "default": "agent_turn_input" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - } - ] - }, - "description": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "name", - "type" - ] - }, "ScoringFnDefWithProvider": { "type": "object", "properties": { @@ -5382,12 +5510,6 @@ ] } }, - "parameters": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Parameter" - } - }, "return_type": { "oneOf": [ { @@ -5532,27 +5654,21 @@ } ] }, - "context": { - "type": "object", - "properties": { - "judge_model": { - "type": "string" + "params": { + "oneOf": [ + { + "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams" }, - "prompt_template": { - "type": "string" - }, - "judge_score_regex": { - "type": "array", - "items": { - "type": "string" - } + { + "$ref": "#/components/schemas/RegexParserScoringFnParams" } - }, - "additionalProperties": false, - "required": [ - "judge_model" ] }, + "type": { + "type": "string", + "const": "scoring_fn", + "default": "scoring_fn" + }, "provider_id": { "type": "string" } @@ -5561,20 +5677,31 @@ "required": [ "identifier", "metadata", - "parameters", "return_type", + "type", "provider_id" ] }, - "ShieldDefWithProvider": { + "Shield": { "type": "object", "properties": { "identifier": { "type": "string" }, - "type": { + "provider_resource_id": { "type": "string" }, + "provider_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "shield", + "default": "shield" + }, + "shield_type": { + "$ref": "#/components/schemas/ShieldType" + }, "params": { "type": "object", "additionalProperties": { @@ -5599,17 +5726,26 @@ } ] } - }, - "provider_id": { - "type": "string" } }, "additionalProperties": false, "required": [ "identifier", + "provider_resource_id", + "provider_id", "type", - "params", - "provider_id" + "shield_type", + "params" + ], + "title": "A safety shield resource that can be used to check content" + }, + "ShieldType": { + "type": "string", + "enum": [ + "generic_content_shield", + "llama_guard", + "code_scanner", + "prompt_guard" ] }, "Trace": { @@ -5867,12 +6003,16 @@ "JobCancelRequest": { "type": "object", "properties": { + "task_id": { + "type": "string" + }, "job_id": { "type": "string" } }, "additionalProperties": false, "required": [ + "task_id", "job_id" ] }, @@ -6514,6 +6654,18 @@ "dataset_def" ] }, + "RegisterEvalTaskRequest": { + "type": "object", + "properties": { + "eval_task_def": { + "$ref": "#/components/schemas/EvalTaskDefWithProvider" + } + }, + "additionalProperties": false, + "required": [ + "eval_task_def" + ] + }, "RegisterMemoryBankRequest": { "type": "object", "properties": { @@ -6542,13 +6694,44 @@ "RegisterModelRequest": { "type": "object", "properties": { - "model": { - "$ref": "#/components/schemas/ModelDefWithProvider" + "model_id": { + "type": "string" + }, + "provider_model_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, "required": [ - "model" + "model_id" ] }, "RegisterScoringFunctionRequest": { @@ -6566,19 +6749,89 @@ "RegisterShieldRequest": { "type": "object", "properties": { - "shield": { - "$ref": "#/components/schemas/ShieldDefWithProvider" + "shield_id": { + "type": "string" + }, + "shield_type": { + "$ref": "#/components/schemas/ShieldType" + }, + "provider_shield_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "params": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, "required": [ - "shield" + "shield_id", + "shield_type" + ] + }, + "RunEvalRequest": { + "type": "object", + "properties": { + "task_id": { + "type": "string" + }, + "task_config": { + "oneOf": [ + { + "$ref": "#/components/schemas/BenchmarkEvalTaskConfig" + }, + { + "$ref": "#/components/schemas/AppEvalTaskConfig" + } + ] + } + }, + "additionalProperties": false, + "required": [ + "task_id", + "task_config" + ] + }, + "Job": { + "type": "object", + "properties": { + "job_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "job_id" ] }, "RunShieldRequest": { "type": "object", "properties": { - "shield_type": { + "shield_id": { "type": "string" }, "messages": { @@ -6628,7 +6881,7 @@ }, "additionalProperties": false, "required": [ - "shield_type", + "shield_id", "messages", "params" ] @@ -6674,9 +6927,23 @@ } }, "scoring_functions": { - "type": "array", - "items": { - "type": "string" + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "oneOf": [ + { + "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams" + }, + { + "$ref": "#/components/schemas/RegexParserScoringFnParams" + } + ] + }, + { + "type": "null" + } + ] } } }, @@ -6708,9 +6975,23 @@ "type": "string" }, "scoring_functions": { - "type": "array", - "items": { - "type": "string" + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "oneOf": [ + { + "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams" + }, + { + "$ref": "#/components/schemas/RegexParserScoringFnParams" + } + ] + }, + { + "type": "null" + } + ] } }, "save_results_dataset": { @@ -7063,56 +7344,59 @@ ], "tags": [ { - "name": "Memory" - }, - { - "name": "Inference" - }, - { - "name": "Eval" - }, - { - "name": "MemoryBanks" - }, - { - "name": "Models" - }, - { - "name": "BatchInference" - }, - { - "name": "PostTraining" - }, - { - "name": "Agents" - }, - { - "name": "Shields" - }, - { - "name": "Telemetry" - }, - { - "name": "Inspect" - }, - { - "name": "DatasetIO" - }, - { - "name": "SyntheticDataGeneration" + "name": "ScoringFunctions" }, { "name": "Datasets" }, { - "name": "Scoring" - }, - { - "name": "ScoringFunctions" + "name": "Inspect" }, { "name": "Safety" }, + { + "name": "Eval" + }, + { + "name": "Inference" + }, + { + "name": "BatchInference" + }, + { + "name": "Agents" + }, + { + "name": "PostTraining" + }, + { + "name": "Shields" + }, + { + "name": "Memory" + }, + { + "name": "Scoring" + }, + { + "name": "SyntheticDataGeneration" + }, + { + "name": "EvalTasks" + }, + { + "name": "MemoryBanks" + }, + { + "name": "DatasetIO" + }, + { + "name": "Models" + }, + { + "name": "Telemetry" + }, { "name": "BuiltinTool", "description": "" @@ -7377,13 +7661,29 @@ "name": "AgentCandidate", "description": "" }, + { + "name": "AppEvalTaskConfig", + "description": "" + }, + { + "name": "BenchmarkEvalTaskConfig", + "description": "" + }, + { + "name": "LLMAsJudgeScoringFnParams", + "description": "" + }, { "name": "ModelCandidate", "description": "" }, { - "name": "EvaluateRequest", - "description": "" + "name": "RegexParserScoringFnParams", + "description": "" + }, + { + "name": "EvaluateRowsRequest", + "description": "" }, { "name": "EvaluateResponse", @@ -7393,14 +7693,6 @@ "name": "ScoringResult", "description": "" }, - { - "name": "EvaluateBatchRequest", - "description": "" - }, - { - "name": "Job", - "description": "" - }, { "name": "GetAgentsSessionRequest", "description": "" @@ -7434,24 +7726,28 @@ "description": "" }, { - "name": "ModelDefWithProvider", - "description": "" + "name": "EvalTaskDefWithProvider", + "description": "" + }, + { + "name": "Model", + "description": "" }, { "name": "PaginatedRowsResult", "description": "" }, - { - "name": "Parameter", - "description": "" - }, { "name": "ScoringFnDefWithProvider", "description": "" }, { - "name": "ShieldDefWithProvider", - "description": "" + "name": "Shield", + "description": "A safety shield resource that can be used to check content\n\n" + }, + { + "name": "ShieldType", + "description": "" }, { "name": "Trace", @@ -7573,6 +7869,10 @@ "name": "RegisterDatasetRequest", "description": "" }, + { + "name": "RegisterEvalTaskRequest", + "description": "" + }, { "name": "RegisterMemoryBankRequest", "description": "" @@ -7589,6 +7889,14 @@ "name": "RegisterShieldRequest", "description": "" }, + { + "name": "RunEvalRequest", + "description": "" + }, + { + "name": "Job", + "description": "" + }, { "name": "RunShieldRequest", "description": "" @@ -7651,6 +7959,7 @@ "DatasetIO", "Datasets", "Eval", + "EvalTasks", "Inference", "Inspect", "Memory", @@ -7680,11 +7989,13 @@ "AgentTurnResponseStreamChunk", "AgentTurnResponseTurnCompletePayload", "AgentTurnResponseTurnStartPayload", + "AppEvalTaskConfig", "Attachment", "BatchChatCompletionRequest", "BatchChatCompletionResponse", "BatchCompletionRequest", "BatchCompletionResponse", + "BenchmarkEvalTaskConfig", "BuiltinTool", "CancelTrainingJobRequest", "ChatCompletionRequest", @@ -7708,9 +8019,9 @@ "DoraFinetuningConfig", "EmbeddingsRequest", "EmbeddingsResponse", - "EvaluateBatchRequest", - "EvaluateRequest", + "EvalTaskDefWithProvider", "EvaluateResponse", + "EvaluateRowsRequest", "FinetuningAlgorithm", "FunctionCallToolDefinition", "GetAgentsSessionRequest", @@ -7724,6 +8035,7 @@ "JobStatus", "KeyValueMemoryBankDef", "KeywordMemoryBankDef", + "LLMAsJudgeScoringFnParams", "LogEventRequest", "LogSeverity", "LoraFinetuningConfig", @@ -7731,11 +8043,10 @@ "MemoryRetrievalStep", "MemoryToolDefinition", "MetricEvent", + "Model", "ModelCandidate", - "ModelDefWithProvider", "OptimizerConfig", "PaginatedRowsResult", - "Parameter", "PhotogenToolDefinition", "PostTrainingJob", "PostTrainingJobArtifactsResponse", @@ -7748,7 +8059,9 @@ "QueryDocumentsRequest", "QueryDocumentsResponse", "RLHFAlgorithm", + "RegexParserScoringFnParams", "RegisterDatasetRequest", + "RegisterEvalTaskRequest", "RegisterMemoryBankRequest", "RegisterModelRequest", "RegisterScoringFunctionRequest", @@ -7756,6 +8069,7 @@ "RestAPIExecutionConfig", "RestAPIMethod", "RouteInfo", + "RunEvalRequest", "RunShieldRequest", "RunShieldResponse", "SafetyViolation", @@ -7769,8 +8083,9 @@ "ScoringResult", "SearchToolDefinition", "Session", + "Shield", "ShieldCallStep", - "ShieldDefWithProvider", + "ShieldType", "SpanEndPayload", "SpanStartPayload", "SpanStatus", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 7dd231965..0e6571301 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -218,6 +218,30 @@ components: - event_type - turn_id type: object + AppEvalTaskConfig: + additionalProperties: false + properties: + eval_candidate: + oneOf: + - $ref: '#/components/schemas/ModelCandidate' + - $ref: '#/components/schemas/AgentCandidate' + num_examples: + type: integer + scoring_params: + additionalProperties: + oneOf: + - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' + - $ref: '#/components/schemas/RegexParserScoringFnParams' + type: object + type: + const: app + default: app + type: string + required: + - type + - eval_candidate + - scoring_params + type: object Attachment: additionalProperties: false properties: @@ -322,6 +346,23 @@ components: required: - completion_message_batch type: object + BenchmarkEvalTaskConfig: + additionalProperties: false + properties: + eval_candidate: + oneOf: + - $ref: '#/components/schemas/ModelCandidate' + - $ref: '#/components/schemas/AgentCandidate' + num_examples: + type: integer + type: + const: benchmark + default: benchmark + type: string + required: + - type + - eval_candidate + type: object BuiltinTool: enum: - brave_search @@ -790,6 +831,10 @@ components: type: object provider_id: type: string + type: + const: dataset + default: dataset + type: string url: $ref: '#/components/schemas/URL' required: @@ -797,6 +842,7 @@ components: - dataset_schema - url - metadata + - type - provider_id type: object DeleteAgentsRequest: @@ -872,51 +918,40 @@ components: required: - embeddings type: object - EvaluateBatchRequest: + EvalTaskDefWithProvider: additionalProperties: false properties: - candidate: - oneOf: - - $ref: '#/components/schemas/ModelCandidate' - - $ref: '#/components/schemas/AgentCandidate' dataset_id: type: string + identifier: + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + provider_id: + type: string scoring_functions: items: type: string type: array + type: + const: eval_task + default: eval_task + type: string required: + - identifier - dataset_id - - candidate - - scoring_functions - type: object - EvaluateRequest: - additionalProperties: false - properties: - candidate: - oneOf: - - $ref: '#/components/schemas/ModelCandidate' - - $ref: '#/components/schemas/AgentCandidate' - input_rows: - items: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: array - scoring_functions: - items: - type: string - type: array - required: - - input_rows - - candidate - scoring_functions + - metadata + - type + - provider_id type: object EvaluateResponse: additionalProperties: false @@ -941,6 +976,37 @@ components: - generations - scores type: object + EvaluateRowsRequest: + additionalProperties: false + properties: + input_rows: + items: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + type: array + scoring_functions: + items: + type: string + type: array + task_config: + oneOf: + - $ref: '#/components/schemas/BenchmarkEvalTaskConfig' + - $ref: '#/components/schemas/AppEvalTaskConfig' + task_id: + type: string + required: + - task_id + - input_rows + - scoring_functions + - task_config + type: object FinetuningAlgorithm: enum: - full @@ -1082,7 +1148,10 @@ components: properties: job_id: type: string + task_id: + type: string required: + - task_id - job_id type: object JobStatus: @@ -1124,6 +1193,25 @@ components: - provider_id - type type: object + LLMAsJudgeScoringFnParams: + additionalProperties: false + properties: + judge_model: + type: string + judge_score_regexes: + items: + type: string + type: array + prompt_template: + type: string + type: + const: llm_as_judge + default: llm_as_judge + type: string + required: + - type + - judge_model + type: object LogEventRequest: additionalProperties: false properties: @@ -1405,6 +1493,36 @@ components: - value - unit type: object + Model: + additionalProperties: false + properties: + identifier: + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + provider_id: + type: string + provider_resource_id: + type: string + type: + const: model + default: model + type: string + required: + - identifier + - provider_resource_id + - provider_id + - type + - metadata + type: object ModelCandidate: additionalProperties: false properties: @@ -1423,31 +1541,6 @@ components: - model - sampling_params type: object - ModelDefWithProvider: - additionalProperties: false - properties: - identifier: - type: string - llama_model: - type: string - metadata: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - provider_id: - type: string - required: - - identifier - - llama_model - - metadata - - provider_id - type: object OptimizerConfig: additionalProperties: false properties: @@ -1492,109 +1585,6 @@ components: - rows - total_count type: object - Parameter: - additionalProperties: false - properties: - description: - type: string - name: - type: string - type: - oneOf: - - additionalProperties: false - properties: - type: - const: string - default: string - type: string - required: - - type - type: object - - additionalProperties: false - properties: - type: - const: number - default: number - type: string - required: - - type - type: object - - additionalProperties: false - properties: - type: - const: boolean - default: boolean - type: string - required: - - type - type: object - - additionalProperties: false - properties: - type: - const: array - default: array - type: string - required: - - type - type: object - - additionalProperties: false - properties: - type: - const: object - default: object - type: string - required: - - type - type: object - - additionalProperties: false - properties: - type: - const: json - default: json - type: string - required: - - type - type: object - - additionalProperties: false - properties: - type: - const: union - default: union - type: string - required: - - type - type: object - - additionalProperties: false - properties: - type: - const: chat_completion_input - default: chat_completion_input - type: string - required: - - type - type: object - - additionalProperties: false - properties: - type: - const: completion_input - default: completion_input - type: string - required: - - type - type: object - - additionalProperties: false - properties: - type: - const: agent_turn_input - default: agent_turn_input - type: string - required: - - type - type: object - required: - - name - - type - type: object PhotogenToolDefinition: additionalProperties: false properties: @@ -1844,6 +1834,20 @@ components: enum: - dpo type: string + RegexParserScoringFnParams: + additionalProperties: false + properties: + parsing_regexes: + items: + type: string + type: array + type: + const: regex_parser + default: regex_parser + type: string + required: + - type + type: object RegisterDatasetRequest: additionalProperties: false properties: @@ -1852,6 +1856,14 @@ components: required: - dataset_def type: object + RegisterEvalTaskRequest: + additionalProperties: false + properties: + eval_task_def: + $ref: '#/components/schemas/EvalTaskDefWithProvider' + required: + - eval_task_def + type: object RegisterMemoryBankRequest: additionalProperties: false properties: @@ -1867,10 +1879,24 @@ components: RegisterModelRequest: additionalProperties: false properties: - model: - $ref: '#/components/schemas/ModelDefWithProvider' + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + model_id: + type: string + provider_id: + type: string + provider_model_id: + type: string required: - - model + - model_id type: object RegisterScoringFunctionRequest: additionalProperties: false @@ -1883,10 +1909,27 @@ components: RegisterShieldRequest: additionalProperties: false properties: - shield: - $ref: '#/components/schemas/ShieldDefWithProvider' + params: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + provider_id: + type: string + provider_shield_id: + type: string + shield_id: + type: string + shield_type: + $ref: '#/components/schemas/ShieldType' required: - - shield + - shield_id + - shield_type type: object RestAPIExecutionConfig: additionalProperties: false @@ -1952,6 +1995,19 @@ components: - method - provider_types type: object + RunEvalRequest: + additionalProperties: false + properties: + task_config: + oneOf: + - $ref: '#/components/schemas/BenchmarkEvalTaskConfig' + - $ref: '#/components/schemas/AppEvalTaskConfig' + task_id: + type: string + required: + - task_id + - task_config + type: object RunShieldRequest: additionalProperties: false properties: @@ -1973,10 +2029,10 @@ components: - type: array - type: object type: object - shield_type: + shield_id: type: string required: - - shield_type + - shield_id - messages - params type: object @@ -2045,9 +2101,13 @@ components: save_results_dataset: type: boolean scoring_functions: - items: - type: string - type: array + additionalProperties: + oneOf: + - oneOf: + - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' + - $ref: '#/components/schemas/RegexParserScoringFnParams' + - type: 'null' + type: object required: - dataset_id - scoring_functions @@ -2081,9 +2141,13 @@ components: type: object type: array scoring_functions: - items: - type: string - type: array + additionalProperties: + oneOf: + - oneOf: + - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' + - $ref: '#/components/schemas/RegexParserScoringFnParams' + - type: 'null' + type: object required: - input_rows - scoring_functions @@ -2101,20 +2165,6 @@ components: ScoringFnDefWithProvider: additionalProperties: false properties: - context: - additionalProperties: false - properties: - judge_model: - type: string - judge_score_regex: - items: - type: string - type: array - prompt_template: - type: string - required: - - judge_model - type: object description: type: string identifier: @@ -2129,10 +2179,10 @@ components: - type: array - type: object type: object - parameters: - items: - $ref: '#/components/schemas/Parameter' - type: array + params: + oneOf: + - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' + - $ref: '#/components/schemas/RegexParserScoringFnParams' provider_id: type: string return_type: @@ -2227,11 +2277,15 @@ components: required: - type type: object + type: + const: scoring_fn + default: scoring_fn + type: string required: - identifier - metadata - - parameters - return_type + - type - provider_id type: object ScoringResult: @@ -2320,6 +2374,40 @@ components: - started_at title: A single session of an interaction with an Agentic System. type: object + Shield: + additionalProperties: false + properties: + identifier: + type: string + params: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + provider_id: + type: string + provider_resource_id: + type: string + shield_type: + $ref: '#/components/schemas/ShieldType' + type: + const: shield + default: shield + type: string + required: + - identifier + - provider_resource_id + - provider_id + - type + - shield_type + - params + title: A safety shield resource that can be used to check content + type: object ShieldCallStep: additionalProperties: false properties: @@ -2344,31 +2432,13 @@ components: - step_id - step_type type: object - ShieldDefWithProvider: - additionalProperties: false - properties: - identifier: - type: string - params: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - provider_id: - type: string - type: - type: string - required: - - identifier - - type - - params - - provider_id - type: object + ShieldType: + enum: + - generic_content_shield + - llama_guard + - code_scanner + - prompt_guard + type: string SpanEndPayload: additionalProperties: false properties: @@ -2998,7 +3068,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905" + \ draft and subject to change.\n Generated at 2024-11-11 13:59:59.544511" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -3387,7 +3457,7 @@ paths: description: OK tags: - Datasets - /eval/evaluate: + /eval/evaluate_rows: post: parameters: - description: JSON-encoded provider data which will be made available to the @@ -3401,7 +3471,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/EvaluateRequest' + $ref: '#/components/schemas/EvaluateRowsRequest' required: true responses: '200': @@ -3412,31 +3482,6 @@ paths: description: OK tags: - Eval - /eval/evaluate_batch: - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-ProviderData - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/EvaluateBatchRequest' - required: true - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/Job' - description: OK - tags: - - Eval /eval/job/cancel: post: parameters: @@ -3461,6 +3506,11 @@ paths: /eval/job/result: get: parameters: + - in: query + name: task_id + required: true + schema: + type: string - in: query name: job_id required: true @@ -3485,6 +3535,11 @@ paths: /eval/job/status: get: parameters: + - in: query + name: task_id + required: true + schema: + type: string - in: query name: job_id required: true @@ -3508,6 +3563,97 @@ paths: description: OK tags: - Eval + /eval/run_eval: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RunEvalRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Job' + description: OK + tags: + - Eval + /eval_tasks/get: + get: + parameters: + - in: query + name: name + required: true + schema: + type: string + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + responses: + '200': + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/EvalTaskDefWithProvider' + - type: 'null' + description: OK + tags: + - EvalTasks + /eval_tasks/list: + get: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + responses: + '200': + content: + application/jsonl: + schema: + $ref: '#/components/schemas/EvalTaskDefWithProvider' + description: OK + tags: + - EvalTasks + /eval_tasks/register: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterEvalTaskRequest' + required: true + responses: + '200': + description: OK + tags: + - EvalTasks /health: get: parameters: @@ -3747,7 +3893,7 @@ paths: application/json: schema: oneOf: - - $ref: '#/components/schemas/ModelDefWithProvider' + - $ref: '#/components/schemas/Model' - type: 'null' description: OK tags: @@ -3767,7 +3913,7 @@ paths: content: application/jsonl: schema: - $ref: '#/components/schemas/ModelDefWithProvider' + $ref: '#/components/schemas/Model' description: OK tags: - Models @@ -3789,6 +3935,10 @@ paths: required: true responses: '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Model' description: OK tags: - Models @@ -4143,7 +4293,7 @@ paths: get: parameters: - in: query - name: shield_type + name: identifier required: true schema: type: string @@ -4160,7 +4310,7 @@ paths: application/json: schema: oneOf: - - $ref: '#/components/schemas/ShieldDefWithProvider' + - $ref: '#/components/schemas/Shield' - type: 'null' description: OK tags: @@ -4180,7 +4330,7 @@ paths: content: application/jsonl: schema: - $ref: '#/components/schemas/ShieldDefWithProvider' + $ref: '#/components/schemas/Shield' description: OK tags: - Shields @@ -4202,6 +4352,10 @@ paths: required: true responses: '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Shield' description: OK tags: - Shields @@ -4280,23 +4434,24 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Memory -- name: Inference -- name: Eval -- name: MemoryBanks -- name: Models -- name: BatchInference -- name: PostTraining -- name: Agents -- name: Shields -- name: Telemetry -- name: Inspect -- name: DatasetIO -- name: SyntheticDataGeneration -- name: Datasets -- name: Scoring - name: ScoringFunctions +- name: Datasets +- name: Inspect - name: Safety +- name: Eval +- name: Inference +- name: BatchInference +- name: Agents +- name: PostTraining +- name: Shields +- name: Memory +- name: Scoring +- name: SyntheticDataGeneration +- name: EvalTasks +- name: MemoryBanks +- name: DatasetIO +- name: Models +- name: Telemetry - description: name: BuiltinTool - description: name: AgentCandidate +- description: + name: AppEvalTaskConfig +- description: + name: BenchmarkEvalTaskConfig +- description: + name: LLMAsJudgeScoringFnParams - description: name: ModelCandidate -- description: - name: EvaluateRequest + name: RegexParserScoringFnParams +- description: + name: EvaluateRowsRequest - description: name: EvaluateResponse - description: name: ScoringResult -- description: - name: EvaluateBatchRequest -- description: - name: Job - description: name: GetAgentsSessionRequest @@ -4544,20 +4706,24 @@ tags: - description: name: DatasetDefWithProvider -- description: - name: ModelDefWithProvider + name: EvalTaskDefWithProvider +- description: + name: Model - description: name: PaginatedRowsResult -- description: - name: Parameter - description: name: ScoringFnDefWithProvider -- description: - name: ShieldDefWithProvider +- description: 'A safety shield resource that can be used to check content + + + ' + name: Shield +- description: + name: ShieldType - description: name: Trace - description: 'Checkpoint created during training runs @@ -4647,6 +4813,9 @@ tags: - description: name: RegisterDatasetRequest +- description: + name: RegisterEvalTaskRequest - description: name: RegisterMemoryBankRequest @@ -4659,6 +4828,10 @@ tags: - description: name: RegisterShieldRequest +- description: + name: RunEvalRequest +- description: + name: Job - description: name: RunShieldRequest @@ -4708,6 +4881,7 @@ x-tagGroups: - DatasetIO - Datasets - Eval + - EvalTasks - Inference - Inspect - Memory @@ -4734,11 +4908,13 @@ x-tagGroups: - AgentTurnResponseStreamChunk - AgentTurnResponseTurnCompletePayload - AgentTurnResponseTurnStartPayload + - AppEvalTaskConfig - Attachment - BatchChatCompletionRequest - BatchChatCompletionResponse - BatchCompletionRequest - BatchCompletionResponse + - BenchmarkEvalTaskConfig - BuiltinTool - CancelTrainingJobRequest - ChatCompletionRequest @@ -4762,9 +4938,9 @@ x-tagGroups: - DoraFinetuningConfig - EmbeddingsRequest - EmbeddingsResponse - - EvaluateBatchRequest - - EvaluateRequest + - EvalTaskDefWithProvider - EvaluateResponse + - EvaluateRowsRequest - FinetuningAlgorithm - FunctionCallToolDefinition - GetAgentsSessionRequest @@ -4778,6 +4954,7 @@ x-tagGroups: - JobStatus - KeyValueMemoryBankDef - KeywordMemoryBankDef + - LLMAsJudgeScoringFnParams - LogEventRequest - LogSeverity - LoraFinetuningConfig @@ -4785,11 +4962,10 @@ x-tagGroups: - MemoryRetrievalStep - MemoryToolDefinition - MetricEvent + - Model - ModelCandidate - - ModelDefWithProvider - OptimizerConfig - PaginatedRowsResult - - Parameter - PhotogenToolDefinition - PostTrainingJob - PostTrainingJobArtifactsResponse @@ -4802,7 +4978,9 @@ x-tagGroups: - QueryDocumentsRequest - QueryDocumentsResponse - RLHFAlgorithm + - RegexParserScoringFnParams - RegisterDatasetRequest + - RegisterEvalTaskRequest - RegisterMemoryBankRequest - RegisterModelRequest - RegisterScoringFunctionRequest @@ -4810,6 +4988,7 @@ x-tagGroups: - RestAPIExecutionConfig - RestAPIMethod - RouteInfo + - RunEvalRequest - RunShieldRequest - RunShieldResponse - SafetyViolation @@ -4823,8 +5002,9 @@ x-tagGroups: - ScoringResult - SearchToolDefinition - Session + - Shield - ShieldCallStep - - ShieldDefWithProvider + - ShieldType - SpanEndPayload - SpanStartPayload - SpanStatus diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 50fb922fe..04a5a55d5 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -40,6 +40,10 @@ EvalCandidate = Annotated[ class BenchmarkEvalTaskConfig(BaseModel): type: Literal["benchmark"] = "benchmark" eval_candidate: EvalCandidate + num_examples: Optional[int] = Field( + description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated", + default=None, + ) @json_schema_type @@ -50,6 +54,10 @@ class AppEvalTaskConfig(BaseModel): description="Map between scoring function id and parameters for each scoring function you want to run", default_factory=dict, ) + num_examples: Optional[int] = Field( + description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated", + default=None, + ) # we could optinally add any specific dataset config here diff --git a/llama_stack/providers/adapters/datasetio/huggingface/__init__.py b/llama_stack/providers/adapters/datasetio/huggingface/__init__.py new file mode 100644 index 000000000..db803d183 --- /dev/null +++ b/llama_stack/providers/adapters/datasetio/huggingface/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import HuggingfaceDatasetIOConfig + + +async def get_adapter_impl( + config: HuggingfaceDatasetIOConfig, + _deps, +): + from .huggingface import HuggingfaceDatasetIOImpl + + impl = HuggingfaceDatasetIOImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/datasetio/huggingface/config.py b/llama_stack/providers/adapters/datasetio/huggingface/config.py new file mode 100644 index 000000000..89dbe53a0 --- /dev/null +++ b/llama_stack/providers/adapters/datasetio/huggingface/config.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.datasetio import * # noqa: F401, F403 + + +class HuggingfaceDatasetIOConfig(BaseModel): ... diff --git a/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py b/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py new file mode 100644 index 000000000..598ca5cfd --- /dev/null +++ b/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import List, Optional + +from llama_stack.apis.datasetio import * # noqa: F403 + + +import datasets as hf_datasets +from llama_stack.providers.datatypes import DatasetsProtocolPrivate +from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url + +from .config import HuggingfaceDatasetIOConfig + + +def load_hf_dataset(dataset_def: DatasetDef): + if dataset_def.metadata.get("path", None): + return hf_datasets.load_dataset(**dataset_def.metadata) + + df = get_dataframe_from_url(dataset_def.url) + + if df is None: + raise ValueError(f"Failed to load dataset from {dataset_def.url}") + + dataset = hf_datasets.Dataset.from_pandas(df) + return dataset + + +class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): + def __init__(self, config: HuggingfaceDatasetIOConfig) -> None: + self.config = config + # local registry for keeping track of datasets within the provider + self.dataset_infos = {} + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: ... + + async def register_dataset( + self, + dataset_def: DatasetDef, + ) -> None: + self.dataset_infos[dataset_def.identifier] = dataset_def + + async def list_datasets(self) -> List[DatasetDef]: + return list(self.dataset_infos.values()) + + async def get_rows_paginated( + self, + dataset_id: str, + rows_in_page: int, + page_token: Optional[str] = None, + filter_condition: Optional[str] = None, + ) -> PaginatedRowsResult: + dataset_def = self.dataset_infos[dataset_id] + loaded_dataset = load_hf_dataset(dataset_def) + + if page_token and not page_token.isnumeric(): + raise ValueError("Invalid page_token") + + if page_token is None or len(page_token) == 0: + next_page_token = 0 + else: + next_page_token = int(page_token) + + start = next_page_token + if rows_in_page == -1: + end = len(loaded_dataset) + else: + end = min(start + rows_in_page, len(loaded_dataset)) + + rows = [loaded_dataset[i] for i in range(start, end)] + + return PaginatedRowsResult( + rows=rows, + total_count=len(rows), + next_page_token=str(end), + ) diff --git a/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py b/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py index a96d9bcab..a6fe4feb3 100644 --- a/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py @@ -3,20 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import io from typing import List, Optional import pandas from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 -import base64 from abc import ABC, abstractmethod from dataclasses import dataclass -from urllib.parse import unquote from llama_stack.providers.datatypes import DatasetsProtocolPrivate -from llama_stack.providers.utils.memory.vector_store import parse_data_url +from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from .config import MetaReferenceDatasetIOConfig @@ -73,31 +70,9 @@ class PandasDataframeDataset(BaseDataset): if self.df is not None: return - # TODO: more robust support w/ data url - if self.dataset_def.url.uri.endswith(".csv"): - df = pandas.read_csv(self.dataset_def.url.uri) - elif self.dataset_def.url.uri.endswith(".xlsx"): - df = pandas.read_excel(self.dataset_def.url.uri) - elif self.dataset_def.url.uri.startswith("data:"): - parts = parse_data_url(self.dataset_def.url.uri) - data = parts["data"] - if parts["is_base64"]: - data = base64.b64decode(data) - else: - data = unquote(data) - encoding = parts["encoding"] or "utf-8" - data = data.encode(encoding) - - mime_type = parts["mimetype"] - mime_category = mime_type.split("/")[0] - data_bytes = io.BytesIO(data) - - if mime_category == "text": - df = pandas.read_csv(data_bytes) - else: - df = pandas.read_excel(data_bytes) - else: - raise ValueError(f"Unsupported file type: {self.dataset_def.url}") + df = get_dataframe_from_url(self.dataset_def.url) + if df is None: + raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") self.df = self._validate_dataset_schema(df) diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py index 4a61c9d93..48d8e2b04 100644 --- a/llama_stack/providers/inline/meta_reference/eval/eval.py +++ b/llama_stack/providers/inline/meta_reference/eval/eval.py @@ -9,6 +9,8 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from .....apis.common.job_types import Job from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus from llama_stack.apis.common.type_system import * # noqa: F403 +from tqdm import tqdm + from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval_tasks import EvalTaskDef @@ -47,7 +49,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): self.eval_tasks = {} - async def initialize(self) -> None: ... + async def initialize(self) -> None: + pass async def shutdown(self) -> None: ... @@ -93,7 +96,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, - rows_in_page=-1, + rows_in_page=( + -1 if task_config.num_examples is None else task_config.num_examples + ), ) res = await self.evaluate_rows( task_id=task_id, @@ -125,7 +130,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): ), "SamplingParams.max_tokens must be provided" generations = [] - for x in input_rows: + for x in tqdm(input_rows): if ColumnName.completion_input.value in x: input_content = eval(str(x[ColumnName.completion_input.value])) response = await self.inference_api.completion( diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring.py b/llama_stack/providers/inline/meta_reference/scoring/scoring.py index c4add966d..6370ea5e5 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring.py @@ -13,21 +13,14 @@ from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.inference.inference import Inference from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.equality_scoring_fn import ( - EqualityScoringFn, -) - -from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import ( - LlmAsJudgeScoringFn, -) - -from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import ( - SubsetOfScoringFn, -) from .config import MetaReferenceScoringConfig +from .scoring_fn.equality_scoring_fn import EqualityScoringFn +from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn +from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn +from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn -FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn] +FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py index 99fa6cc3a..b54bf7ae8 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py @@ -11,6 +11,5 @@ from llama_stack.apis.scoring_functions import ScoringFnDef equality = ScoringFnDef( identifier="meta-reference::equality", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", - parameters=[], return_type=NumberType(), ) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py index cfef52160..68d77b8df 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py @@ -26,7 +26,6 @@ Total rating: llm_as_judge_8b_correctness = ScoringFnDef( identifier="meta-reference::llm_as_judge_8b_correctness", description="Llm As Judge Scoring Function", - parameters=[], return_type=NumberType(), params=LLMAsJudgeScoringFnParams( prompt_template=JUDGE_PROMPT, diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py new file mode 100644 index 000000000..84e518887 --- /dev/null +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import NumberType + +MULTILINGUAL_ANSWER_REGEXES = [ + r"Answer\s*:", + r"Answer\s*:​​​​​​", # Korean invisible character + r"উত্তর\s*:", + r"उत्तर\s*:", + r"উত্তরঃ", + r"উত্তর\s*:", + r"Antwort\s*:", + r"답변\s*:", + r"정답\s*:", + r"답\s*:", + r"答案\s*:", + r"答案\s*:", + r"答\s*:", + r"答\s*:", + r"答复\s*:", + r"答曰\s*:", + r"الإجابة:", + r"الجواب:", + r"إجابة:", + r"الإجابة النهائية:", + r"الإجابة الصحيحة:", + r"الإجابة الصحيحة هي:", + r"الإجابة هي:", + r"Respuesta\s*:", + r"Risposta\s*:", + r"答え\s*:", + r"答え\s*:", + r"回答\s*:", + r"回答\s*:", + r"解答\s*:", + r"Jawaban\s*:", + r"Réponse\s*:", + r"Resposta\s*:", + r"Jibu\s*:", + r"Idahun\s*:", + r"Ìdáhùn\s*:", + r"Idáhùn\s*:", + r"Àmọ̀nà\s*:", + r"Àdáhùn\s*:", + r"Ànúgọ\s*:", + r"Àṣàyàn\s*:", +] + +MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( + r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" +) + +regex_parser_multiple_choice_answer = ScoringFnDef( + identifier="meta-reference::regex_parser_multiple_choice_answer", + description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result", + return_type=NumberType(), + params=RegexParserScoringFnParams( + parsing_regexes=[ + MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) + for x in MULTILINGUAL_ANSWER_REGEXES + ], + ), +) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py new file mode 100644 index 000000000..0aff2f535 --- /dev/null +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import re + +from .base_scoring_fn import BaseScoringFn +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from .common import aggregate_accuracy + +from .fn_defs.regex_parser_multiple_choice_answer import ( + regex_parser_multiple_choice_answer, +) + + +class RegexParserScoringFn(BaseScoringFn): + """ + A scoring_fn that parses answer from generated response according to context and check match with expected_answer. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.supported_fn_defs_registry = { + regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer, + } + + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, + ) -> ScoringResultRow: + assert ( + scoring_fn_identifier is not None + ), "Scoring function identifier not found." + fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] + if scoring_params is not None: + fn_def.params = scoring_params + + assert ( + fn_def.params is not None + and fn_def.params.type == ScoringConfigType.regex_parser.value + ), f"RegexParserScoringFnParams not found for {fn_def}." + + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + + # parse answer according to regex + parsed_answer = None + for regex in fn_def.params.parsing_regexes: + match = re.search(regex, generated_answer) + if match: + parsed_answer = match.group(1) + break + + score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0 + return { + "score": score, + } + + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: + return aggregate_accuracy(scoring_results) diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index 976bbd448..3fdeac997 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -19,4 +19,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.meta_reference.datasetio.MetaReferenceDatasetIOConfig", api_dependencies=[], ), + remote_provider_spec( + api=Api.datasetio, + adapter=AdapterSpec( + adapter_type="huggingface", + pip_packages=[ + "datasets", + ], + module="llama_stack.providers.adapters.datasetio.huggingface", + config_class="llama_stack.providers.adapters.datasetio.huggingface.HuggingfaceDatasetIOConfig", + ), + ), ] diff --git a/llama_stack/providers/tests/datasetio/fixtures.py b/llama_stack/providers/tests/datasetio/fixtures.py index 7d7615b55..d810d5e02 100644 --- a/llama_stack/providers/tests/datasetio/fixtures.py +++ b/llama_stack/providers/tests/datasetio/fixtures.py @@ -31,7 +31,20 @@ def datasetio_meta_reference() -> ProviderFixture: ) -DATASETIO_FIXTURES = ["meta_reference", "remote"] +@pytest.fixture(scope="session") +def datasetio_huggingface() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="huggingface", + provider_type="remote::huggingface", + config={}, + ) + ], + ) + + +DATASETIO_FIXTURES = ["meta_reference", "remote", "huggingface"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/eval/conftest.py b/llama_stack/providers/tests/eval/conftest.py index 064feb611..985a8bc37 100644 --- a/llama_stack/providers/tests/eval/conftest.py +++ b/llama_stack/providers/tests/eval/conftest.py @@ -34,6 +34,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="meta_reference_eval_together_inference", marks=pytest.mark.meta_reference_eval_together_inference, ), + pytest.param( + { + "eval": "meta_reference", + "scoring": "meta_reference", + "datasetio": "huggingface", + "inference": "together", + }, + id="meta_reference_eval_together_inference_huggingface_datasetio", + marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio, + ), ] @@ -41,6 +51,7 @@ def pytest_configure(config): for fixture_name in [ "meta_reference_eval_fireworks_inference", "meta_reference_eval_together_inference", + "meta_reference_eval_together_inference_huggingface_datasetio", ]: config.addinivalue_line( "markers", diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index a55a754c5..fdd4dcfbb 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -7,10 +7,15 @@ import pytest -from llama_models.llama3.api import SamplingParams +from llama_models.llama3.api import SamplingParams, URL + +from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType + +from llama_stack.apis.datasetio.datasetio import DatasetDefWithProvider from llama_stack.apis.eval.eval import ( AppEvalTaskConfig, + BenchmarkEvalTaskConfig, EvalTaskDefWithProvider, ModelCandidate, ) @@ -21,7 +26,7 @@ from llama_stack.providers.tests.datasetio.test_datasetio import register_datase # How to run this test: # # pytest llama_stack/providers/tests/eval/test_eval.py -# -m "meta_reference" +# -m "meta_reference_eval_together_inference_huggingface_datasetio" # -v -s --tb=short --disable-warnings @@ -33,21 +38,26 @@ class Testeval: eval_tasks_impl = eval_stack[Api.eval_tasks] response = await eval_tasks_impl.list_eval_tasks() assert isinstance(response, list) - assert len(response) == 0 @pytest.mark.asyncio async def test_eval_evaluate_rows(self, eval_stack): - eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl = ( + eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl, models_impl = ( eval_stack[Api.eval], eval_stack[Api.eval_tasks], eval_stack[Api.datasetio], eval_stack[Api.datasets], + eval_stack[Api.models], ) + for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]: + await models_impl.register_model( + model_id=model_id, + provider_id="", + ) await register_dataset( datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" ) response = await datasets_impl.list_datasets() - assert len(response) == 1 + rows = await datasetio_impl.get_rows_paginated( dataset_id="test_dataset_for_eval", rows_in_page=3, @@ -66,7 +76,6 @@ class Testeval: provider_id="meta-reference", ) await eval_tasks_impl.register_eval_task(task_def) - response = await eval_impl.evaluate_rows( task_id=task_id, input_rows=rows.rows, @@ -84,11 +93,17 @@ class Testeval: @pytest.mark.asyncio async def test_eval_run_eval(self, eval_stack): - eval_impl, eval_tasks_impl, datasets_impl = ( + eval_impl, eval_tasks_impl, datasets_impl, models_impl = ( eval_stack[Api.eval], eval_stack[Api.eval_tasks], eval_stack[Api.datasets], + eval_stack[Api.models], ) + for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]: + await models_impl.register_model( + model_id=model_id, + provider_id="", + ) await register_dataset( datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" ) @@ -124,3 +139,72 @@ class Testeval: assert len(eval_response.generations) == 5 assert "meta-reference::subset_of" in eval_response.scores assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores + + @pytest.mark.asyncio + async def test_eval_run_benchmark_eval(self, eval_stack): + eval_impl, eval_tasks_impl, datasets_impl, models_impl = ( + eval_stack[Api.eval], + eval_stack[Api.eval_tasks], + eval_stack[Api.datasets], + eval_stack[Api.models], + ) + for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]: + await models_impl.register_model( + model_id=model_id, + provider_id="", + ) + response = await datasets_impl.list_datasets() + assert len(response) > 0 + if response[0].provider_id != "huggingface": + pytest.skip( + "Only huggingface provider supports pre-registered remote datasets" + ) + # register dataset + mmlu = DatasetDefWithProvider( + identifier="mmlu", + url=URL(uri="https://huggingface.co/datasets/llamastack/evals"), + dataset_schema={ + "input_query": StringType(), + "expected_answer": StringType(), + "chat_completion_input": ChatCompletionInputType(), + }, + metadata={ + "path": "llamastack/evals", + "name": "evals__mmlu__details", + "split": "train", + }, + provider_id="", + ) + + await datasets_impl.register_dataset(mmlu) + + # register eval task + meta_reference_mmlu = EvalTaskDefWithProvider( + identifier="meta-reference-mmlu", + dataset_id="mmlu", + scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"], + provider_id="", + ) + + await eval_tasks_impl.register_eval_task(meta_reference_mmlu) + + # list benchmarks + response = await eval_tasks_impl.list_eval_tasks() + assert len(response) > 0 + + benchmark_id = "meta-reference-mmlu" + response = await eval_impl.run_eval( + task_id=benchmark_id, + task_config=BenchmarkEvalTaskConfig( + eval_candidate=ModelCandidate( + model="Llama3.2-3B-Instruct", + sampling_params=SamplingParams(), + ), + num_examples=3, + ), + ) + job_status = await eval_impl.job_status(benchmark_id, response.job_id) + assert job_status and job_status.value == "completed" + eval_response = await eval_impl.job_result(benchmark_id, response.job_id) + assert eval_response is not None + assert len(eval_response.generations) == 3 diff --git a/llama_stack/providers/utils/datasetio/__init__.py b/llama_stack/providers/utils/datasetio/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/utils/datasetio/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/utils/datasetio/url_utils.py b/llama_stack/providers/utils/datasetio/url_utils.py new file mode 100644 index 000000000..3faea9f95 --- /dev/null +++ b/llama_stack/providers/utils/datasetio/url_utils.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import base64 +import io +from urllib.parse import unquote + +import pandas + +from llama_models.llama3.api.datatypes import URL + +from llama_stack.providers.utils.memory.vector_store import parse_data_url + + +def get_dataframe_from_url(url: URL): + df = None + if url.uri.endswith(".csv"): + df = pandas.read_csv(url.uri) + elif url.uri.endswith(".xlsx"): + df = pandas.read_excel(url.uri) + elif url.uri.startswith("data:"): + parts = parse_data_url(url.uri) + data = parts["data"] + if parts["is_base64"]: + data = base64.b64decode(data) + else: + data = unquote(data) + encoding = parts["encoding"] or "utf-8" + data = data.encode(encoding) + + mime_type = parts["mimetype"] + mime_category = mime_type.split("/")[0] + data_bytes = io.BytesIO(data) + + if mime_category == "text": + df = pandas.read_csv(data_bytes) + else: + df = pandas.read_excel(data_bytes) + else: + raise ValueError(f"Unsupported file type: {url}") + + return df