feat: add /jobs API

This API will be later tied to jobs as defined for specific flows
(post-training, eval, etc.) through the common scheduler mechanism.

Note: At the moment, API does nothing useful. (Except returning Not
Implemented errors when called.)

This is an alternative to developing per-flow jobs APIs. Eventually,
once /jobs API is implemented, we should be able to deprecate existing
APIs under /v1/post-training/, /v1/eval/ etc.

See #1587 (tracker)
See #1238 (design details)

Note: This is an alternative path to #1582 and #1583.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-12 16:15:30 -04:00
parent 0fdb15bcc7
commit 90799cdcee
12 changed files with 557 additions and 11 deletions

View file

@ -36,6 +36,7 @@ class Api(Enum):
# built-in API
inspect = "inspect"
jobs = "jobs"
@json_schema_type

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .jobs import * # noqa: F401 F403

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class JobArtifact(BaseModel):
name: str
type: str
uri: str
metadata: dict
@json_schema_type
class JobInfo(BaseModel):
uuid: str
type: str
status: str
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
artifacts: List[JobArtifact]
class ListJobsResponse(BaseModel):
data: List[JobInfo]
@runtime_checkable
class Jobs(Protocol):
@webmethod(route="/jobs/{job_id}/cancel", method="POST")
async def cancel_job(
self,
job_id: str,
) -> None: ...
@webmethod(route="/jobs/{job_id}", method="DELETE")
async def delete_job(
self,
job_id: str,
) -> None: ...
@webmethod(route="/jobs", method="GET")
async def list_jobs(self) -> ListJobsResponse: ...
@webmethod(route="/jobs/{job_id}", method="GET")
async def get_job(
self,
job_id: str,
) -> JobInfo: ...

View file

@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
class ResourceType(Enum):
model = "model"
job = "job"
shield = "shield"
vector_db = "vector_db"
dataset = "dataset"

View file

@ -56,7 +56,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]:
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
return [api for api in Api if api not in routing_table_apis and api not in (Api.inspect, Api.jobs)]
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:

View file

@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.apis.jobs import (
JobInfo,
Jobs,
ListJobsResponse,
)
from llama_stack.distribution.datatypes import StackRunConfig
class DistributionJobsConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = DistributionJobsImpl(config, deps)
await impl.initialize()
return impl
class DistributionJobsImpl(Jobs):
def __init__(self, config, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def list_jobs(self) -> ListJobsResponse:
raise NotImplementedError
async def delete_job(self, job_id: str) -> None:
raise NotImplementedError
async def cancel_job(self, job_id: str) -> None:
raise NotImplementedError
async def get_job(self, job_id: str) -> JobInfo:
raise NotImplementedError

View file

@ -14,6 +14,7 @@ from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.jobs import Jobs
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.safety import Safety
@ -62,6 +63,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
Api.jobs: Jobs,
Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs,
Api.models: Models,
@ -226,26 +228,32 @@ def sort_providers_by_deps(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
# Append built-in "inspect" provider
# Append built-in providers
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
deps = [x.value for x in apis]
config = run_config.model_dump()
sorted_providers += [
(
"inspect",
name,
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={"run_config": run_config.model_dump()},
config={"run_config": config},
spec=InlineProviderSpec(
api=Api.inspect,
api=api,
provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect",
config_class=f"llama_stack.distribution.{name}.Distribution{name.title()}Config",
module=f"llama_stack.distribution.{name}",
api_dependencies=apis,
deps__=[x.value for x in apis],
deps__=deps,
),
),
)
)
for name, api in [
("inspect", Api.inspect),
("jobs", Api.jobs),
]
]
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:

View file

@ -367,7 +367,9 @@ def main():
continue
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
# also include builtin APIs
apis_to_serve += {"inspect", "jobs"}
for api_str in apis_to_serve:
api = Api(api_str)

View file

@ -21,6 +21,7 @@ from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.jobs import Jobs
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.safety import Safety
@ -62,6 +63,7 @@ class LlamaStack(
Models,
Shields,
Inspect,
Jobs,
ToolGroups,
ToolRuntime,
RAGToolRuntime,