mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 20:04:30 +00:00
feat: add /jobs API
This API will be later tied to jobs as defined for specific flows (post-training, eval, etc.) through the common scheduler mechanism. Note: At the moment, API does nothing useful. (Except returning Not Implemented errors when called.) This is an alternative to developing per-flow jobs APIs. Eventually, once /jobs API is implemented, we should be able to deprecate existing APIs under /v1/post-training/, /v1/eval/ etc. See #1587 (tracker) See #1238 (design details) Note: This is an alternative path to #1582 and #1583. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
0fdb15bcc7
commit
90799cdcee
12 changed files with 557 additions and 11 deletions
|
|
@ -56,7 +56,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
|
||||
def providable_apis() -> List[Api]:
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||
return [api for api in Api if api not in routing_table_apis and api not in (Api.inspect, Api.jobs)]
|
||||
|
||||
|
||||
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
|
|
|
|||
48
llama_stack/distribution/jobs.py
Normal file
48
llama_stack/distribution/jobs.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.jobs import (
|
||||
JobInfo,
|
||||
Jobs,
|
||||
ListJobsResponse,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
|
||||
|
||||
class DistributionJobsConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = DistributionJobsImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class DistributionJobsImpl(Jobs):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_jobs(self) -> ListJobsResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_job(self, job_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def cancel_job(self, job_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_job(self, job_id: str) -> JobInfo:
|
||||
raise NotImplementedError
|
||||
|
|
@ -14,6 +14,7 @@ from llama_stack.apis.datasets import Datasets
|
|||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.jobs import Jobs
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
|
@ -62,6 +63,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
Api.jobs: Jobs,
|
||||
Api.vector_io: VectorIO,
|
||||
Api.vector_dbs: VectorDBs,
|
||||
Api.models: Models,
|
||||
|
|
@ -226,26 +228,32 @@ def sort_providers_by_deps(
|
|||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||
)
|
||||
|
||||
# Append built-in "inspect" provider
|
||||
# Append built-in providers
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
deps = [x.value for x in apis]
|
||||
config = run_config.model_dump()
|
||||
sorted_providers += [
|
||||
(
|
||||
"inspect",
|
||||
name,
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
config={"run_config": config},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
api=api,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
config_class=f"llama_stack.distribution.{name}.Distribution{name.title()}Config",
|
||||
module=f"llama_stack.distribution.{name}",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
deps__=deps,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
for name, api in [
|
||||
("inspect", Api.inspect),
|
||||
("jobs", Api.jobs),
|
||||
]
|
||||
]
|
||||
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
|
|
|
|||
|
|
@ -367,7 +367,9 @@ def main():
|
|||
continue
|
||||
apis_to_serve.add(inf.routing_table_api.value)
|
||||
|
||||
apis_to_serve.add("inspect")
|
||||
# also include builtin APIs
|
||||
apis_to_serve += {"inspect", "jobs"}
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.eval import Eval
|
|||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.jobs import Jobs
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
|
@ -62,6 +63,7 @@ class LlamaStack(
|
|||
Models,
|
||||
Shields,
|
||||
Inspect,
|
||||
Jobs,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
RAGToolRuntime,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue