From 0f50cfa561f427dd4c8c11b876f8d80f08744847 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Mon, 24 Mar 2025 20:54:04 -0400 Subject: [PATCH] feat(api): define a more coherent jobs api across different flows Signed-off-by: Ihar Hrachyshka --- docs/_static/llama-stack-spec.html | 1607 +++++++++-------- docs/_static/llama-stack-spec.yaml | 1103 ++++++----- .../apis/batch_inference/batch_inference.py | 18 +- llama_stack/apis/common/job_types.py | 60 +- llama_stack/apis/eval/eval.py | 91 +- .../apis/post_training/post_training.py | 64 +- .../synthetic_data_generation.py | 35 +- llama_stack/distribution/routers/routers.py | 63 +- .../inline/eval/meta_reference/eval.py | 73 +- .../post_training/torchtune/post_training.py | 88 +- .../recipes/lora_finetuning_single_device.py | 25 +- .../post_training/nvidia/post_training.py | 175 +- llama_stack/strong_typing/inspection.py | 12 +- .../post_training/test_post_training.py | 6 +- .../nvidia/test_supervised_fine_tuning.py | 114 +- 15 files changed, 1864 insertions(+), 1670 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 54d888441..3a494e06c 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -171,42 +171,6 @@ } } }, - "/v1/post-training/job/cancel": { - "post": { - "responses": { - "200": { - "description": "OK" - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "PostTraining (Coming Soon)" - ], - "description": "", - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/CancelTrainingJobRequest" - } - } - }, - "required": true - } - } - }, "/v1/inference/chat-completion": { "post": { "responses": { @@ -764,6 +728,41 @@ ] } }, + "/v1/evaluate/job/{job_id}": { + "delete": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Eval" + ], + "description": "", + "parameters": [ + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ] + } + }, "/v1/files/{bucket}/{key}": { "get": { "responses": { @@ -859,6 +858,76 @@ ] } }, + "/v1/post-training/job/{job_id}": { + "delete": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "PostTraining (Coming Soon)" + ], + "description": "", + "parameters": [ + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ] + } + }, + "/v1/synthetic-data-generation/job/{job_id}": { + "delete": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "SyntheticDataGeneration (Coming Soon)" + ], + "description": "", + "parameters": [ + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ] + } + }, "/v1/inference/embeddings": { "post": { "responses": { @@ -902,15 +971,15 @@ } } }, - "/v1/eval/benchmarks/{benchmark_id}/evaluations": { + "/v1/eval/benchmarks/{benchmark_id}/evaluate": { "post": { "responses": { "200": { - "description": "EvaluateResponse object containing generations and scores", + "description": "OK", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/EvaluateResponse" + "$ref": "#/components/schemas/EvaluateJob" } } } @@ -931,12 +1000,11 @@ "tags": [ "Eval" ], - "description": "Evaluate a list of rows on a benchmark.", + "description": "", "parameters": [ { "name": "benchmark_id", "in": "path", - "description": "The ID of the benchmark to run the evaluation on.", "required": true, "schema": { "type": "string" @@ -947,7 +1015,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/EvaluateRowsRequest" + "$ref": "#/components/schemas/EvaluateRequest" } } }, @@ -1203,6 +1271,89 @@ ] } }, + "/v1/evaluate/jobs/{job_id}": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/EvaluateJob" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Eval" + ], + "description": "", + "parameters": [ + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ] + }, + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/EvaluateJob" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Eval" + ], + "description": "", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UpdateEvaluateJobRequest" + } + } + }, + "required": true + } + } + }, "/v1/models/{model_id}": { "get": { "responses": { @@ -1278,6 +1429,89 @@ ] } }, + "/v1/post-training/jobs/{job_id}": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PostTrainingJob" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "PostTraining (Coming Soon)" + ], + "description": "", + "parameters": [ + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ] + }, + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PostTrainingJob" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "PostTraining (Coming Soon)" + ], + "description": "", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UpdatePostTrainingJobRequest" + } + } + }, + "required": true + } + } + }, "/v1/scoring-functions/{scoring_fn_id}": { "get": { "responses": { @@ -1464,6 +1698,80 @@ } } }, + "/v1/synthetic-data-generation/jobs/{job_id}": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SyntheticDataGenerationJob" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "SyntheticDataGeneration (Coming Soon)" + ], + "description": "", + "parameters": [] + }, + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SyntheticDataGenerationJob" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "SyntheticDataGeneration (Coming Soon)" + ], + "description": "", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UpdateSyntheticDataGenerationJobRequest" + } + } + }, + "required": true + } + } + }, "/v1/tools/{tool_name}": { "get": { "responses": { @@ -1623,123 +1931,6 @@ ] } }, - "/v1/post-training/job/artifacts": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/PostTrainingJobArtifactsResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "PostTraining (Coming Soon)" - ], - "description": "", - "parameters": [ - { - "name": "job_uuid", - "in": "query", - "required": true, - "schema": { - "type": "string" - } - } - ] - } - }, - "/v1/post-training/job/status": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/PostTrainingJobStatusResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "PostTraining (Coming Soon)" - ], - "description": "", - "parameters": [ - { - "name": "job_uuid", - "in": "query", - "required": true, - "schema": { - "type": "string" - } - } - ] - } - }, - "/v1/post-training/jobs": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ListPostTrainingJobsResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "PostTraining (Coming Soon)" - ], - "description": "", - "parameters": [] - } - }, "/v1/files/session:{upload_id}": { "get": { "responses": { @@ -2168,153 +2359,6 @@ ] } }, - "/v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}": { - "get": { - "responses": { - "200": { - "description": "The status of the evaluationjob.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/Job" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Eval" - ], - "description": "Get the status of a job.", - "parameters": [ - { - "name": "benchmark_id", - "in": "path", - "description": "The ID of the benchmark to run the evaluation on.", - "required": true, - "schema": { - "type": "string" - } - }, - { - "name": "job_id", - "in": "path", - "description": "The ID of the job to get the status of.", - "required": true, - "schema": { - "type": "string" - } - } - ] - }, - "delete": { - "responses": { - "200": { - "description": "OK" - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Eval" - ], - "description": "Cancel a job.", - "parameters": [ - { - "name": "benchmark_id", - "in": "path", - "description": "The ID of the benchmark to run the evaluation on.", - "required": true, - "schema": { - "type": "string" - } - }, - { - "name": "job_id", - "in": "path", - "description": "The ID of the job to cancel.", - "required": true, - "schema": { - "type": "string" - } - } - ] - } - }, - "/v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result": { - "get": { - "responses": { - "200": { - "description": "The result of the job.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/EvaluateResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Eval" - ], - "description": "Get the result of a job.", - "parameters": [ - { - "name": "benchmark_id", - "in": "path", - "description": "The ID of the benchmark to run the evaluation on.", - "required": true, - "schema": { - "type": "string" - } - }, - { - "name": "job_id", - "in": "path", - "description": "The ID of the job to get the result of.", - "required": true, - "schema": { - "type": "string" - } - } - ] - } - }, "/v1/agents/{agent_id}/sessions": { "get": { "responses": { @@ -2499,6 +2543,39 @@ } } }, + "/v1/evaluate/jobs": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListEvaluateJobsResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Eval" + ], + "description": "", + "parameters": [] + } + }, "/v1/files/{bucket}": { "get": { "responses": { @@ -2616,6 +2693,39 @@ } } }, + "/v1/post-training/jobs": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListPostTrainingJobsResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "PostTraining (Coming Soon)" + ], + "description": "", + "parameters": [] + } + }, "/v1/providers": { "get": { "responses": { @@ -2873,6 +2983,39 @@ } } }, + "/v1/synthetic-data-generation/jobs": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListSyntheticDataGenerationJobsResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "SyntheticDataGeneration (Coming Soon)" + ], + "description": "", + "parameters": [] + } + }, "/v1/toolgroups": { "get": { "responses": { @@ -3509,59 +3652,6 @@ } } }, - "/v1/eval/benchmarks/{benchmark_id}/jobs": { - "post": { - "responses": { - "200": { - "description": "The job that was created to run the evaluation.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/Job" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Eval" - ], - "description": "Run an evaluation on a benchmark.", - "parameters": [ - { - "name": "benchmark_id", - "in": "path", - "description": "The ID of the benchmark to run the evaluation on.", - "required": true, - "schema": { - "type": "string" - } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/RunEvalRequest" - } - } - }, - "required": true - } - } - }, "/v1/safety/run-shield": { "post": { "responses": { @@ -3778,7 +3868,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/SyntheticDataGenerationResponse" + "$ref": "#/components/schemas/SyntheticDataGenerationJob" } } } @@ -4834,19 +4924,6 @@ "title": "CompletionResponse", "description": "Response from a completion request." }, - "CancelTrainingJobRequest": { - "type": "object", - "properties": { - "job_uuid": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "job_uuid" - ], - "title": "CancelTrainingJobRequest" - }, "ChatCompletionRequest": { "type": "object", "properties": { @@ -6525,139 +6602,29 @@ } } }, - "EvaluateRowsRequest": { + "EvaluateRequest": { "type": "object", "properties": { - "input_rows": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "description": "The rows to evaluate." - }, - "scoring_functions": { - "type": "array", - "items": { - "type": "string" - }, - "description": "The scoring functions to use for the evaluation." - }, "benchmark_config": { - "$ref": "#/components/schemas/BenchmarkConfig", - "description": "The configuration for the benchmark." + "$ref": "#/components/schemas/BenchmarkConfig" } }, "additionalProperties": false, "required": [ - "input_rows", - "scoring_functions", "benchmark_config" ], - "title": "EvaluateRowsRequest" + "title": "EvaluateRequest" }, - "EvaluateResponse": { + "JobArtifact": { "type": "object", "properties": { - "generations": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "description": "The generations from the evaluation." + "name": { + "type": "string" }, - "scores": { - "type": "object", - "additionalProperties": { - "$ref": "#/components/schemas/ScoringResult" - }, - "description": "The scores from the evaluation." - } - }, - "additionalProperties": false, - "required": [ - "generations", - "scores" - ], - "title": "EvaluateResponse", - "description": "The response from an evaluation." - }, - "ScoringResult": { - "type": "object", - "properties": { - "score_rows": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "description": "The scoring result for each row. Each row is a map of column name to value." + "type": { + "type": "string" }, - "aggregated_results": { + "metadata": { "type": "object", "additionalProperties": { "oneOf": [ @@ -6680,17 +6647,100 @@ "type": "object" } ] - }, - "description": "Map of metric name to aggregated value" + } + }, + "uri": { + "type": "string" } }, "additionalProperties": false, "required": [ - "score_rows", - "aggregated_results" + "name", + "type" ], - "title": "ScoringResult", - "description": "A scoring result for a single row." + "title": "JobArtifact" + }, + "JobStatusDetails": { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": [ + "unknown", + "new", + "scheduled", + "running", + "paused", + "resuming", + "cancelled", + "failed", + "completed" + ], + "title": "JobStatus" + }, + "message": { + "type": "string" + }, + "timestamp": { + "type": "string", + "format": "date-time" + } + }, + "additionalProperties": false, + "required": [ + "status", + "timestamp" + ], + "title": "JobStatusDetails" + }, + "EvaluateJob": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/components/schemas/JobArtifact" + } + }, + "events": { + "type": "array", + "items": { + "$ref": "#/components/schemas/JobStatusDetails" + } + }, + "type": { + "type": "string", + "const": "eval", + "default": "eval" + }, + "status": { + "type": "string", + "enum": [ + "unknown", + "new", + "scheduled", + "running", + "paused", + "resuming", + "cancelled", + "failed", + "completed" + ], + "title": "JobStatus" + } + }, + "additionalProperties": false, + "required": [ + "id", + "artifacts", + "events", + "type", + "status" + ], + "title": "EvaluateJob" }, "Agent": { "type": "object", @@ -7101,6 +7151,55 @@ ], "title": "ModelType" }, + "PostTrainingJob": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/components/schemas/JobArtifact" + } + }, + "events": { + "type": "array", + "items": { + "$ref": "#/components/schemas/JobStatusDetails" + } + }, + "type": { + "type": "string", + "const": "post-training", + "default": "post-training" + }, + "status": { + "type": "string", + "enum": [ + "unknown", + "new", + "scheduled", + "running", + "paused", + "resuming", + "cancelled", + "failed", + "completed" + ], + "title": "JobStatus" + } + }, + "additionalProperties": false, + "required": [ + "id", + "artifacts", + "events", + "type", + "status" + ], + "title": "PostTrainingJob" + }, "AgentTurnInputType": { "type": "object", "properties": { @@ -7575,6 +7674,55 @@ ], "title": "QuerySpanTreeResponse" }, + "SyntheticDataGenerationJob": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/components/schemas/JobArtifact" + } + }, + "events": { + "type": "array", + "items": { + "$ref": "#/components/schemas/JobStatusDetails" + } + }, + "type": { + "type": "string", + "const": "synthetic-data-generation", + "default": "synthetic-data-generation" + }, + "status": { + "type": "string", + "enum": [ + "unknown", + "new", + "scheduled", + "running", + "paused", + "resuming", + "cancelled", + "failed", + "completed" + ], + "title": "JobStatus" + } + }, + "additionalProperties": false, + "required": [ + "id", + "artifacts", + "events", + "type", + "status" + ], + "title": "SyntheticDataGenerationJob" + }, "Tool": { "type": "object", "properties": { @@ -7736,127 +7884,6 @@ ], "title": "Trace" }, - "Checkpoint": { - "description": "Checkpoint created during training runs", - "title": "Checkpoint" - }, - "PostTrainingJobArtifactsResponse": { - "type": "object", - "properties": { - "job_uuid": { - "type": "string" - }, - "checkpoints": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Checkpoint" - } - } - }, - "additionalProperties": false, - "required": [ - "job_uuid", - "checkpoints" - ], - "title": "PostTrainingJobArtifactsResponse", - "description": "Artifacts of a finetuning job." - }, - "PostTrainingJobStatusResponse": { - "type": "object", - "properties": { - "job_uuid": { - "type": "string" - }, - "status": { - "type": "string", - "enum": [ - "completed", - "in_progress", - "failed", - "scheduled", - "cancelled" - ], - "title": "JobStatus" - }, - "scheduled_at": { - "type": "string", - "format": "date-time" - }, - "started_at": { - "type": "string", - "format": "date-time" - }, - "completed_at": { - "type": "string", - "format": "date-time" - }, - "resources_allocated": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "checkpoints": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Checkpoint" - } - } - }, - "additionalProperties": false, - "required": [ - "job_uuid", - "status", - "checkpoints" - ], - "title": "PostTrainingJobStatusResponse", - "description": "Status of a finetuning job." - }, - "ListPostTrainingJobsResponse": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "type": "object", - "properties": { - "job_uuid": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "job_uuid" - ], - "title": "PostTrainingJob" - } - } - }, - "additionalProperties": false, - "required": [ - "data" - ], - "title": "ListPostTrainingJobsResponse" - }, "VectorDB": { "type": "object", "properties": { @@ -8259,31 +8286,6 @@ "title": "PaginatedResponse", "description": "A generic paginated response that follows a simple format." }, - "Job": { - "type": "object", - "properties": { - "job_id": { - "type": "string" - }, - "status": { - "type": "string", - "enum": [ - "completed", - "in_progress", - "failed", - "scheduled", - "cancelled" - ], - "title": "JobStatus" - } - }, - "additionalProperties": false, - "required": [ - "job_id", - "status" - ], - "title": "Job" - }, "ListAgentSessionsResponse": { "type": "object", "properties": { @@ -8379,6 +8381,68 @@ ], "title": "ListDatasetsResponse" }, + "ListEvaluateJobsResponse": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/components/schemas/JobArtifact" + } + }, + "events": { + "type": "array", + "items": { + "$ref": "#/components/schemas/JobStatusDetails" + } + }, + "type": { + "type": "string", + "const": "eval", + "default": "eval" + }, + "status": { + "type": "string", + "enum": [ + "unknown", + "new", + "scheduled", + "running", + "paused", + "resuming", + "cancelled", + "failed", + "completed" + ], + "title": "JobStatus" + } + }, + "additionalProperties": false, + "required": [ + "id", + "artifacts", + "events", + "type", + "status" + ], + "title": "EvaluateJob" + } + } + }, + "additionalProperties": false, + "required": [ + "data" + ], + "title": "ListEvaluateJobsResponse" + }, "ListFileResponse": { "type": "object", "properties": { @@ -8413,6 +8477,22 @@ ], "title": "ListModelsResponse" }, + "ListPostTrainingJobsResponse": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/PostTrainingJob" + } + } + }, + "additionalProperties": false, + "required": [ + "data" + ], + "title": "ListPostTrainingJobsResponse" + }, "ListProvidersResponse": { "type": "object", "properties": { @@ -8517,6 +8597,22 @@ ], "title": "ListShieldsResponse" }, + "ListSyntheticDataGenerationJobsResponse": { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "$ref": "#/components/schemas/SyntheticDataGenerationJob" + } + } + }, + "additionalProperties": false, + "required": [ + "items" + ], + "title": "ListSyntheticDataGenerationJobsResponse" + }, "ListToolGroupsResponse": { "type": "object", "properties": { @@ -10279,19 +10375,6 @@ ], "title": "PreferenceOptimizeRequest" }, - "PostTrainingJob": { - "type": "object", - "properties": { - "job_uuid": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "job_uuid" - ], - "title": "PostTrainingJob" - }, "DefaultRAGQueryGeneratorConfig": { "type": "object", "properties": { @@ -10992,20 +11075,6 @@ ], "title": "ResumeAgentTurnRequest" }, - "RunEvalRequest": { - "type": "object", - "properties": { - "benchmark_config": { - "$ref": "#/components/schemas/BenchmarkConfig", - "description": "The configuration for the benchmark." - } - }, - "additionalProperties": false, - "required": [ - "benchmark_config" - ], - "title": "RunEvalRequest" - }, "RunShieldRequest": { "type": "object", "properties": { @@ -11164,6 +11233,73 @@ "title": "ScoreResponse", "description": "The response from scoring." }, + "ScoringResult": { + "type": "object", + "properties": { + "score_rows": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "description": "The scoring result for each row. Each row is a map of column name to value." + }, + "aggregated_results": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + }, + "description": "Map of metric name to aggregated value" + } + }, + "additionalProperties": false, + "required": [ + "score_rows", + "aggregated_results" + ], + "title": "ScoringResult", + "description": "A scoring result for a single row." + }, "ScoreBatchRequest": { "type": "object", "properties": { @@ -11411,69 +11547,90 @@ ], "title": "SyntheticDataGenerateRequest" }, - "SyntheticDataGenerationResponse": { + "UpdateEvaluateJobRequest": { "type": "object", "properties": { - "synthetic_data": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "statistics": { + "job": { "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" + "properties": { + "id": { + "type": "string" + }, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/components/schemas/JobArtifact" } - ] - } + }, + "events": { + "type": "array", + "items": { + "$ref": "#/components/schemas/JobStatusDetails" + } + }, + "type": { + "type": "string", + "const": "eval", + "default": "eval" + }, + "status": { + "type": "string", + "enum": [ + "unknown", + "new", + "scheduled", + "running", + "paused", + "resuming", + "cancelled", + "failed", + "completed" + ], + "title": "JobStatus" + } + }, + "additionalProperties": false, + "required": [ + "id", + "artifacts", + "events", + "type", + "status" + ], + "title": "EvaluateJob" } }, "additionalProperties": false, "required": [ - "synthetic_data" + "job" ], - "title": "SyntheticDataGenerationResponse", - "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." + "title": "UpdateEvaluateJobRequest" + }, + "UpdatePostTrainingJobRequest": { + "type": "object", + "properties": { + "job": { + "$ref": "#/components/schemas/PostTrainingJob" + } + }, + "additionalProperties": false, + "required": [ + "job" + ], + "title": "UpdatePostTrainingJobRequest" + }, + "UpdateSyntheticDataGenerationJobRequest": { + "type": "object", + "properties": { + "job": { + "$ref": "#/components/schemas/SyntheticDataGenerationJob" + } + }, + "additionalProperties": false, + "required": [ + "job" + ], + "title": "UpdateSyntheticDataGenerationJobRequest" }, "VersionInfo": { "type": "object", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index cf657bff9..b6dc8a61d 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -98,31 +98,6 @@ paths: schema: $ref: '#/components/schemas/BatchCompletionRequest' required: true - /v1/post-training/job/cancel: - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - PostTraining (Coming Soon) - description: '' - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/CancelTrainingJobRequest' - required: true /v1/inference/chat-completion: post: responses: @@ -516,6 +491,30 @@ paths: required: true schema: type: string + /v1/evaluate/job/{job_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Eval + description: '' + parameters: + - name: job_id + in: path + required: true + schema: + type: string /v1/files/{bucket}/{key}: get: responses: @@ -585,6 +584,54 @@ paths: required: true schema: type: string + /v1/post-training/job/{job_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - PostTraining (Coming Soon) + description: '' + parameters: + - name: job_id + in: path + required: true + schema: + type: string + /v1/synthetic-data-generation/job/{job_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - SyntheticDataGeneration (Coming Soon) + description: '' + parameters: + - name: job_id + in: path + required: true + schema: + type: string /v1/inference/embeddings: post: responses: @@ -618,16 +665,15 @@ paths: schema: $ref: '#/components/schemas/EmbeddingsRequest' required: true - /v1/eval/benchmarks/{benchmark_id}/evaluations: + /v1/eval/benchmarks/{benchmark_id}/evaluate: post: responses: '200': - description: >- - EvaluateResponse object containing generations and scores + description: OK content: application/json: schema: - $ref: '#/components/schemas/EvaluateResponse' + $ref: '#/components/schemas/EvaluateJob' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -640,12 +686,10 @@ paths: $ref: '#/components/responses/DefaultError' tags: - Eval - description: Evaluate a list of rows on a benchmark. + description: '' parameters: - name: benchmark_id in: path - description: >- - The ID of the benchmark to run the evaluation on. required: true schema: type: string @@ -653,7 +697,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/EvaluateRowsRequest' + $ref: '#/components/schemas/EvaluateRequest' required: true /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}: get: @@ -824,6 +868,62 @@ paths: required: true schema: type: string + /v1/evaluate/jobs/{job_id}: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluateJob' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Eval + description: '' + parameters: + - name: job_id + in: path + required: true + schema: + type: string + post: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluateJob' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Eval + description: '' + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateEvaluateJobRequest' + required: true /v1/models/{model_id}: get: responses: @@ -875,6 +975,62 @@ paths: required: true schema: type: string + /v1/post-training/jobs/{job_id}: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingJob' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - PostTraining (Coming Soon) + description: '' + parameters: + - name: job_id + in: path + required: true + schema: + type: string + post: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingJob' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - PostTraining (Coming Soon) + description: '' + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UpdatePostTrainingJobRequest' + required: true /v1/scoring-functions/{scoring_fn_id}: get: responses: @@ -998,6 +1154,57 @@ paths: schema: $ref: '#/components/schemas/GetSpanTreeRequest' required: true + /v1/synthetic-data-generation/jobs/{job_id}: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/SyntheticDataGenerationJob' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - SyntheticDataGeneration (Coming Soon) + description: '' + parameters: [] + post: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/SyntheticDataGenerationJob' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - SyntheticDataGeneration (Coming Soon) + description: '' + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateSyntheticDataGenerationJobRequest' + required: true /v1/tools/{tool_name}: get: responses: @@ -1105,85 +1312,6 @@ paths: required: true schema: type: string - /v1/post-training/job/artifacts: - get: - responses: - '200': - description: OK - content: - application/json: - schema: - $ref: '#/components/schemas/PostTrainingJobArtifactsResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - PostTraining (Coming Soon) - description: '' - parameters: - - name: job_uuid - in: query - required: true - schema: - type: string - /v1/post-training/job/status: - get: - responses: - '200': - description: OK - content: - application/json: - schema: - $ref: '#/components/schemas/PostTrainingJobStatusResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - PostTraining (Coming Soon) - description: '' - parameters: - - name: job_uuid - in: query - required: true - schema: - type: string - /v1/post-training/jobs: - get: - responses: - '200': - description: OK - content: - application/json: - schema: - $ref: '#/components/schemas/ListPostTrainingJobsResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - PostTraining (Coming Soon) - description: '' - parameters: [] /v1/files/session:{upload_id}: get: responses: @@ -1492,109 +1620,6 @@ paths: required: false schema: type: integer - /v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}: - get: - responses: - '200': - description: The status of the evaluationjob. - content: - application/json: - schema: - $ref: '#/components/schemas/Job' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Eval - description: Get the status of a job. - parameters: - - name: benchmark_id - in: path - description: >- - The ID of the benchmark to run the evaluation on. - required: true - schema: - type: string - - name: job_id - in: path - description: The ID of the job to get the status of. - required: true - schema: - type: string - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Eval - description: Cancel a job. - parameters: - - name: benchmark_id - in: path - description: >- - The ID of the benchmark to run the evaluation on. - required: true - schema: - type: string - - name: job_id - in: path - description: The ID of the job to cancel. - required: true - schema: - type: string - /v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result: - get: - responses: - '200': - description: The result of the job. - content: - application/json: - schema: - $ref: '#/components/schemas/EvaluateResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Eval - description: Get the result of a job. - parameters: - - name: benchmark_id - in: path - description: >- - The ID of the benchmark to run the evaluation on. - required: true - schema: - type: string - - name: job_id - in: path - description: The ID of the job to get the result of. - required: true - schema: - type: string /v1/agents/{agent_id}/sessions: get: responses: @@ -1723,6 +1748,29 @@ paths: schema: $ref: '#/components/schemas/RegisterDatasetRequest' required: true + /v1/evaluate/jobs: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/ListEvaluateJobsResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Eval + description: '' + parameters: [] /v1/files/{bucket}: get: responses: @@ -1803,6 +1851,29 @@ paths: schema: $ref: '#/components/schemas/RegisterModelRequest' required: true + /v1/post-training/jobs: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/ListPostTrainingJobsResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - PostTraining (Coming Soon) + description: '' + parameters: [] /v1/providers: get: responses: @@ -1980,6 +2051,29 @@ paths: schema: $ref: '#/components/schemas/RegisterShieldRequest' required: true + /v1/synthetic-data-generation/jobs: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/ListSyntheticDataGenerationJobsResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - SyntheticDataGeneration (Coming Soon) + description: '' + parameters: [] /v1/toolgroups: get: responses: @@ -2423,43 +2517,6 @@ paths: schema: $ref: '#/components/schemas/ResumeAgentTurnRequest' required: true - /v1/eval/benchmarks/{benchmark_id}/jobs: - post: - responses: - '200': - description: >- - The job that was created to run the evaluation. - content: - application/json: - schema: - $ref: '#/components/schemas/Job' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Eval - description: Run an evaluation on a benchmark. - parameters: - - name: benchmark_id - in: path - description: >- - The ID of the benchmark to run the evaluation on. - required: true - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RunEvalRequest' - required: true /v1/safety/run-shield: post: responses: @@ -2610,7 +2667,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/SyntheticDataGenerationResponse' + $ref: '#/components/schemas/SyntheticDataGenerationJob' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -3360,15 +3417,6 @@ components: - stop_reason title: CompletionResponse description: Response from a completion request. - CancelTrainingJobRequest: - type: object - properties: - job_uuid: - type: string - additionalProperties: false - required: - - job_uuid - title: CancelTrainingJobRequest ChatCompletionRequest: type: object properties: @@ -4586,82 +4634,23 @@ components: llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' regex_parser: '#/components/schemas/RegexParserScoringFnParams' basic: '#/components/schemas/BasicScoringFnParams' - EvaluateRowsRequest: + EvaluateRequest: type: object properties: - input_rows: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The rows to evaluate. - scoring_functions: - type: array - items: - type: string - description: >- - The scoring functions to use for the evaluation. benchmark_config: $ref: '#/components/schemas/BenchmarkConfig' - description: The configuration for the benchmark. additionalProperties: false required: - - input_rows - - scoring_functions - benchmark_config - title: EvaluateRowsRequest - EvaluateResponse: + title: EvaluateRequest + JobArtifact: type: object properties: - generations: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The generations from the evaluation. - scores: - type: object - additionalProperties: - $ref: '#/components/schemas/ScoringResult' - description: The scores from the evaluation. - additionalProperties: false - required: - - generations - - scores - title: EvaluateResponse - description: The response from an evaluation. - ScoringResult: - type: object - properties: - score_rows: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - The scoring result for each row. Each row is a map of column name to value. - aggregated_results: + name: + type: string + type: + type: string + metadata: type: object additionalProperties: oneOf: @@ -4671,13 +4660,77 @@ components: - type: string - type: array - type: object - description: Map of metric name to aggregated value + uri: + type: string additionalProperties: false required: - - score_rows - - aggregated_results - title: ScoringResult - description: A scoring result for a single row. + - name + - type + title: JobArtifact + JobStatusDetails: + type: object + properties: + status: + type: string + enum: + - unknown + - new + - scheduled + - running + - paused + - resuming + - cancelled + - failed + - completed + title: JobStatus + message: + type: string + timestamp: + type: string + format: date-time + additionalProperties: false + required: + - status + - timestamp + title: JobStatusDetails + EvaluateJob: + type: object + properties: + id: + type: string + artifacts: + type: array + items: + $ref: '#/components/schemas/JobArtifact' + events: + type: array + items: + $ref: '#/components/schemas/JobStatusDetails' + type: + type: string + const: eval + default: eval + status: + type: string + enum: + - unknown + - new + - scheduled + - running + - paused + - resuming + - cancelled + - failed + - completed + title: JobStatus + additionalProperties: false + required: + - id + - artifacts + - events + - type + - status + title: EvaluateJob Agent: type: object properties: @@ -4951,6 +5004,44 @@ components: - llm - embedding title: ModelType + PostTrainingJob: + type: object + properties: + id: + type: string + artifacts: + type: array + items: + $ref: '#/components/schemas/JobArtifact' + events: + type: array + items: + $ref: '#/components/schemas/JobStatusDetails' + type: + type: string + const: post-training + default: post-training + status: + type: string + enum: + - unknown + - new + - scheduled + - running + - paused + - resuming + - cancelled + - failed + - completed + title: JobStatus + additionalProperties: false + required: + - id + - artifacts + - events + - type + - status + title: PostTrainingJob AgentTurnInputType: type: object properties: @@ -5254,6 +5345,44 @@ components: required: - data title: QuerySpanTreeResponse + SyntheticDataGenerationJob: + type: object + properties: + id: + type: string + artifacts: + type: array + items: + $ref: '#/components/schemas/JobArtifact' + events: + type: array + items: + $ref: '#/components/schemas/JobStatusDetails' + type: + type: string + const: synthetic-data-generation + default: synthetic-data-generation + status: + type: string + enum: + - unknown + - new + - scheduled + - running + - paused + - resuming + - cancelled + - failed + - completed + title: JobStatus + additionalProperties: false + required: + - id + - artifacts + - events + - type + - status + title: SyntheticDataGenerationJob Tool: type: object properties: @@ -5356,86 +5485,6 @@ components: - root_span_id - start_time title: Trace - Checkpoint: - description: Checkpoint created during training runs - title: Checkpoint - PostTrainingJobArtifactsResponse: - type: object - properties: - job_uuid: - type: string - checkpoints: - type: array - items: - $ref: '#/components/schemas/Checkpoint' - additionalProperties: false - required: - - job_uuid - - checkpoints - title: PostTrainingJobArtifactsResponse - description: Artifacts of a finetuning job. - PostTrainingJobStatusResponse: - type: object - properties: - job_uuid: - type: string - status: - type: string - enum: - - completed - - in_progress - - failed - - scheduled - - cancelled - title: JobStatus - scheduled_at: - type: string - format: date-time - started_at: - type: string - format: date-time - completed_at: - type: string - format: date-time - resources_allocated: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - checkpoints: - type: array - items: - $ref: '#/components/schemas/Checkpoint' - additionalProperties: false - required: - - job_uuid - - status - - checkpoints - title: PostTrainingJobStatusResponse - description: Status of a finetuning job. - ListPostTrainingJobsResponse: - type: object - properties: - data: - type: array - items: - type: object - properties: - job_uuid: - type: string - additionalProperties: false - required: - - job_uuid - title: PostTrainingJob - additionalProperties: false - required: - - data - title: ListPostTrainingJobsResponse VectorDB: type: object properties: @@ -5669,25 +5718,6 @@ components: title: PaginatedResponse description: >- A generic paginated response that follows a simple format. - Job: - type: object - properties: - job_id: - type: string - status: - type: string - enum: - - completed - - in_progress - - failed - - scheduled - - cancelled - title: JobStatus - additionalProperties: false - required: - - job_id - - status - title: Job ListAgentSessionsResponse: type: object properties: @@ -5755,6 +5785,53 @@ components: required: - data title: ListDatasetsResponse + ListEvaluateJobsResponse: + type: object + properties: + data: + type: array + items: + type: object + properties: + id: + type: string + artifacts: + type: array + items: + $ref: '#/components/schemas/JobArtifact' + events: + type: array + items: + $ref: '#/components/schemas/JobStatusDetails' + type: + type: string + const: eval + default: eval + status: + type: string + enum: + - unknown + - new + - scheduled + - running + - paused + - resuming + - cancelled + - failed + - completed + title: JobStatus + additionalProperties: false + required: + - id + - artifacts + - events + - type + - status + title: EvaluateJob + additionalProperties: false + required: + - data + title: ListEvaluateJobsResponse ListFileResponse: type: object properties: @@ -5780,6 +5857,17 @@ components: required: - data title: ListModelsResponse + ListPostTrainingJobsResponse: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/PostTrainingJob' + additionalProperties: false + required: + - data + title: ListPostTrainingJobsResponse ListProvidersResponse: type: object properties: @@ -5852,6 +5940,17 @@ components: required: - data title: ListShieldsResponse + ListSyntheticDataGenerationJobsResponse: + type: object + properties: + items: + type: array + items: + $ref: '#/components/schemas/SyntheticDataGenerationJob' + additionalProperties: false + required: + - items + title: ListSyntheticDataGenerationJobsResponse ListToolGroupsResponse: type: object properties: @@ -7059,15 +7158,6 @@ components: - hyperparam_search_config - logger_config title: PreferenceOptimizeRequest - PostTrainingJob: - type: object - properties: - job_uuid: - type: string - additionalProperties: false - required: - - job_uuid - title: PostTrainingJob DefaultRAGQueryGeneratorConfig: type: object properties: @@ -7514,16 +7604,6 @@ components: required: - tool_responses title: ResumeAgentTurnRequest - RunEvalRequest: - type: object - properties: - benchmark_config: - $ref: '#/components/schemas/BenchmarkConfig' - description: The configuration for the benchmark. - additionalProperties: false - required: - - benchmark_config - title: RunEvalRequest RunShieldRequest: type: object properties: @@ -7620,6 +7700,40 @@ components: - results title: ScoreResponse description: The response from scoring. + ScoringResult: + type: object + properties: + score_rows: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + The scoring result for each row. Each row is a map of column name to value. + aggregated_results: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Map of metric name to aggregated value + additionalProperties: false + required: + - score_rows + - aggregated_results + title: ScoringResult + description: A scoring result for a single row. ScoreBatchRequest: type: object properties: @@ -7777,38 +7891,69 @@ components: - dialogs - filtering_function title: SyntheticDataGenerateRequest - SyntheticDataGenerationResponse: + UpdateEvaluateJobRequest: type: object properties: - synthetic_data: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - statistics: + job: type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + properties: + id: + type: string + artifacts: + type: array + items: + $ref: '#/components/schemas/JobArtifact' + events: + type: array + items: + $ref: '#/components/schemas/JobStatusDetails' + type: + type: string + const: eval + default: eval + status: + type: string + enum: + - unknown + - new + - scheduled + - running + - paused + - resuming + - cancelled + - failed + - completed + title: JobStatus + additionalProperties: false + required: + - id + - artifacts + - events + - type + - status + title: EvaluateJob additionalProperties: false required: - - synthetic_data - title: SyntheticDataGenerationResponse - description: >- - Response from the synthetic data generation. Batch of (prompt, response, score) - tuples that pass the threshold. + - job + title: UpdateEvaluateJobRequest + UpdatePostTrainingJobRequest: + type: object + properties: + job: + $ref: '#/components/schemas/PostTrainingJob' + additionalProperties: false + required: + - job + title: UpdatePostTrainingJobRequest + UpdateSyntheticDataGenerationJobRequest: + type: object + properties: + job: + $ref: '#/components/schemas/SyntheticDataGenerationJob' + additionalProperties: false + required: + - job + title: UpdateSyntheticDataGenerationJobRequest VersionInfo: type: object properties: diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index 7a324128d..ad536cb31 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -4,9 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Protocol, runtime_checkable +from typing import List, Literal, Optional, Protocol, runtime_checkable -from llama_stack.apis.common.job_types import Job +from pydantic import BaseModel + +from llama_stack.apis.common.job_types import BaseJob from llama_stack.apis.inference import ( InterleavedContent, LogProbConfig, @@ -20,6 +22,14 @@ from llama_stack.apis.inference import ( from llama_stack.schema_utils import webmethod +class BatchInferenceJob(BaseJob, BaseModel): + type: Literal["batch_inference"] = "batch_inference" + + +class ListBatchInferenceJobsResponse(BaseModel): + data: list[BatchInferenceJob] + + @runtime_checkable class BatchInference(Protocol): """Batch inference API for generating completions and chat completions. @@ -38,7 +48,7 @@ class BatchInference(Protocol): sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, logprobs: Optional[LogProbConfig] = None, - ) -> Job: ... + ) -> BatchInferenceJob: ... @webmethod(route="/batch-inference/chat-completion", method="POST") async def chat_completion( @@ -52,4 +62,4 @@ class BatchInference(Protocol): tool_prompt_format: Optional[ToolPromptFormat] = None, response_format: Optional[ResponseFormat] = None, logprobs: Optional[LogProbConfig] = None, - ) -> Job: ... + ) -> BatchInferenceJob: ... diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py index ca6bcaf63..730a9bd2e 100644 --- a/llama_stack/apis/common/job_types.py +++ b/llama_stack/apis/common/job_types.py @@ -3,22 +3,68 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum +from datetime import datetime, timezone +from enum import Enum, unique +from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field, computed_field from llama_stack.schema_utils import json_schema_type +@unique class JobStatus(Enum): - completed = "completed" - in_progress = "in_progress" - failed = "failed" + unknown = "unknown" + new = "new" scheduled = "scheduled" + running = "running" + paused = "paused" + resuming = "resuming" cancelled = "cancelled" + failed = "failed" + completed = "completed" @json_schema_type -class Job(BaseModel): - job_id: str +class JobStatusDetails(BaseModel): status: JobStatus + message: str | None = None + timestamp: datetime + + +@json_schema_type +class JobArtifact(BaseModel): + name: str + + # TODO: should it be a Literal / Enum? + type: str + + # Any additional metadata the artifact may have + # TODO: is Any the right type here? What happens when the caller passes a value without a __repr__? + metadata: dict[str, Any] | None = None + + # TODO: enforce type to be a URI + uri: str | None = None # points to /files + + +def _get_job_status_details(status: JobStatus) -> JobStatusDetails: + return JobStatusDetails(status=status, timestamp=datetime.now(timezone.utc)) + + +class BaseJob(BaseModel): + id: str # TODO: make it a UUID? + + artifacts: list[JobArtifact] = Field(default_factory=list) + events: list[JobStatusDetails] = Field(default_factory=lambda: [_get_job_status_details(JobStatus.new)]) + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if "type" not in cls.__annotations__: + raise ValueError(f"Class {cls.__name__} must have a type field") + + @computed_field + def status(self) -> JobStatus: + return self.events[-1].status + + def update_status(self, value: JobStatus): + self.events.append(_get_job_status_details(value)) diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 0e5959c37..fe9cbecce 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -4,15 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Literal, Optional, Protocol, Union +from typing import Dict, Literal, Optional, Protocol, Union from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.agents import AgentConfig -from llama_stack.apis.common.job_types import Job +from llama_stack.apis.common.job_types import BaseJob from llama_stack.apis.inference import SamplingParams, SystemMessage -from llama_stack.apis.scoring import ScoringResult from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @@ -47,6 +46,14 @@ EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discrimin register_schema(EvalCandidate, name="EvalCandidate") +class EvaluateJob(BaseJob, BaseModel): + type: Literal["eval"] = "eval" + + +class ListEvaluateJobsResponse(BaseModel): + data: list[EvaluateJob] + + @json_schema_type class BenchmarkConfig(BaseModel): """A benchmark configuration for evaluation. @@ -68,76 +75,30 @@ class BenchmarkConfig(BaseModel): # we could optinally add any specific dataset config here -@json_schema_type -class EvaluateResponse(BaseModel): - """The response from an evaluation. - - :param generations: The generations from the evaluation. - :param scores: The scores from the evaluation. - """ - - generations: List[Dict[str, Any]] - # each key in the dict is a scoring function name - scores: Dict[str, ScoringResult] - - class Eval(Protocol): """Llama Stack Evaluation API for running evaluations on model and agent candidates.""" - @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST") - async def run_eval( + @webmethod(route="/eval/benchmarks/{benchmark_id}/evaluate", method="POST") + async def evaluate( self, benchmark_id: str, benchmark_config: BenchmarkConfig, - ) -> Job: - """Run an evaluation on a benchmark. + ) -> EvaluateJob: ... - :param benchmark_id: The ID of the benchmark to run the evaluation on. - :param benchmark_config: The configuration for the benchmark. - :return: The job that was created to run the evaluation. - """ + # CRUD operations on running jobs + @webmethod(route="/evaluate/jobs/{job_id:path}", method="GET") + async def get_evaluate_job(self, job_id: str) -> EvaluateJob: ... - @webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST") - async def evaluate_rows( - self, - benchmark_id: str, - input_rows: List[Dict[str, Any]], - scoring_functions: List[str], - benchmark_config: BenchmarkConfig, - ) -> EvaluateResponse: - """Evaluate a list of rows on a benchmark. + @webmethod(route="/evaluate/jobs", method="GET") + async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse: ... - :param benchmark_id: The ID of the benchmark to run the evaluation on. - :param input_rows: The rows to evaluate. - :param scoring_functions: The scoring functions to use for the evaluation. - :param benchmark_config: The configuration for the benchmark. - :return: EvaluateResponse object containing generations and scores - """ + @webmethod(route="/evaluate/jobs/{job_id:path}", method="POST") + async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob: ... - @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") - async def job_status(self, benchmark_id: str, job_id: str) -> Job: - """Get the status of a job. + @webmethod(route="/evaluate/job/{job_id:path}", method="DELETE") + async def delete_evaluate_job(self, job_id: str) -> None: ... - :param benchmark_id: The ID of the benchmark to run the evaluation on. - :param job_id: The ID of the job to get the status of. - :return: The status of the evaluationjob. - """ - ... - - @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE") - async def job_cancel(self, benchmark_id: str, job_id: str) -> None: - """Cancel a job. - - :param benchmark_id: The ID of the benchmark to run the evaluation on. - :param job_id: The ID of the job to cancel. - """ - ... - - @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET") - async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: - """Get the result of a job. - - :param benchmark_id: The ID of the benchmark to run the evaluation on. - :param job_id: The ID of the job to get the result of. - :return: The result of the job. - """ + # Note: pause/resume/cancel are achieved as follows: + # - POST with status=paused + # - POST with status=resuming + # - POST with status=cancelled diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index e5f1bcb65..46caed8ac 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from datetime import datetime from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, Union @@ -12,8 +11,7 @@ from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.common.content_types import URL -from llama_stack.apis.common.job_types import JobStatus -from llama_stack.apis.common.training_types import Checkpoint +from llama_stack.apis.common.job_types import BaseJob from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @@ -92,14 +90,6 @@ AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Fi register_schema(AlgorithmConfig, name="AlgorithmConfig") -@json_schema_type -class PostTrainingJobLogStream(BaseModel): - """Stream of logs from a finetuning job.""" - - job_uuid: str - log_lines: List[str] - - @json_schema_type class RLHFAlgorithm(Enum): dpo = "dpo" @@ -135,41 +125,17 @@ class PostTrainingRLHFRequest(BaseModel): logger_config: Dict[str, Any] -class PostTrainingJob(BaseModel): - job_uuid: str - - @json_schema_type -class PostTrainingJobStatusResponse(BaseModel): - """Status of a finetuning job.""" - - job_uuid: str - status: JobStatus - - scheduled_at: Optional[datetime] = None - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - - resources_allocated: Optional[Dict[str, Any]] = None - - checkpoints: List[Checkpoint] = Field(default_factory=list) +class PostTrainingJob(BaseJob, BaseModel): + type: Literal["post-training"] = "post-training" class ListPostTrainingJobsResponse(BaseModel): - data: List[PostTrainingJob] - - -@json_schema_type -class PostTrainingJobArtifactsResponse(BaseModel): - """Artifacts of a finetuning job.""" - - job_uuid: str - checkpoints: List[Checkpoint] = Field(default_factory=list) - - # TODO(ashwin): metrics, evals + data: list[PostTrainingJob] class PostTraining(Protocol): + # This is how you create a new job - POST against the root endpoint @webmethod(route="/post-training/supervised-fine-tune", method="POST") async def supervised_fine_tune( self, @@ -196,14 +162,20 @@ class PostTraining(Protocol): logger_config: Dict[str, Any], ) -> PostTrainingJob: ... + # CRUD operations on running jobs + @webmethod(route="/post-training/jobs/{job_id:path}", method="GET") + async def get_post_training_job(self, job_id: str) -> PostTrainingJob: ... + @webmethod(route="/post-training/jobs", method="GET") - async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ... + async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse: ... - @webmethod(route="/post-training/job/status", method="GET") - async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ... + @webmethod(route="/post-training/jobs/{job_id:path}", method="POST") + async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob: ... - @webmethod(route="/post-training/job/cancel", method="POST") - async def cancel_training_job(self, job_uuid: str) -> None: ... + @webmethod(route="/post-training/job/{job_id:path}", method="DELETE") + async def delete_post_training_job(self, job_id: str) -> None: ... - @webmethod(route="/post-training/job/artifacts", method="GET") - async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ... + # Note: pause/resume/cancel are achieved as follows: + # - POST with status=paused + # - POST with status=resuming + # - POST with status=cancelled diff --git a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py index 7b41192af..a05d04225 100644 --- a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +++ b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py @@ -5,10 +5,11 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Optional, Protocol, Union +from typing import List, Literal, Optional, Protocol from pydantic import BaseModel +from llama_stack.apis.common.job_types import BaseJob from llama_stack.apis.inference import Message from llama_stack.schema_utils import json_schema_type, webmethod @@ -34,11 +35,13 @@ class SyntheticDataGenerationRequest(BaseModel): @json_schema_type -class SyntheticDataGenerationResponse(BaseModel): - """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" +class SyntheticDataGenerationJob(BaseJob, BaseModel): + type: Literal["synthetic-data-generation"] = "synthetic-data-generation" - synthetic_data: List[Dict[str, Any]] - statistics: Optional[Dict[str, Any]] = None + +@json_schema_type +class ListSyntheticDataGenerationJobsResponse(BaseModel): + items: list[SyntheticDataGenerationJob] class SyntheticDataGeneration(Protocol): @@ -48,4 +51,24 @@ class SyntheticDataGeneration(Protocol): dialogs: List[Message], filtering_function: FilteringFunction = FilteringFunction.none, model: Optional[str] = None, - ) -> Union[SyntheticDataGenerationResponse]: ... + ) -> SyntheticDataGenerationJob: ... + + # CRUD operations on running jobs + @webmethod(route="/synthetic-data-generation/jobs/{job_id:path}", method="GET") + async def get_synthetic_data_generation_job(self) -> SyntheticDataGenerationJob: ... + + @webmethod(route="/synthetic-data-generation/jobs", method="GET") + async def list_synthetic_data_generation_jobs(self) -> ListSyntheticDataGenerationJobsResponse: ... + + @webmethod(route="/synthetic-data-generation/jobs/{job_id:path}", method="POST") + async def update_synthetic_data_generation_job( + self, job: SyntheticDataGenerationJob + ) -> SyntheticDataGenerationJob: ... + + @webmethod(route="/synthetic-data-generation/job/{job_id:path}", method="DELETE") + async def delete_synthetic_data_generation_job(self, job_id: str) -> None: ... + + # Note: pause/resume/cancel are achieved as follows: + # - POST with status=paused + # - POST with status=resuming + # - POST with status=cancelled diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 17aecdaf8..e1d21f400 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -16,7 +16,7 @@ from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import DatasetPurpose, DataSource -from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job +from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateJob, ListEvaluateJobsResponse from llama_stack.apis.inference import ( BatchChatCompletionResponse, BatchCompletionResponse, @@ -779,61 +779,32 @@ class EvalRouter(Eval): logger.debug("EvalRouter.shutdown") pass - async def run_eval( + async def evaluate( self, benchmark_id: str, benchmark_config: BenchmarkConfig, - ) -> Job: - logger.debug(f"EvalRouter.run_eval: {benchmark_id}") - return await self.routing_table.get_provider_impl(benchmark_id).run_eval( + ) -> EvaluateJob: + logger.debug(f"EvalRouter.evaluate: {benchmark_id}") + return await self.routing_table.get_provider_impl(benchmark_id).evaluate( benchmark_id=benchmark_id, benchmark_config=benchmark_config, ) - async def evaluate_rows( - self, - benchmark_id: str, - input_rows: List[Dict[str, Any]], - scoring_functions: List[str], - benchmark_config: BenchmarkConfig, - ) -> EvaluateResponse: - logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") - return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( - benchmark_id=benchmark_id, - input_rows=input_rows, - scoring_functions=scoring_functions, - benchmark_config=benchmark_config, - ) + async def get_evaluate_job(self, job_id: str) -> EvaluateJob: + logger.debug(f"EvalRouter.get_evaluate_job: {job_id}") + return await self.routing_table.get_provider_impl("eval").get_evaluate_job(job_id) - async def job_status( - self, - benchmark_id: str, - job_id: str, - ) -> Job: - logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) + async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse: + logger.debug("EvalRouter.list_evaluate_jobs") + return await self.routing_table.get_provider_impl("eval").list_evaluate_jobs() - async def job_cancel( - self, - benchmark_id: str, - job_id: str, - ) -> None: - logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") - await self.routing_table.get_provider_impl(benchmark_id).job_cancel( - benchmark_id, - job_id, - ) + async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob: + logger.debug(f"EvalRouter.update_evaluate_job: {job.id}") + return await self.routing_table.get_provider_impl("eval").update_evaluate_job(job) - async def job_result( - self, - benchmark_id: str, - job_id: str, - ) -> EvaluateResponse: - logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_result( - benchmark_id, - job_id, - ) + async def delete_evaluate_job(self, job_id: str) -> None: + logger.debug(f"EvalRouter.delete_evaluate_job: {job_id}") + return await self.routing_table.get_provider_impl("eval").delete_evaluate_job(job_id) class ToolRuntimeRouter(ToolRuntime): diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 7c28f1bb7..8a4373b1c 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -20,9 +20,10 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( ) from llama_stack.providers.utils.common.data_schema_validator import ColumnName from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.schema_utils import webmethod -from .....apis.common.job_types import Job, JobStatus -from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse +from .....apis.common.job_types import JobArtifact, JobStatus +from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateJob, ListEvaluateJobsResponse from .config import MetaReferenceEvalConfig EVAL_TASKS_PREFIX = "benchmarks:" @@ -75,11 +76,11 @@ class MetaReferenceEvalImpl( ) self.benchmarks[task_def.identifier] = task_def - async def run_eval( + async def evaluate( self, benchmark_id: str, benchmark_config: BenchmarkConfig, - ) -> Job: + ) -> EvaluateJob: task_def = self.benchmarks[benchmark_id] dataset_id = task_def.dataset_id scoring_functions = task_def.scoring_functions @@ -91,18 +92,35 @@ class MetaReferenceEvalImpl( dataset_id=dataset_id, limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), ) - res = await self.evaluate_rows( + + generations, scoring_results = await self._evaluate_rows( benchmark_id=benchmark_id, input_rows=all_rows.data, scoring_functions=scoring_functions, benchmark_config=benchmark_config, ) + artifacts = [ + JobArtifact( + type="generation", + name=f"generation-{i}", + metadata=generation, + ) + for i, generation in enumerate(generations) + ] + [ + JobArtifact( + type="scoring_results", + name="scoring_results", + metadata=scoring_results, + ) + ] # TODO: currently needs to wait for generation before returning # need job scheduler queue (ray/celery) w/ jobs api job_id = str(len(self.jobs)) - self.jobs[job_id] = res - return Job(job_id=job_id, status=JobStatus.completed) + job = EvaluateJob(id=job_id, artifacts=artifacts) + job.update_status(JobStatus.completed) + self.jobs[job_id] = job + return job async def _run_agent_generation( self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig @@ -182,13 +200,13 @@ class MetaReferenceEvalImpl( return generations - async def evaluate_rows( + async def _evaluate_rows( self, benchmark_id: str, input_rows: List[Dict[str, Any]], scoring_functions: List[str], benchmark_config: BenchmarkConfig, - ) -> EvaluateResponse: + ) -> tuple[list[dict[str, Any]], dict[str, Any]]: candidate = benchmark_config.eval_candidate if candidate.type == "agent": generations = await self._run_agent_generation(input_rows, benchmark_config) @@ -214,21 +232,26 @@ class MetaReferenceEvalImpl( input_rows=score_input_rows, scoring_functions=scoring_functions_dict ) - return EvaluateResponse(generations=generations, scores=score_response.results) - - async def job_status(self, benchmark_id: str, job_id: str) -> Job: - if job_id in self.jobs: - return Job(job_id=job_id, status=JobStatus.completed) - - raise ValueError(f"Job {job_id} not found") - - async def job_cancel(self, benchmark_id: str, job_id: str) -> None: - raise NotImplementedError("Job cancel is not implemented yet") - - async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: - job = await self.job_status(benchmark_id, job_id) - status = job.status - if not status or status != JobStatus.completed: - raise ValueError(f"Job is not completed, Status: {status.value}") + return generations, score_response.results + # CRUD operations on running jobs + @webmethod(route="/evaluate/jobs/{job_id:path}", method="GET") + async def get_evaluate_job(self, job_id: str) -> EvaluateJob: return self.jobs[job_id] + + @webmethod(route="/evaluate/jobs", method="GET") + async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse: + return ListEvaluateJobsResponse(data=list(self.jobs.values())) + + @webmethod(route="/evaluate/jobs/{job_id:path}", method="POST") + async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob: + raise NotImplementedError + + @webmethod(route="/evaluate/job/{job_id:path}", method="DELETE") + async def delete_evaluate_job(self, job_id: str) -> None: + raise NotImplementedError + + # Note: pause/resume/cancel are achieved as follows: + # - POST with status=paused + # - POST with status=resuming + # - POST with status=cancelled diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index cc1a6a5fe..e18c56f08 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -10,14 +10,10 @@ from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( AlgorithmConfig, - Checkpoint, DPOAlignmentConfig, - JobStatus, ListPostTrainingJobsResponse, LoraFinetuningConfig, PostTrainingJob, - PostTrainingJobArtifactsResponse, - PostTrainingJobStatusResponse, TrainingConfig, ) from llama_stack.providers.inline.post_training.torchtune.config import ( @@ -54,15 +50,6 @@ class TorchtunePostTrainingImpl: async def shutdown(self) -> None: await self._scheduler.shutdown() - @staticmethod - def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact: - return JobArtifact( - type=TrainingArtifactType.CHECKPOINT.value, - name=checkpoint.identifier, - uri=checkpoint.path, - metadata=dict(checkpoint), - ) - @staticmethod def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact: return JobArtifact( @@ -98,14 +85,14 @@ class TorchtunePostTrainingImpl: self.datasetio_api, self.datasets_api, ) + await recipe.setup() resources_allocated, checkpoints = await recipe.train() on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated)) for checkpoint in checkpoints: - artifact = self._checkpoint_to_artifact(checkpoint) - on_artifact_collected_cb(artifact) + on_artifact_collected_cb(checkpoint) on_status_change_cb(SchedulerJobStatus.completed) on_log_message_cb("Lora finetuning completed") @@ -113,6 +100,8 @@ class TorchtunePostTrainingImpl: raise NotImplementedError() job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler) + + # TODO: initialize with more data from scheduler return PostTrainingJob(job_uuid=job_uuid) async def preference_optimize( @@ -125,56 +114,31 @@ class TorchtunePostTrainingImpl: logger_config: Dict[str, Any], ) -> PostTrainingJob: ... - async def get_training_jobs(self) -> ListPostTrainingJobsResponse: + # TODO: should these be under post-training/supervised-fine-tune/? + # CRUD operations on running jobs + @webmethod(route="/post-training/jobs/{job_id:path}", method="GET") + async def get_post_training_job(self, job_id: str) -> PostTrainingJob: + # TODO: implement + raise NotImplementedError + + @webmethod(route="/post-training/jobs", method="GET") + async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse: + # TODO: populate other data return ListPostTrainingJobsResponse( data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()] ) - @staticmethod - def _get_artifacts_metadata_by_type(job, artifact_type): - return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type] + @webmethod(route="/post-training/jobs/{job_id:path}", method="POST") + async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob: + # TODO: implement + raise NotImplementedError - @classmethod - def _get_checkpoints(cls, job): - return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value) + @webmethod(route="/post-training/job/{job_id:path}", method="DELETE") + async def delete_post_training_job(self, job_id: str) -> None: + # TODO: implement + raise NotImplementedError - @classmethod - def _get_resources_allocated(cls, job): - data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value) - return data[0] if data else None - - @webmethod(route="/post-training/job/status") - async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: - job = self._scheduler.get_job(job_uuid) - - match job.status: - # TODO: Add support for other statuses to API - case SchedulerJobStatus.new | SchedulerJobStatus.scheduled: - status = JobStatus.scheduled - case SchedulerJobStatus.running: - status = JobStatus.in_progress - case SchedulerJobStatus.completed: - status = JobStatus.completed - case SchedulerJobStatus.failed: - status = JobStatus.failed - case _: - raise NotImplementedError() - - return PostTrainingJobStatusResponse( - job_uuid=job_uuid, - status=status, - scheduled_at=job.scheduled_at, - started_at=job.started_at, - completed_at=job.completed_at, - checkpoints=self._get_checkpoints(job), - resources_allocated=self._get_resources_allocated(job), - ) - - @webmethod(route="/post-training/job/cancel") - async def cancel_training_job(self, job_uuid: str) -> None: - self._scheduler.cancel(job_uuid) - - @webmethod(route="/post-training/job/artifacts") - async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: - job = self._scheduler.get_job(job_uuid) - return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job)) + # Note: pause/resume/cancel are achieved as follows: + # - POST with status=paused + # - POST with status=resuming + # - POST with status=cancelled diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 04bf86b97..d139e8cdb 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -33,11 +33,11 @@ from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup from torchtune.training.metric_logging import DiskLogger from tqdm import tqdm +from llama_stack.apis.common.job_types import JobArtifact from llama_stack.apis.common.training_types import PostTrainingMetric from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( - Checkpoint, DataConfig, EfficiencyConfig, LoraFinetuningConfig, @@ -457,7 +457,7 @@ class LoraFinetuningSingleDevice: return loss - async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: + async def train(self) -> Tuple[Dict[str, Any], List[JobArtifact]]: """ The core training loop. """ @@ -543,13 +543,18 @@ class LoraFinetuningSingleDevice: self.epochs_run += 1 log.info("Starting checkpoint save...") checkpoint_path = await self.save_checkpoint(epoch=curr_epoch) - checkpoint = Checkpoint( - identifier=f"{self.model_id}-sft-{curr_epoch}", - created_at=datetime.now(timezone.utc), - epoch=curr_epoch, - post_training_job_id=self.job_uuid, - path=checkpoint_path, + + checkpoint = JobArtifact( + name=f"{self.model_id}-sft-{curr_epoch}", + type="checkpoint", + # TODO: this should be exposed via /files instead + uri=checkpoint_path, ) + + metadata = { + "created_at": datetime.now(timezone.utc), + "epoch": curr_epoch, + } if self.training_config.data_config.validation_dataset_id: validation_loss, perplexity = await self.validation() training_metrics = PostTrainingMetric( @@ -558,7 +563,9 @@ class LoraFinetuningSingleDevice: validation_loss=validation_loss, perplexity=perplexity, ) - checkpoint.training_metrics = training_metrics + metadata["training_metrics"] = training_metrics + checkpoint.metadata = metadata + checkpoints.append(checkpoint) # clean up the memory after training finishes diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index e14fcf0cc..18d563641 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -4,19 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import warnings -from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, Optional import aiohttp -from pydantic import BaseModel, ConfigDict +from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.post_training import ( AlgorithmConfig, DPOAlignmentConfig, - JobStatus, + ListPostTrainingJobsResponse, PostTrainingJob, - PostTrainingJobArtifactsResponse, - PostTrainingJobStatusResponse, TrainingConfig, ) from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig @@ -25,36 +22,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe from .models import _MODEL_ENTRIES -# Map API status to JobStatus enum -STATUS_MAPPING = { - "running": "in_progress", - "completed": "completed", - "failed": "failed", - "cancelled": "cancelled", - "pending": "scheduled", -} - - -class NvidiaPostTrainingJob(PostTrainingJob): - """Parse the response from the Customizer API. - Inherits job_uuid from PostTrainingJob. - Adds status, created_at, updated_at parameters. - Passes through all other parameters from data field in the response. - """ - - model_config = ConfigDict(extra="allow") - status: JobStatus - created_at: datetime - updated_at: datetime - - -class ListNvidiaPostTrainingJobs(BaseModel): - data: List[NvidiaPostTrainingJob] - - -class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse): - model_config = ConfigDict(extra="allow") - class NvidiaPostTrainingAdapter(ModelRegistryHelper): def __init__(self, config: NvidiaPostTrainingConfig): @@ -100,102 +67,54 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): raise Exception(f"API request failed: {error_data}") return await response.json() - async def get_training_jobs( - self, - page: Optional[int] = 1, - page_size: Optional[int] = 10, - sort: Optional[Literal["created_at", "-created_at"]] = "created_at", - ) -> ListNvidiaPostTrainingJobs: - """Get all customization jobs. - Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs. + raise Exception(f"API request failed after {self.config.max_retries} retries") - Returns a ListNvidiaPostTrainingJobs object with the following fields: - - data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects + @staticmethod + def _get_job_status(job: Dict[str, Any]) -> JobStatus: + job_status = job.get("status", "unknown").lower() + try: + return JobStatus(job_status) + except ValueError: + return JobStatus.unknown + + # TODO: fetch just the necessary job from remote + async def get_post_training_job(self, job_id: str) -> PostTrainingJob: + jobs = await self.list_post_training_jobs() + for job in jobs.data: + if job.id == job_id: + return job + raise ValueError(f"Job with ID {job_id} not found") + + async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse: + """Get all customization jobs. ToDo: Support for schema input for filtering. """ - params = {"page": page, "page_size": page_size, "sort": sort} - + # TODO: don't hardcode pagination params + params = {"page": 1, "page_size": 10, "sort": "created_at"} response = await self._make_request("GET", "/v1/customization/jobs", params=params) jobs = [] - for job in response.get("data", []): - job_id = job.pop("id") - job_status = job.pop("status", "unknown").lower() - mapped_status = STATUS_MAPPING.get(job_status, "unknown") + for job_dict in response.get("data", []): + # TODO: expose artifacts + job = PostTrainingJob(**job_dict) + job.update_status(self._get_job_status(job_dict)) + jobs.append(job) - # Convert string timestamps to datetime objects - created_at = ( - datetime.fromisoformat(job.pop("created_at")) - if "created_at" in job - else datetime.now(tz=datetime.timezone.utc) - ) - updated_at = ( - datetime.fromisoformat(job.pop("updated_at")) - if "updated_at" in job - else datetime.now(tz=datetime.timezone.utc) - ) + return ListPostTrainingJobsResponse(data=jobs) - # Create NvidiaPostTrainingJob instance - jobs.append( - NvidiaPostTrainingJob( - job_uuid=job_id, - status=JobStatus(mapped_status), - created_at=created_at, - updated_at=updated_at, - **job, - ) - ) - - return ListNvidiaPostTrainingJobs(data=jobs) - - async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse: - """Get the status of a customization job. - Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob. - - Returns a NvidiaPostTrainingJob object with the following fields: - - job_uuid: str - Unique identifier for the job - - status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled) - - created_at: datetime - The time when the job was created - - updated_at: datetime - The last time the job status was updated - - Additional fields that may be included: - - steps_completed: Optional[int] - Number of training steps completed - - epochs_completed: Optional[int] - Number of epochs completed - - percentage_done: Optional[float] - Percentage of training completed (0-100) - - best_epoch: Optional[int] - The epoch with the best performance - - train_loss: Optional[float] - Training loss of the best checkpoint - - val_loss: Optional[float] - Validation loss of the best checkpoint - - metrics: Optional[Dict] - Additional training metrics - - status_logs: Optional[List] - Detailed logs of status changes - """ - response = await self._make_request( - "GET", - f"/v1/customization/jobs/{job_uuid}/status", - params={"job_id": job_uuid}, - ) - - api_status = response.pop("status").lower() - mapped_status = STATUS_MAPPING.get(api_status, "unknown") - - return NvidiaPostTrainingJobStatusResponse( - status=JobStatus(mapped_status), - job_uuid=job_uuid, - started_at=datetime.fromisoformat(response.pop("created_at")), - updated_at=datetime.fromisoformat(response.pop("updated_at")), - **response, - ) - - async def cancel_training_job(self, job_uuid: str) -> None: + async def update_post_training_job(self, job_id: str, status: JobStatus | None = None) -> PostTrainingJob: + if status is None: + raise ValueError("Status must be provided") + if status not in {JobStatus.cancelled}: + raise ValueError(f"Unsupported status: {status}") await self._make_request( - method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid} + method="POST", path=f"/v1/customization/jobs/{job_id}/cancel", params={"job_id": job_id} ) + return await self.get_post_training_job(job_id) - async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: - raise NotImplementedError("Job artifacts are not implemented yet") - - async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: - raise NotImplementedError("Job artifacts are not implemented yet") + async def delete_post_training_job(self, job_id: str) -> None: + raise NotImplementedError("Delete job is not implemented yet") async def supervised_fine_tune( self, @@ -206,7 +125,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): model: str, checkpoint_dir: Optional[str], algorithm_config: Optional[AlgorithmConfig] = None, - ) -> NvidiaPostTrainingJob: + ) -> PostTrainingJob: """ Fine-tunes a model on a dataset. Currently only supports Lora finetuning for standlone docker container. @@ -409,15 +328,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): headers={"Accept": "application/json"}, json=job_config, ) - - job_uuid = response["id"] response.pop("status") - created_at = datetime.fromisoformat(response.pop("created_at")) - updated_at = datetime.fromisoformat(response.pop("updated_at")) - return NvidiaPostTrainingJob( - job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response - ) + # TODO: expose artifacts + job = PostTrainingJob(**response) + job.update_status(JobStatus.running) + return job async def preference_optimize( self, @@ -430,6 +346,3 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): ) -> PostTrainingJob: """Optimize a model based on preference data.""" raise NotImplementedError("Preference optimization is not implemented yet") - - async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse: - raise NotImplementedError("Job logs are not implemented yet") diff --git a/llama_stack/strong_typing/inspection.py b/llama_stack/strong_typing/inspection.py index a75a170cf..ebfcd74f8 100644 --- a/llama_stack/strong_typing/inspection.py +++ b/llama_stack/strong_typing/inspection.py @@ -562,6 +562,15 @@ else: return typing.get_type_hints(typ) +def get_computed_fields(typ: type) -> dict[str, type]: + "Returns all computed fields of a class." + pydantic_decorators = getattr(typ, "__pydantic_decorators__", None) + if not pydantic_decorators: + return {} + computed_fields = pydantic_decorators.computed_fields + return {field_name: decorator.info.return_type for field_name, decorator in computed_fields.items()} + + def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]: "Returns all properties of a class." @@ -569,7 +578,8 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]: return ((field.name, field.type) for field in dataclasses.fields(typ)) else: resolved_hints = get_resolved_hints(typ) - return resolved_hints.items() + computed_fields = get_computed_fields(typ) + return (resolved_hints | computed_fields).items() def get_class_property(typ: type, name: str) -> Optional[type | str]: diff --git a/tests/integration/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py index 3e22bc5a7..0768ff6a5 100644 --- a/tests/integration/post_training/test_post_training.py +++ b/tests/integration/post_training/test_post_training.py @@ -8,14 +8,12 @@ from typing import List import pytest from llama_stack.apis.common.job_types import JobStatus +from llama_stack.apis.common.training_types import Checkpoint from llama_stack.apis.post_training import ( - Checkpoint, DataConfig, LoraFinetuningConfig, OptimizerConfig, PostTrainingJob, - PostTrainingJobArtifactsResponse, - PostTrainingJobStatusResponse, TrainingConfig, ) @@ -84,7 +82,6 @@ class TestPostTraining: async def test_get_training_job_status(self, post_training_stack): post_training_impl = post_training_stack job_status = await post_training_impl.get_training_job_status("1234") - assert isinstance(job_status, PostTrainingJobStatusResponse) assert job_status.job_uuid == "1234" assert job_status.status == JobStatus.completed assert isinstance(job_status.checkpoints[0], Checkpoint) @@ -93,7 +90,6 @@ class TestPostTraining: async def test_get_training_job_artifacts(self, post_training_stack): post_training_impl = post_training_stack job_artifacts = await post_training_impl.get_training_job_artifacts("1234") - assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) assert job_artifacts.job_uuid == "1234" assert isinstance(job_artifacts.checkpoints[0], Checkpoint) assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0" diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 7ce89144b..d6470962b 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -17,12 +17,14 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import ( TrainingConfigOptimizerConfig, ) +from llama_stack.apis.common.job_types import JobStatus +from llama_stack.apis.post_training import ( + ListPostTrainingJobsResponse, + PostTrainingJob, +) from llama_stack.providers.remote.post_training.nvidia.post_training import ( - ListNvidiaPostTrainingJobs, NvidiaPostTrainingAdapter, NvidiaPostTrainingConfig, - NvidiaPostTrainingJob, - NvidiaPostTrainingJobStatusResponse, ) @@ -49,21 +51,25 @@ class TestNvidiaPostTraining(unittest.TestCase): def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None): """Helper method to verify request details in mock calls.""" - call_args = mock_call.call_args + found = False + for call_args in mock_call.call_args_list: + if expected_method and expected_path: + if isinstance(call_args[0], tuple) and len(call_args[0]) == 2: + if call_args[0] == (expected_method, expected_path): + found = True + else: + if call_args[1]["method"] == expected_method and call_args[1]["path"] == expected_path: + found = True - if expected_method and expected_path: - if isinstance(call_args[0], tuple) and len(call_args[0]) == 2: - assert call_args[0] == (expected_method, expected_path) - else: - assert call_args[1]["method"] == expected_method - assert call_args[1]["path"] == expected_path + if expected_params: + if call_args[1]["params"] == expected_params: + found = True - if expected_params: - assert call_args[1]["params"] == expected_params - - if expected_json: - for key, value in expected_json.items(): - assert call_args[1]["json"][key] == value + if expected_json: + for key, value in expected_json.items(): + if call_args[1]["json"][key] == value: + found = True + assert found def test_supervised_fine_tune(self): """Test the supervised fine-tuning API call.""" @@ -151,9 +157,8 @@ class TestNvidiaPostTraining(unittest.TestCase): ) ) - # check the output is a PostTrainingJob - assert isinstance(training_job, NvidiaPostTrainingJob) - assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2" + assert isinstance(training_job, PostTrainingJob) + assert training_job.id == "cust-JGTaMbJMdqjJU8WbQdN9Q2" self.mock_make_request.assert_called_once() self._assert_request( @@ -199,38 +204,7 @@ class TestNvidiaPostTraining(unittest.TestCase): ) ) - def test_get_training_job_status(self): - self.mock_make_request.return_value = { - "created_at": "2024-12-09T04:06:28.580220", - "updated_at": "2024-12-09T04:21:19.852832", - "status": "completed", - "steps_completed": 1210, - "epochs_completed": 2, - "percentage_done": 100.0, - "best_epoch": 2, - "train_loss": 1.718016266822815, - "val_loss": 1.8661999702453613, - } - - job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" - - status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id)) - - assert isinstance(status, NvidiaPostTrainingJobStatusResponse) - assert status.status.value == "completed" - assert status.steps_completed == 1210 - assert status.epochs_completed == 2 - assert status.percentage_done == 100.0 - assert status.best_epoch == 2 - assert status.train_loss == 1.718016266822815 - assert status.val_loss == 1.8661999702453613 - - self.mock_make_request.assert_called_once() - self._assert_request( - self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id} - ) - - def test_get_training_jobs(self): + def test_list_post_training_jobs(self): job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" self.mock_make_request.return_value = { "data": [ @@ -258,12 +232,12 @@ class TestNvidiaPostTraining(unittest.TestCase): ] } - jobs = self.run_async(self.adapter.get_training_jobs()) + jobs = self.run_async(self.adapter.list_post_training_jobs()) - assert isinstance(jobs, ListNvidiaPostTrainingJobs) + assert isinstance(jobs, ListPostTrainingJobsResponse) assert len(jobs.data) == 1 job = jobs.data[0] - assert job.job_uuid == job_id + assert job.id == job_id assert job.status.value == "completed" self.mock_make_request.assert_called_once() @@ -275,14 +249,36 @@ class TestNvidiaPostTraining(unittest.TestCase): ) def test_cancel_training_job(self): - self.mock_make_request.return_value = {} # Empty response for successful cancellation job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" + self.mock_make_request.return_value = { + "data": [ + { + "id": job_id, + "created_at": "2024-12-09T04:06:28.542884", + "updated_at": "2024-12-09T04:21:19.852832", + "config": { + "name": "meta-llama/Llama-3.1-8B-Instruct", + "base_model": "meta-llama/Llama-3.1-8B-Instruct", + }, + "dataset": {"name": "default/sample-basic-test"}, + "hyperparameters": { + "finetuning_type": "lora", + "training_type": "sft", + "batch_size": 16, + "epochs": 2, + "learning_rate": 0.0001, + "lora": {"adapter_dim": 16, "adapter_dropout": 0.1}, + }, + "output_model": "default/job-1234", + "status": "completed", + "project": "default", + } + ] + } - result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id)) + result = self.run_async(self.adapter.update_post_training_job(job_id=job_id, status=JobStatus.cancelled)) + assert result.id == job_id - assert result is None - - self.mock_make_request.assert_called_once() self._assert_request( self.mock_make_request, "POST",