diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 709360ede..618153319 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -230,6 +230,41 @@ } } }, + "/v1/jobs/{job_id}/cancel": { + "post": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Jobs" + ], + "description": "", + "parameters": [ + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ] + } + }, "/v1/post-training/job/cancel": { "post": { "responses": { @@ -925,6 +960,81 @@ ] } }, + "/v1/jobs/{job_id}": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/JobInfo" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Jobs" + ], + "description": "", + "parameters": [ + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ] + }, + "delete": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Jobs" + ], + "description": "", + "parameters": [ + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ] + } + }, "/v1/inference/embeddings": { "post": { "responses": { @@ -2568,6 +2678,39 @@ ] } }, + "/v1/jobs": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListJobsResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Jobs" + ], + "description": "", + "parameters": [] + } + }, "/v1/models": { "get": { "responses": { @@ -4715,6 +4858,12 @@ "CompletionResponse": { "type": "object", "properties": { + "metrics": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MetricEvent" + } + }, "content": { "type": "string", "description": "The generated completion text" @@ -5082,6 +5231,12 @@ "CompletionResponseStreamChunk": { "type": "object", "properties": { + "metrics": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MetricEvent" + } + }, "delta": { "type": "string", "description": "New content generated since last chunk. This can be one or more tokens." @@ -7094,6 +7249,73 @@ ], "title": "UnionType" }, + "JobArtifact": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "type": { + "type": "string" + }, + "uri": { + "type": "string" + }, + "metadata": { + "type": "object", + "title": "dict", + "description": "dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)" + } + }, + "additionalProperties": false, + "required": [ + "name", + "type", + "uri", + "metadata" + ], + "title": "JobArtifact" + }, + "JobInfo": { + "type": "object", + "properties": { + "uuid": { + "type": "string" + }, + "type": { + "type": "string" + }, + "status": { + "type": "string" + }, + "scheduled_at": { + "type": "string", + "format": "date-time" + }, + "started_at": { + "type": "string", + "format": "date-time" + }, + "completed_at": { + "type": "string", + "format": "date-time" + }, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/components/schemas/JobArtifact" + } + } + }, + "additionalProperties": false, + "required": [ + "uuid", + "type", + "status", + "artifacts" + ], + "title": "JobInfo" + }, "Model": { "type": "object", "properties": { @@ -8157,6 +8379,22 @@ "title": "ListFileResponse", "description": "Response representing a list of file entries." }, + "ListJobsResponse": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/JobInfo" + } + } + }, + "additionalProperties": false, + "required": [ + "data" + ], + "title": "ListJobsResponse" + }, "ListModelsResponse": { "type": "object", "properties": { @@ -10119,6 +10357,9 @@ { "name": "Inspect" }, + { + "name": "Jobs" + }, { "name": "Models" }, @@ -10169,6 +10410,7 @@ "Files (Coming Soon)", "Inference", "Inspect", + "Jobs", "Models", "PostTraining (Coming Soon)", "Safety", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 4c00fbe63..8cc7779ea 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -142,6 +142,30 @@ paths: schema: $ref: '#/components/schemas/BatchCompletionRequest' required: true + /v1/jobs/{job_id}/cancel: + post: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Jobs + description: '' + parameters: + - name: job_id + in: path + required: true + schema: + type: string /v1/post-training/job/cancel: post: responses: @@ -633,6 +657,57 @@ paths: required: true schema: type: string + /v1/jobs/{job_id}: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/JobInfo' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Jobs + description: '' + parameters: + - name: job_id + in: path + required: true + schema: + type: string + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Jobs + description: '' + parameters: + - name: job_id + in: path + required: true + schema: + type: string /v1/inference/embeddings: post: responses: @@ -1731,6 +1806,29 @@ paths: required: true schema: type: string + /v1/jobs: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/ListJobsResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Jobs + description: '' + parameters: [] /v1/models: get: responses: @@ -3213,6 +3311,10 @@ components: CompletionResponse: type: object properties: + metrics: + type: array + items: + $ref: '#/components/schemas/MetricEvent' content: type: string description: The generated completion text @@ -3531,6 +3633,10 @@ components: CompletionResponseStreamChunk: type: object properties: + metrics: + type: array + items: + $ref: '#/components/schemas/MetricEvent' delta: type: string description: >- @@ -4901,6 +5007,60 @@ components: required: - type title: UnionType + JobArtifact: + type: object + properties: + name: + type: string + type: + type: string + uri: + type: string + metadata: + type: object + title: dict + description: >- + dict() -> new empty dictionary dict(mapping) -> new dictionary initialized + from a mapping object's (key, value) pairs dict(iterable) -> new dictionary + initialized as if via: d = {} for k, v in iterable: d[k] + = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in + the keyword argument list. For example: dict(one=1, two=2) + additionalProperties: false + required: + - name + - type + - uri + - metadata + title: JobArtifact + JobInfo: + type: object + properties: + uuid: + type: string + type: + type: string + status: + type: string + scheduled_at: + type: string + format: date-time + started_at: + type: string + format: date-time + completed_at: + type: string + format: date-time + artifacts: + type: array + items: + $ref: '#/components/schemas/JobArtifact' + additionalProperties: false + required: + - uuid + - type + - status + - artifacts + title: JobInfo Model: type: object properties: @@ -5562,6 +5722,17 @@ components: title: ListFileResponse description: >- Response representing a list of file entries. + ListJobsResponse: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/JobInfo' + additionalProperties: false + required: + - data + title: ListJobsResponse ListModelsResponse: type: object properties: @@ -6818,6 +6989,7 @@ tags: Llama Stack Inference API for generating completions, chat completions, and embeddings. - name: Inspect + - name: Jobs - name: Models - name: PostTraining (Coming Soon) - name: Safety @@ -6842,6 +7014,7 @@ x-tagGroups: - Files (Coming Soon) - Inference - Inspect + - Jobs - Models - PostTraining (Coming Soon) - Safety diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 842a2b63d..866974ee7 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -36,6 +36,7 @@ class Api(Enum): # built-in API inspect = "inspect" + jobs = "jobs" @json_schema_type diff --git a/llama_stack/apis/jobs/__init__.py b/llama_stack/apis/jobs/__init__.py new file mode 100644 index 000000000..2bb4ff26c --- /dev/null +++ b/llama_stack/apis/jobs/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .jobs import * # noqa: F401 F403 diff --git a/llama_stack/apis/jobs/jobs.py b/llama_stack/apis/jobs/jobs.py new file mode 100644 index 000000000..af4a03b67 --- /dev/null +++ b/llama_stack/apis/jobs/jobs.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime +from typing import List, Optional, Protocol, runtime_checkable + +from pydantic import BaseModel + +from llama_stack.schema_utils import json_schema_type, webmethod + + +@json_schema_type +class JobArtifact(BaseModel): + name: str + type: str + uri: str + metadata: dict + + +@json_schema_type +class JobInfo(BaseModel): + uuid: str + type: str + status: str + + scheduled_at: Optional[datetime] = None + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + artifacts: List[JobArtifact] + + +class ListJobsResponse(BaseModel): + data: List[JobInfo] + + +@runtime_checkable +class Jobs(Protocol): + @webmethod(route="/jobs/{job_id}/cancel", method="POST") + async def cancel_job( + self, + job_id: str, + ) -> None: ... + + @webmethod(route="/jobs/{job_id}", method="DELETE") + async def delete_job( + self, + job_id: str, + ) -> None: ... + + @webmethod(route="/jobs", method="GET") + async def list_jobs(self) -> ListJobsResponse: ... + + @webmethod(route="/jobs/{job_id}", method="GET") + async def get_job( + self, + job_id: str, + ) -> JobInfo: ... diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index 70ec63c55..a23f9d4fd 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, Field class ResourceType(Enum): model = "model" + job = "job" shield = "shield" vector_db = "vector_db" dataset = "dataset" diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 308081415..e35d2e562 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -56,7 +56,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: def providable_apis() -> List[Api]: routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()} - return [api for api in Api if api not in routing_table_apis and api != Api.inspect] + return [api for api in Api if api not in routing_table_apis and api not in (Api.inspect, Api.jobs)] def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: diff --git a/llama_stack/distribution/jobs.py b/llama_stack/distribution/jobs.py new file mode 100644 index 000000000..3c6dcb1e3 --- /dev/null +++ b/llama_stack/distribution/jobs.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from llama_stack.apis.jobs import ( + JobInfo, + Jobs, + ListJobsResponse, +) +from llama_stack.distribution.datatypes import StackRunConfig + + +class DistributionJobsConfig(BaseModel): + run_config: StackRunConfig + + +async def get_provider_impl(config, deps): + impl = DistributionJobsImpl(config, deps) + await impl.initialize() + return impl + + +class DistributionJobsImpl(Jobs): + def __init__(self, config, deps): + self.config = config + self.deps = deps + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def list_jobs(self) -> ListJobsResponse: + raise NotImplementedError + + async def delete_job(self, job_id: str) -> None: + raise NotImplementedError + + async def cancel_job(self, job_id: str) -> None: + raise NotImplementedError + + async def get_job(self, job_id: str) -> JobInfo: + raise NotImplementedError diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index ab075f399..abbad9ae1 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -14,6 +14,7 @@ from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval import Eval from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect +from llama_stack.apis.jobs import Jobs from llama_stack.apis.models import Models from llama_stack.apis.post_training import PostTraining from llama_stack.apis.safety import Safety @@ -62,6 +63,7 @@ def api_protocol_map() -> Dict[Api, Any]: Api.agents: Agents, Api.inference: Inference, Api.inspect: Inspect, + Api.jobs: Jobs, Api.vector_io: VectorIO, Api.vector_dbs: VectorDBs, Api.models: Models, @@ -226,26 +228,32 @@ def sort_providers_by_deps( {k: list(v.values()) for k, v in providers_with_specs.items()} ) - # Append built-in "inspect" provider + # Append built-in providers apis = [x[1].spec.api for x in sorted_providers] - sorted_providers.append( + deps = [x.value for x in apis] + config = run_config.model_dump() + sorted_providers += [ ( - "inspect", + name, ProviderWithSpec( provider_id="__builtin__", provider_type="__builtin__", - config={"run_config": run_config.model_dump()}, + config={"run_config": config}, spec=InlineProviderSpec( - api=Api.inspect, + api=api, provider_type="__builtin__", - config_class="llama_stack.distribution.inspect.DistributionInspectConfig", - module="llama_stack.distribution.inspect", + config_class=f"llama_stack.distribution.{name}.Distribution{name.title()}Config", + module=f"llama_stack.distribution.{name}", api_dependencies=apis, - deps__=[x.value for x in apis], + deps__=deps, ), ), ) - ) + for name, api in [ + ("inspect", Api.inspect), + ("jobs", Api.jobs), + ] + ] logger.debug(f"Resolved {len(sorted_providers)} providers") for api_str, provider in sorted_providers: diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 7ca009b13..d438a1559 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -367,7 +367,9 @@ def main(): continue apis_to_serve.add(inf.routing_table_api.value) - apis_to_serve.add("inspect") + # also include builtin APIs + apis_to_serve += {"inspect", "jobs"} + for api_str in apis_to_serve: api = Api(api_str) diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 2b974739a..4a6fce62e 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -21,6 +21,7 @@ from llama_stack.apis.eval import Eval from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect +from llama_stack.apis.jobs import Jobs from llama_stack.apis.models import Models from llama_stack.apis.post_training import PostTraining from llama_stack.apis.safety import Safety @@ -62,6 +63,7 @@ class LlamaStack( Models, Shields, Inspect, + Jobs, ToolGroups, ToolRuntime, RAGToolRuntime, diff --git a/pyproject.toml b/pyproject.toml index 055fa7a55..80afaaef4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,6 +168,7 @@ exclude = [ "^llama_stack/apis/files/files\\.py$", "^llama_stack/apis/inference/inference\\.py$", "^llama_stack/apis/inspect/inspect\\.py$", + "^llama_stack/apis/jobs/jobs\\.py$", "^llama_stack/apis/models/models\\.py$", "^llama_stack/apis/post_training/post_training\\.py$", "^llama_stack/apis/resource\\.py$",