forked from phoenix-oss/llama-stack-mirror
Add version to REST API url (#478)
# What does this PR do? Adds a `/alpha/` prefix to all the REST API urls. Also makes them all use hyphens instead of underscores as is more standard practice. (This is based on feedback from our partners.) ## Test Plan The Stack itself does not need updating. However, client SDKs and documentation will need to be updated.
This commit is contained in:
parent
05e93bd2f7
commit
0dc7f5fa89
18 changed files with 32842 additions and 6032 deletions
|
@ -31,7 +31,12 @@ from .strong_typing.schema import json_schema_type
|
||||||
|
|
||||||
schema_utils.json_schema_type = json_schema_type
|
schema_utils.json_schema_type = json_schema_type
|
||||||
|
|
||||||
from llama_stack.distribution.stack import LlamaStack
|
# this line needs to be here to ensure json_schema_type has been altered before
|
||||||
|
# the imports use the annotation
|
||||||
|
from llama_stack.distribution.stack import ( # noqa: E402
|
||||||
|
LLAMA_STACK_API_VERSION,
|
||||||
|
LlamaStack,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(output_dir: str):
|
def main(output_dir: str):
|
||||||
|
@ -50,7 +55,7 @@ def main(output_dir: str):
|
||||||
server=Server(url="http://any-hosted-llama-stack.com"),
|
server=Server(url="http://any-hosted-llama-stack.com"),
|
||||||
info=Info(
|
info=Info(
|
||||||
title="[DRAFT] Llama Stack Specification",
|
title="[DRAFT] Llama Stack Specification",
|
||||||
version="0.0.1",
|
version=LLAMA_STACK_API_VERSION,
|
||||||
description="""This is the specification of the llama stack that provides
|
description="""This is the specification of the llama stack that provides
|
||||||
a set of endpoints and their corresponding interfaces that are tailored to
|
a set of endpoints and their corresponding interfaces that are tailored to
|
||||||
best leverage Llama Models. The specification is still in draft and subject to change.
|
best leverage Llama Models. The specification is still in draft and subject to change.
|
||||||
|
|
|
@ -202,7 +202,9 @@ class ContentBuilder:
|
||||||
) -> MediaType:
|
) -> MediaType:
|
||||||
schema = self.schema_builder.classdef_to_ref(item_type)
|
schema = self.schema_builder.classdef_to_ref(item_type)
|
||||||
if self.schema_transformer:
|
if self.schema_transformer:
|
||||||
schema_transformer: Callable[[SchemaOrRef], SchemaOrRef] = self.schema_transformer # type: ignore
|
schema_transformer: Callable[[SchemaOrRef], SchemaOrRef] = (
|
||||||
|
self.schema_transformer
|
||||||
|
) # type: ignore
|
||||||
schema = schema_transformer(schema)
|
schema = schema_transformer(schema)
|
||||||
|
|
||||||
if not examples:
|
if not examples:
|
||||||
|
@ -630,6 +632,7 @@ class Generator:
|
||||||
raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
|
raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
|
||||||
|
|
||||||
route = op.get_route()
|
route = op.get_route()
|
||||||
|
print(f"route: {route}")
|
||||||
if route in paths:
|
if route in paths:
|
||||||
paths[route].update(pathItem)
|
paths[route].update(pathItem)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -12,6 +12,8 @@ import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from llama_stack.distribution.stack import LLAMA_STACK_API_VERSION
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from ..strong_typing.inspection import (
|
from ..strong_typing.inspection import (
|
||||||
|
@ -111,9 +113,12 @@ class EndpointOperation:
|
||||||
|
|
||||||
def get_route(self) -> str:
|
def get_route(self) -> str:
|
||||||
if self.route is not None:
|
if self.route is not None:
|
||||||
return self.route
|
assert (
|
||||||
|
"_" not in self.route
|
||||||
|
), f"route should not contain underscores: {self.route}"
|
||||||
|
return "/".join(["", LLAMA_STACK_API_VERSION, self.route.lstrip("/")])
|
||||||
|
|
||||||
route_parts = ["", self.name]
|
route_parts = ["", LLAMA_STACK_API_VERSION, self.name]
|
||||||
for param_name, _ in self.path_params:
|
for param_name, _ in self.path_params:
|
||||||
route_parts.append("{" + param_name + "}")
|
route_parts.append("{" + param_name + "}")
|
||||||
return "/".join(route_parts)
|
return "/".join(route_parts)
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -49,7 +49,7 @@ class BatchChatCompletionResponse(BaseModel):
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class BatchInference(Protocol):
|
class BatchInference(Protocol):
|
||||||
@webmethod(route="/batch_inference/completion")
|
@webmethod(route="/batch-inference/completion")
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -58,7 +58,7 @@ class BatchInference(Protocol):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchCompletionResponse: ...
|
) -> BatchCompletionResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/batch_inference/chat_completion")
|
@webmethod(route="/batch-inference/chat-completion")
|
||||||
async def batch_chat_completion(
|
async def batch_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -29,7 +29,7 @@ class DatasetIO(Protocol):
|
||||||
# keeping for aligning with inference/safety, but this is not used
|
# keeping for aligning with inference/safety, but this is not used
|
||||||
dataset_store: DatasetStore
|
dataset_store: DatasetStore
|
||||||
|
|
||||||
@webmethod(route="/datasetio/get_rows_paginated", method="GET")
|
@webmethod(route="/datasetio/get-rows-paginated", method="GET")
|
||||||
async def get_rows_paginated(
|
async def get_rows_paginated(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
|
|
@ -74,14 +74,14 @@ class EvaluateResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
class Eval(Protocol):
|
||||||
@webmethod(route="/eval/run_eval", method="POST")
|
@webmethod(route="/eval/run-eval", method="POST")
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
task_config: EvalTaskConfig,
|
task_config: EvalTaskConfig,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
@webmethod(route="/eval/evaluate-rows", method="POST")
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
|
|
|
@ -42,13 +42,13 @@ class EvalTaskInput(CommonEvalTaskFields, BaseModel):
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class EvalTasks(Protocol):
|
class EvalTasks(Protocol):
|
||||||
@webmethod(route="/eval_tasks/list", method="GET")
|
@webmethod(route="/eval-tasks/list", method="GET")
|
||||||
async def list_eval_tasks(self) -> List[EvalTask]: ...
|
async def list_eval_tasks(self) -> List[EvalTask]: ...
|
||||||
|
|
||||||
@webmethod(route="/eval_tasks/get", method="GET")
|
@webmethod(route="/eval-tasks/get", method="GET")
|
||||||
async def get_eval_task(self, name: str) -> Optional[EvalTask]: ...
|
async def get_eval_task(self, name: str) -> Optional[EvalTask]: ...
|
||||||
|
|
||||||
@webmethod(route="/eval_tasks/register", method="POST")
|
@webmethod(route="/eval-tasks/register", method="POST")
|
||||||
async def register_eval_task(
|
async def register_eval_task(
|
||||||
self,
|
self,
|
||||||
eval_task_id: str,
|
eval_task_id: str,
|
||||||
|
|
|
@ -234,7 +234,7 @@ class Inference(Protocol):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(route="/inference/chat_completion")
|
@webmethod(route="/inference/chat-completion")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -130,13 +130,13 @@ class MemoryBankInput(BaseModel):
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class MemoryBanks(Protocol):
|
class MemoryBanks(Protocol):
|
||||||
@webmethod(route="/memory_banks/list", method="GET")
|
@webmethod(route="/memory-banks/list", method="GET")
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/get", method="GET")
|
@webmethod(route="/memory-banks/get", method="GET")
|
||||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ...
|
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/register", method="POST")
|
@webmethod(route="/memory-banks/register", method="POST")
|
||||||
async def register_memory_bank(
|
async def register_memory_bank(
|
||||||
self,
|
self,
|
||||||
memory_bank_id: str,
|
memory_bank_id: str,
|
||||||
|
@ -145,5 +145,5 @@ class MemoryBanks(Protocol):
|
||||||
provider_memory_bank_id: Optional[str] = None,
|
provider_memory_bank_id: Optional[str] = None,
|
||||||
) -> MemoryBank: ...
|
) -> MemoryBank: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/unregister", method="POST")
|
@webmethod(route="/memory-banks/unregister", method="POST")
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||||
|
|
|
@ -176,7 +176,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class PostTraining(Protocol):
|
class PostTraining(Protocol):
|
||||||
@webmethod(route="/post_training/supervised_fine_tune")
|
@webmethod(route="/post-training/supervised-fine-tune")
|
||||||
def supervised_fine_tune(
|
def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
|
@ -193,7 +193,7 @@ class PostTraining(Protocol):
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post_training/preference_optimize")
|
@webmethod(route="/post-training/preference-optimize")
|
||||||
def preference_optimize(
|
def preference_optimize(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
|
@ -208,22 +208,22 @@ class PostTraining(Protocol):
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post_training/jobs")
|
@webmethod(route="/post-training/jobs")
|
||||||
def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||||
|
|
||||||
# sends SSE stream of logs
|
# sends SSE stream of logs
|
||||||
@webmethod(route="/post_training/job/logs")
|
@webmethod(route="/post-training/job/logs")
|
||||||
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
||||||
|
|
||||||
@webmethod(route="/post_training/job/status")
|
@webmethod(route="/post-training/job/status")
|
||||||
def get_training_job_status(
|
def get_training_job_status(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobStatusResponse: ...
|
) -> PostTrainingJobStatusResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/post_training/job/cancel")
|
@webmethod(route="/post-training/job/cancel")
|
||||||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/post_training/job/artifacts")
|
@webmethod(route="/post-training/job/artifacts")
|
||||||
def get_training_job_artifacts(
|
def get_training_job_artifacts(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobArtifactsResponse: ...
|
) -> PostTrainingJobArtifactsResponse: ...
|
||||||
|
|
|
@ -46,7 +46,7 @@ class ShieldStore(Protocol):
|
||||||
class Safety(Protocol):
|
class Safety(Protocol):
|
||||||
shield_store: ShieldStore
|
shield_store: ShieldStore
|
||||||
|
|
||||||
@webmethod(route="/safety/run_shield")
|
@webmethod(route="/safety/run-shield")
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
|
|
|
@ -44,7 +44,7 @@ class ScoringFunctionStore(Protocol):
|
||||||
class Scoring(Protocol):
|
class Scoring(Protocol):
|
||||||
scoring_function_store: ScoringFunctionStore
|
scoring_function_store: ScoringFunctionStore
|
||||||
|
|
||||||
@webmethod(route="/scoring/score_batch")
|
@webmethod(route="/scoring/score-batch")
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
|
|
@ -104,13 +104,13 @@ class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class ScoringFunctions(Protocol):
|
class ScoringFunctions(Protocol):
|
||||||
@webmethod(route="/scoring_functions/list", method="GET")
|
@webmethod(route="/scoring-functions/list", method="GET")
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring_functions/get", method="GET")
|
@webmethod(route="/scoring-functions/get", method="GET")
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
|
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring_functions/register", method="POST")
|
@webmethod(route="/scoring-functions/register", method="POST")
|
||||||
async def register_scoring_function(
|
async def register_scoring_function(
|
||||||
self,
|
self,
|
||||||
scoring_fn_id: str,
|
scoring_fn_id: str,
|
||||||
|
|
|
@ -44,7 +44,7 @@ class SyntheticDataGenerationResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataGeneration(Protocol):
|
class SyntheticDataGeneration(Protocol):
|
||||||
@webmethod(route="/synthetic_data_generation/generate")
|
@webmethod(route="/synthetic-data-generation/generate")
|
||||||
def synthetic_data_generate(
|
def synthetic_data_generate(
|
||||||
self,
|
self,
|
||||||
dialogs: List[Message],
|
dialogs: List[Message],
|
||||||
|
|
|
@ -125,8 +125,8 @@ Event = Annotated[
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Telemetry(Protocol):
|
class Telemetry(Protocol):
|
||||||
@webmethod(route="/telemetry/log_event")
|
@webmethod(route="/telemetry/log-event")
|
||||||
async def log_event(self, event: Event) -> None: ...
|
async def log_event(self, event: Event) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/get_trace", method="GET")
|
@webmethod(route="/telemetry/get-trace", method="GET")
|
||||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||||
|
|
|
@ -40,6 +40,9 @@ from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
LLAMA_STACK_API_VERSION = "alpha"
|
||||||
|
|
||||||
|
|
||||||
class LlamaStack(
|
class LlamaStack(
|
||||||
MemoryBanks,
|
MemoryBanks,
|
||||||
Inference,
|
Inference,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue