From dd9e1420daf1fd682d94b32a8578b9ca43639fee Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 12 Mar 2025 01:29:18 -0700 Subject: [PATCH] post training job api --- docs/_static/llama-stack-spec.html | 260 ++++++------------ docs/_static/llama-stack-spec.yaml | 176 ++++-------- .../apis/post_training/post_training.py | 5 +- 3 files changed, 133 insertions(+), 308 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 1ef5effef..fb656ffac 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -332,8 +332,55 @@ ] } }, - "/v1/post-training/job/cancel": { - "post": { + "/v1/post-training/jobs/{job_id}": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "$ref": "#/components/schemas/PostTrainingJob" + }, + { + "type": "null" + } + ] + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "PostTraining (Coming Soon)" + ], + "description": "", + "parameters": [ + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ] + }, + "delete": { "responses": { "200": { "description": "OK" @@ -355,17 +402,16 @@ "PostTraining (Coming Soon)" ], "description": "", - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/CancelTrainingJobRequest" - } + "parameters": [ + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" } - }, - "required": true - } + } + ] } }, "/v1/inference/chat-completion": { @@ -1869,104 +1915,6 @@ ] } }, - "/v1/post-training/job/artifacts": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/PostTrainingJobArtifactsResponse" - }, - { - "type": "null" - } - ] - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "PostTraining (Coming Soon)" - ], - "description": "", - "parameters": [ - { - "name": "job_uuid", - "in": "query", - "required": true, - "schema": { - "type": "string" - } - } - ] - } - }, - "/v1/post-training/job/status": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/PostTrainingJobStatusResponse" - }, - { - "type": "null" - } - ] - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "PostTraining (Coming Soon)" - ], - "description": "", - "parameters": [ - { - "name": "job_uuid", - "in": "query", - "required": true, - "schema": { - "type": "string" - } - } - ] - } - }, "/v1/post-training/jobs": { "get": { "responses": { @@ -3130,7 +3078,7 @@ } } }, - "/v1/post-training/preference-optimize": { + "/v1/post-training/preference-optimize/jobs": { "post": { "responses": { "200": { @@ -3586,7 +3534,7 @@ } } }, - "/v1/post-training/supervised-fine-tune": { + "/v1/post-training/supervised-fine-tune/jobs": { "post": { "responses": { "200": { @@ -4703,19 +4651,6 @@ "title": "CompletionResponse", "description": "Response from a completion request." }, - "CancelTrainingJobRequest": { - "type": "object", - "properties": { - "job_uuid": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "job_uuid" - ], - "title": "CancelTrainingJobRequest" - }, "ToolConfig": { "type": "object", "properties": { @@ -7900,32 +7835,12 @@ "description": "Checkpoint created during training runs", "title": "Checkpoint" }, - "PostTrainingJobArtifactsResponse": { + "PostTrainingJob": { "type": "object", "properties": { - "job_uuid": { - "type": "string" - }, - "checkpoints": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Checkpoint" - } - } - }, - "additionalProperties": false, - "required": [ - "job_uuid", - "checkpoints" - ], - "title": "PostTrainingJobArtifactsResponse", - "description": "Artifacts of a finetuning job." - }, - "PostTrainingJobStatusResponse": { - "type": "object", - "properties": { - "job_uuid": { - "type": "string" + "id": { + "type": "string", + "description": "The ID of the job." }, "status": { "type": "string", @@ -7936,19 +7851,21 @@ "scheduled", "cancelled" ], - "title": "JobStatus" + "description": "The status of the job." }, - "scheduled_at": { + "created_at": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "The time the job was created." }, - "started_at": { + "finished_at": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "The time the job finished." }, - "completed_at": { + "error": { "type": "string", - "format": "date-time" + "description": "If status of the job is failed, this will contain the error message." }, "resources_allocated": { "type": "object", @@ -7984,12 +7901,12 @@ }, "additionalProperties": false, "required": [ - "job_uuid", + "id", "status", + "created_at", "checkpoints" ], - "title": "PostTrainingJobStatusResponse", - "description": "Status of a finetuning job." + "title": "PostTrainingJob" }, "ListPostTrainingJobsResponse": { "type": "object", @@ -7997,17 +7914,7 @@ "data": { "type": "array", "items": { - "type": "object", - "properties": { - "job_uuid": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "job_uuid" - ], - "title": "PostTrainingJob" + "$ref": "#/components/schemas/PostTrainingJob" } } }, @@ -9042,19 +8949,6 @@ ], "title": "PreferenceOptimizeRequest" }, - "PostTrainingJob": { - "type": "object", - "properties": { - "job_uuid": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "job_uuid" - ], - "title": "PostTrainingJob" - }, "DefaultRAGQueryGeneratorConfig": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 7f5b96051..167f6e563 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -212,8 +212,37 @@ paths: required: true schema: type: string - /v1/post-training/job/cancel: - post: + /v1/post-training/jobs/{job_id}: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/PostTrainingJob' + - type: 'null' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - PostTraining (Coming Soon) + description: '' + parameters: + - name: job_id + in: path + required: true + schema: + type: string + delete: responses: '200': description: OK @@ -230,13 +259,12 @@ paths: tags: - PostTraining (Coming Soon) description: '' - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/CancelTrainingJobRequest' - required: true + parameters: + - name: job_id + in: path + required: true + schema: + type: string /v1/inference/chat-completion: post: responses: @@ -1263,66 +1291,6 @@ paths: required: true schema: type: string - /v1/post-training/job/artifacts: - get: - responses: - '200': - description: OK - content: - application/json: - schema: - oneOf: - - $ref: '#/components/schemas/PostTrainingJobArtifactsResponse' - - type: 'null' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - PostTraining (Coming Soon) - description: '' - parameters: - - name: job_uuid - in: query - required: true - schema: - type: string - /v1/post-training/job/status: - get: - responses: - '200': - description: OK - content: - application/json: - schema: - oneOf: - - $ref: '#/components/schemas/PostTrainingJobStatusResponse' - - type: 'null' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - PostTraining (Coming Soon) - description: '' - parameters: - - name: job_uuid - in: query - required: true - schema: - type: string /v1/post-training/jobs: get: responses: @@ -2119,7 +2087,7 @@ paths: schema: $ref: '#/components/schemas/LogEventRequest' required: true - /v1/post-training/preference-optimize: + /v1/post-training/preference-optimize/jobs: post: responses: '200': @@ -2435,7 +2403,7 @@ paths: schema: $ref: '#/components/schemas/ScoreRowsRequest' required: true - /v1/post-training/supervised-fine-tune: + /v1/post-training/supervised-fine-tune/jobs: post: responses: '200': @@ -3207,15 +3175,6 @@ components: - stop_reason title: CompletionResponse description: Response from a completion request. - CancelTrainingJobRequest: - type: object - properties: - job_uuid: - type: string - additionalProperties: false - required: - - job_uuid - title: CancelTrainingJobRequest ToolConfig: type: object properties: @@ -5452,26 +5411,12 @@ components: Checkpoint: description: Checkpoint created during training runs title: Checkpoint - PostTrainingJobArtifactsResponse: + PostTrainingJob: type: object properties: - job_uuid: - type: string - checkpoints: - type: array - items: - $ref: '#/components/schemas/Checkpoint' - additionalProperties: false - required: - - job_uuid - - checkpoints - title: PostTrainingJobArtifactsResponse - description: Artifacts of a finetuning job. - PostTrainingJobStatusResponse: - type: object - properties: - job_uuid: + id: type: string + description: The ID of the job. status: type: string enum: @@ -5480,16 +5425,19 @@ components: - failed - scheduled - cancelled - title: JobStatus - scheduled_at: + description: The status of the job. + created_at: type: string format: date-time - started_at: + description: The time the job was created. + finished_at: type: string format: date-time - completed_at: + description: The time the job finished. + error: type: string - format: date-time + description: >- + If status of the job is failed, this will contain the error message. resources_allocated: type: object additionalProperties: @@ -5506,25 +5454,18 @@ components: $ref: '#/components/schemas/Checkpoint' additionalProperties: false required: - - job_uuid + - id - status + - created_at - checkpoints - title: PostTrainingJobStatusResponse - description: Status of a finetuning job. + title: PostTrainingJob ListPostTrainingJobsResponse: type: object properties: data: type: array items: - type: object - properties: - job_uuid: - type: string - additionalProperties: false - required: - - job_uuid - title: PostTrainingJob + $ref: '#/components/schemas/PostTrainingJob' additionalProperties: false required: - data @@ -6192,15 +6133,6 @@ components: - hyperparam_search_config - logger_config title: PreferenceOptimizeRequest - PostTrainingJob: - type: object - properties: - job_uuid: - type: string - additionalProperties: false - required: - - job_uuid - title: PostTrainingJob DefaultRAGQueryGeneratorConfig: type: object properties: diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 079093a5d..58e84eeee 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from datetime import datetime from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, Union @@ -12,7 +11,7 @@ from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.common.content_types import URL -from llama_stack.apis.common.job_types import JobCommonFields, JobStatus +from llama_stack.apis.common.job_types import CommonJobFields from llama_stack.apis.common.training_types import Checkpoint from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @@ -140,7 +139,7 @@ class PostTrainingRLHFRequest(BaseModel): @json_schema_type -class PostTrainingJob(JobCommonFields): +class PostTrainingJob(CommonJobFields): resources_allocated: Optional[Dict[str, Any]] = None checkpoints: List[Checkpoint] = Field(default_factory=list)