Add version to REST API url (#478)

# What does this PR do? 

Adds a `/alpha/` prefix to all the REST API urls.

Also makes them all use hyphens instead of underscores as is more
standard practice.

(This is based on feedback from our partners.)

## Test Plan 

The Stack itself does not need updating. However, client SDKs and
documentation will need to be updated.
This commit is contained in:
Ashwin Bharambe 2024-11-18 22:44:14 -08:00 committed by GitHub
parent 05e93bd2f7
commit 0dc7f5fa89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 32842 additions and 6032 deletions

View file

@ -31,7 +31,12 @@ from .strong_typing.schema import json_schema_type
schema_utils.json_schema_type = json_schema_type schema_utils.json_schema_type = json_schema_type
from llama_stack.distribution.stack import LlamaStack # this line needs to be here to ensure json_schema_type has been altered before
# the imports use the annotation
from llama_stack.distribution.stack import ( # noqa: E402
LLAMA_STACK_API_VERSION,
LlamaStack,
)
def main(output_dir: str): def main(output_dir: str):
@ -50,7 +55,7 @@ def main(output_dir: str):
server=Server(url="http://any-hosted-llama-stack.com"), server=Server(url="http://any-hosted-llama-stack.com"),
info=Info( info=Info(
title="[DRAFT] Llama Stack Specification", title="[DRAFT] Llama Stack Specification",
version="0.0.1", version=LLAMA_STACK_API_VERSION,
description="""This is the specification of the llama stack that provides description="""This is the specification of the llama stack that provides
a set of endpoints and their corresponding interfaces that are tailored to a set of endpoints and their corresponding interfaces that are tailored to
best leverage Llama Models. The specification is still in draft and subject to change. best leverage Llama Models. The specification is still in draft and subject to change.

View file

@ -202,7 +202,9 @@ class ContentBuilder:
) -> MediaType: ) -> MediaType:
schema = self.schema_builder.classdef_to_ref(item_type) schema = self.schema_builder.classdef_to_ref(item_type)
if self.schema_transformer: if self.schema_transformer:
schema_transformer: Callable[[SchemaOrRef], SchemaOrRef] = self.schema_transformer # type: ignore schema_transformer: Callable[[SchemaOrRef], SchemaOrRef] = (
self.schema_transformer
) # type: ignore
schema = schema_transformer(schema) schema = schema_transformer(schema)
if not examples: if not examples:
@ -630,6 +632,7 @@ class Generator:
raise NotImplementedError(f"unknown HTTP method: {op.http_method}") raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
route = op.get_route() route = op.get_route()
print(f"route: {route}")
if route in paths: if route in paths:
paths[route].update(pathItem) paths[route].update(pathItem)
else: else:

View file

@ -12,6 +12,8 @@ import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from llama_stack.distribution.stack import LLAMA_STACK_API_VERSION
from termcolor import colored from termcolor import colored
from ..strong_typing.inspection import ( from ..strong_typing.inspection import (
@ -111,9 +113,12 @@ class EndpointOperation:
def get_route(self) -> str: def get_route(self) -> str:
if self.route is not None: if self.route is not None:
return self.route assert (
"_" not in self.route
), f"route should not contain underscores: {self.route}"
return "/".join(["", LLAMA_STACK_API_VERSION, self.route.lstrip("/")])
route_parts = ["", self.name] route_parts = ["", LLAMA_STACK_API_VERSION, self.name]
for param_name, _ in self.path_params: for param_name, _ in self.path_params:
route_parts.append("{" + param_name + "}") route_parts.append("{" + param_name + "}")
return "/".join(route_parts) return "/".join(route_parts)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -49,7 +49,7 @@ class BatchChatCompletionResponse(BaseModel):
@runtime_checkable @runtime_checkable
class BatchInference(Protocol): class BatchInference(Protocol):
@webmethod(route="/batch_inference/completion") @webmethod(route="/batch-inference/completion")
async def batch_completion( async def batch_completion(
self, self,
model: str, model: str,
@ -58,7 +58,7 @@ class BatchInference(Protocol):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ... ) -> BatchCompletionResponse: ...
@webmethod(route="/batch_inference/chat_completion") @webmethod(route="/batch-inference/chat-completion")
async def batch_chat_completion( async def batch_chat_completion(
self, self,
model: str, model: str,

View file

@ -29,7 +29,7 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used # keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore dataset_store: DatasetStore
@webmethod(route="/datasetio/get_rows_paginated", method="GET") @webmethod(route="/datasetio/get-rows-paginated", method="GET")
async def get_rows_paginated( async def get_rows_paginated(
self, self,
dataset_id: str, dataset_id: str,

View file

@ -74,14 +74,14 @@ class EvaluateResponse(BaseModel):
class Eval(Protocol): class Eval(Protocol):
@webmethod(route="/eval/run_eval", method="POST") @webmethod(route="/eval/run-eval", method="POST")
async def run_eval( async def run_eval(
self, self,
task_id: str, task_id: str,
task_config: EvalTaskConfig, task_config: EvalTaskConfig,
) -> Job: ... ) -> Job: ...
@webmethod(route="/eval/evaluate_rows", method="POST") @webmethod(route="/eval/evaluate-rows", method="POST")
async def evaluate_rows( async def evaluate_rows(
self, self,
task_id: str, task_id: str,

View file

@ -42,13 +42,13 @@ class EvalTaskInput(CommonEvalTaskFields, BaseModel):
@runtime_checkable @runtime_checkable
class EvalTasks(Protocol): class EvalTasks(Protocol):
@webmethod(route="/eval_tasks/list", method="GET") @webmethod(route="/eval-tasks/list", method="GET")
async def list_eval_tasks(self) -> List[EvalTask]: ... async def list_eval_tasks(self) -> List[EvalTask]: ...
@webmethod(route="/eval_tasks/get", method="GET") @webmethod(route="/eval-tasks/get", method="GET")
async def get_eval_task(self, name: str) -> Optional[EvalTask]: ... async def get_eval_task(self, name: str) -> Optional[EvalTask]: ...
@webmethod(route="/eval_tasks/register", method="POST") @webmethod(route="/eval-tasks/register", method="POST")
async def register_eval_task( async def register_eval_task(
self, self,
eval_task_id: str, eval_task_id: str,

View file

@ -234,7 +234,7 @@ class Inference(Protocol):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ... ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
@webmethod(route="/inference/chat_completion") @webmethod(route="/inference/chat-completion")
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,

View file

@ -130,13 +130,13 @@ class MemoryBankInput(BaseModel):
@runtime_checkable @runtime_checkable
class MemoryBanks(Protocol): class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET") @webmethod(route="/memory-banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ... async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get", method="GET") @webmethod(route="/memory-banks/get", method="GET")
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ... async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory_banks/register", method="POST") @webmethod(route="/memory-banks/register", method="POST")
async def register_memory_bank( async def register_memory_bank(
self, self,
memory_bank_id: str, memory_bank_id: str,
@ -145,5 +145,5 @@ class MemoryBanks(Protocol):
provider_memory_bank_id: Optional[str] = None, provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank: ... ) -> MemoryBank: ...
@webmethod(route="/memory_banks/unregister", method="POST") @webmethod(route="/memory-banks/unregister", method="POST")
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ... async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...

View file

@ -176,7 +176,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol): class PostTraining(Protocol):
@webmethod(route="/post_training/supervised_fine_tune") @webmethod(route="/post-training/supervised-fine-tune")
def supervised_fine_tune( def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
@ -193,7 +193,7 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post_training/preference_optimize") @webmethod(route="/post-training/preference-optimize")
def preference_optimize( def preference_optimize(
self, self,
job_uuid: str, job_uuid: str,
@ -208,22 +208,22 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post_training/jobs") @webmethod(route="/post-training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: ... def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs # sends SSE stream of logs
@webmethod(route="/post_training/job/logs") @webmethod(route="/post-training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ... def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
@webmethod(route="/post_training/job/status") @webmethod(route="/post-training/job/status")
def get_training_job_status( def get_training_job_status(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobStatusResponse: ... ) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post_training/job/cancel") @webmethod(route="/post-training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ... def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post_training/job/artifacts") @webmethod(route="/post-training/job/artifacts")
def get_training_job_artifacts( def get_training_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ... ) -> PostTrainingJobArtifactsResponse: ...

View file

@ -46,7 +46,7 @@ class ShieldStore(Protocol):
class Safety(Protocol): class Safety(Protocol):
shield_store: ShieldStore shield_store: ShieldStore
@webmethod(route="/safety/run_shield") @webmethod(route="/safety/run-shield")
async def run_shield( async def run_shield(
self, self,
shield_id: str, shield_id: str,

View file

@ -44,7 +44,7 @@ class ScoringFunctionStore(Protocol):
class Scoring(Protocol): class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score_batch") @webmethod(route="/scoring/score-batch")
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,

View file

@ -104,13 +104,13 @@ class ScoringFnInput(CommonScoringFnFields, BaseModel):
@runtime_checkable @runtime_checkable
class ScoringFunctions(Protocol): class ScoringFunctions(Protocol):
@webmethod(route="/scoring_functions/list", method="GET") @webmethod(route="/scoring-functions/list", method="GET")
async def list_scoring_functions(self) -> List[ScoringFn]: ... async def list_scoring_functions(self) -> List[ScoringFn]: ...
@webmethod(route="/scoring_functions/get", method="GET") @webmethod(route="/scoring-functions/get", method="GET")
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ... async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring_functions/register", method="POST") @webmethod(route="/scoring-functions/register", method="POST")
async def register_scoring_function( async def register_scoring_function(
self, self,
scoring_fn_id: str, scoring_fn_id: str,

View file

@ -44,7 +44,7 @@ class SyntheticDataGenerationResponse(BaseModel):
class SyntheticDataGeneration(Protocol): class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate") @webmethod(route="/synthetic-data-generation/generate")
def synthetic_data_generate( def synthetic_data_generate(
self, self,
dialogs: List[Message], dialogs: List[Message],

View file

@ -125,8 +125,8 @@ Event = Annotated[
@runtime_checkable @runtime_checkable
class Telemetry(Protocol): class Telemetry(Protocol):
@webmethod(route="/telemetry/log_event") @webmethod(route="/telemetry/log-event")
async def log_event(self, event: Event) -> None: ... async def log_event(self, event: Event) -> None: ...
@webmethod(route="/telemetry/get_trace", method="GET") @webmethod(route="/telemetry/get-trace", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ... async def get_trace(self, trace_id: str) -> Trace: ...

View file

@ -40,6 +40,9 @@ from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
LLAMA_STACK_API_VERSION = "alpha"
class LlamaStack( class LlamaStack(
MemoryBanks, MemoryBanks,
Inference, Inference,