more idiomatic REST API

This commit is contained in:
Dinesh Yeduguru 2025-01-14 14:52:32 -08:00
parent d0a25dd453
commit b438dad8d2
29 changed files with 2144 additions and 1917 deletions

View file

@ -7,6 +7,7 @@
from datetime import datetime
from enum import Enum
from typing import (
Annotated,
Any,
AsyncIterator,
Dict,
@ -20,7 +21,6 @@ from typing import (
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
from llama_stack.apis.inference import (
@ -296,13 +296,13 @@ class AgentStepResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Agents(Protocol):
@webmethod(route="/agents/create")
@webmethod(route="/agents", method="POST")
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse: ...
@webmethod(route="/agents/turn/create")
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
async def create_agent_turn(
self,
agent_id: str,
@ -318,36 +318,52 @@ class Agents(Protocol):
toolgroups: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET"
)
async def get_agents_turn(
self, agent_id: str, session_id: str, turn_id: str
self,
agent_id: str,
session_id: str,
turn_id: str,
) -> Turn: ...
@webmethod(route="/agents/step/get")
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
)
async def get_agents_step(
self, agent_id: str, session_id: str, turn_id: str, step_id: str
self,
agent_id: str,
session_id: str,
turn_id: str,
step_id: str,
) -> AgentStepResponse: ...
@webmethod(route="/agents/session/create")
@webmethod(route="/agents/{agent_id}/session", method="POST")
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse: ...
@webmethod(route="/agents/session/get")
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
async def get_agents_session(
self,
agent_id: str,
session_id: str,
agent_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session: ...
@webmethod(route="/agents/session/delete")
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
async def delete_agents_session(
self,
session_id: str,
agent_id: str,
) -> None: ...
@webmethod(route="/agents/delete")
async def delete_agents(
@webmethod(route="/agents/{agent_id}", method="DELETE")
async def delete_agent(
self,
agent_id: str,
) -> None: ...

View file

@ -54,7 +54,7 @@ class BatchChatCompletionResponse(BaseModel):
@runtime_checkable
class BatchInference(Protocol):
@webmethod(route="/batch-inference/completion")
@webmethod(route="/batch-inference/completion", method="POST")
async def batch_completion(
self,
model: str,
@ -63,7 +63,7 @@ class BatchInference(Protocol):
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
@webmethod(route="/batch-inference/chat-completion")
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def batch_chat_completion(
self,
model: str,

View file

@ -29,7 +29,7 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/get-rows-paginated", method="GET")
@webmethod(route="/datasetio/rows", method="GET")
async def get_rows_paginated(
self,
dataset_id: str,
@ -38,7 +38,7 @@ class DatasetIO(Protocol):
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ...
@webmethod(route="/datasetio/append-rows", method="POST")
@webmethod(route="/datasetio/rows", method="POST")
async def append_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -7,11 +7,9 @@
from typing import Any, Dict, List, Literal, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
@ -44,8 +42,12 @@ class DatasetInput(CommonDatasetFields, BaseModel):
provider_dataset_id: Optional[str] = None
class ListDatasetsResponse(BaseModel):
data: List[Dataset]
class Datasets(Protocol):
@webmethod(route="/datasets/register", method="POST")
@webmethod(route="/datasets", method="POST")
async def register_dataset(
self,
dataset_id: str,
@ -56,16 +58,16 @@ class Datasets(Protocol):
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
@webmethod(route="/datasets/get", method="GET")
@webmethod(route="/datasets/{dataset_id}", method="GET")
async def get_dataset(
self,
dataset_id: str,
) -> Optional[Dataset]: ...
@webmethod(route="/datasets/list", method="GET")
async def list_datasets(self) -> List[Dataset]: ...
@webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ...
@webmethod(route="/datasets/unregister", method="POST")
@webmethod(route="/datasets/{dataset_id}", method="DELETE")
async def unregister_dataset(
self,
dataset_id: str,

View file

@ -7,9 +7,7 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
@ -76,7 +74,7 @@ class EvaluateResponse(BaseModel):
class Eval(Protocol):
@webmethod(route="/eval/run-eval", method="POST")
@webmethod(route="/eval/run", method="POST")
async def run_eval(
self,
task_id: str,
@ -92,11 +90,11 @@ class Eval(Protocol):
task_config: EvalTaskConfig,
) -> EvaluateResponse: ...
@webmethod(route="/eval/job/status", method="GET")
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/jobs/{job_id}", method="GET")
async def job_status(self, job_id: str, task_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/job/cancel", method="POST")
async def job_cancel(self, task_id: str, job_id: str) -> None: ...
@webmethod(route="/eval/jobs/cancel", method="POST")
async def job_cancel(self, job_id: str, task_id: str) -> None: ...
@webmethod(route="/eval/job/result", method="GET")
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...
@webmethod(route="/eval/jobs/{job_id}/result", method="GET")
async def job_result(self, job_id: str, task_id: str) -> EvaluateResponse: ...

View file

@ -6,7 +6,6 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
@ -40,15 +39,22 @@ class EvalTaskInput(CommonEvalTaskFields, BaseModel):
provider_eval_task_id: Optional[str] = None
class ListEvalTasksResponse(BaseModel):
data: List[EvalTask]
@runtime_checkable
class EvalTasks(Protocol):
@webmethod(route="/eval-tasks/list", method="GET")
async def list_eval_tasks(self) -> List[EvalTask]: ...
@webmethod(route="/eval-tasks", method="GET")
async def list_eval_tasks(self) -> ListEvalTasksResponse: ...
@webmethod(route="/eval-tasks/get", method="GET")
async def get_eval_task(self, name: str) -> Optional[EvalTask]: ...
@webmethod(route="/eval-tasks/{eval_task_id}", method="GET")
async def get_eval_task(
self,
eval_task_id: str,
) -> Optional[EvalTask]: ...
@webmethod(route="/eval-tasks/register", method="POST")
@webmethod(route="/eval-tasks", method="POST")
async def register_eval_task(
self,
eval_task_id: str,

View file

@ -291,7 +291,7 @@ class ModelStore(Protocol):
class Inference(Protocol):
model_store: ModelStore
@webmethod(route="/inference/completion")
@webmethod(route="/inference/completion", method="POST")
async def completion(
self,
model_id: str,
@ -302,7 +302,7 @@ class Inference(Protocol):
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
@webmethod(route="/inference/chat-completion")
@webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion(
self,
model_id: str,
@ -319,7 +319,7 @@ class Inference(Protocol):
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: ...
@webmethod(route="/inference/embeddings")
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
self,
model_id: str,

View file

@ -34,10 +34,14 @@ class VersionInfo(BaseModel):
version: str
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/providers/list", method="GET")
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/routes/list", method="GET")
async def list_routes(self) -> Dict[str, List[RouteInfo]]: ...

View file

@ -50,7 +50,7 @@ class Memory(Protocol):
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/memory/insert")
@webmethod(route="/memory/insert", method="POST")
async def insert_documents(
self,
bank_id: str,
@ -58,7 +58,7 @@ class Memory(Protocol):
ttl_seconds: Optional[int] = None,
) -> None: ...
@webmethod(route="/memory/query")
@webmethod(route="/memory/query", method="POST")
async def query_documents(
self,
bank_id: str,

View file

@ -16,7 +16,6 @@ from typing import (
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
@ -133,16 +132,23 @@ class MemoryBankInput(BaseModel):
provider_memory_bank_id: Optional[str] = None
class ListMemoryBanksResponse(BaseModel):
data: List[MemoryBank]
@runtime_checkable
@trace_protocol
class MemoryBanks(Protocol):
@webmethod(route="/memory-banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory-banks", method="GET")
async def list_memory_banks(self) -> ListMemoryBanksResponse: ...
@webmethod(route="/memory-banks/get", method="GET")
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory-banks/{memory_bank_id}", method="GET")
async def get_memory_bank(
self,
memory_bank_id: str,
) -> Optional[MemoryBank]: ...
@webmethod(route="/memory-banks/register", method="POST")
@webmethod(route="/memory-banks", method="POST")
async def register_memory_bank(
self,
memory_bank_id: str,
@ -151,5 +157,5 @@ class MemoryBanks(Protocol):
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank: ...
@webmethod(route="/memory-banks/unregister", method="POST")
@webmethod(route="/memory-banks/{memory_bank_id}", method="DELETE")
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...

View file

@ -52,16 +52,23 @@ class ModelInput(CommonModelFields):
model_config = ConfigDict(protected_namespaces=())
class ListModelsResponse(BaseModel):
data: List[Model]
@runtime_checkable
@trace_protocol
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> List[Model]: ...
@webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ...
@webmethod(route="/models/get", method="GET")
async def get_model(self, identifier: str) -> Optional[Model]: ...
@webmethod(route="/models/{model_id}", method="GET")
async def get_model(
self,
model_id: str,
) -> Optional[Model]: ...
@webmethod(route="/models/register", method="POST")
@webmethod(route="/models", method="POST")
async def register_model(
self,
model_id: str,
@ -71,5 +78,8 @@ class Models(Protocol):
model_type: Optional[ModelType] = None,
) -> Model: ...
@webmethod(route="/models/unregister", method="POST")
async def unregister_model(self, model_id: str) -> None: ...
@webmethod(route="/models/{model_id}", method="DELETE")
async def unregister_model(
self,
model_id: str,
) -> None: ...

View file

@ -6,16 +6,13 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
@ -159,6 +156,10 @@ class PostTrainingJobStatusResponse(BaseModel):
checkpoints: List[Checkpoint] = Field(default_factory=list)
class ListPostTrainingJobsResponse(BaseModel):
data: List[PostTrainingJob]
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job."""
@ -197,7 +198,7 @@ class PostTraining(Protocol):
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(

View file

@ -12,7 +12,6 @@ from pydantic import BaseModel, Field
from llama_stack.apis.inference import Message
from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -49,7 +48,7 @@ class ShieldStore(Protocol):
class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run-shield")
@webmethod(route="/safety/run-shield", method="POST")
async def run_shield(
self,
shield_id: str,

View file

@ -11,7 +11,6 @@ from pydantic import BaseModel
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
# mapping of metric to value
ScoringResultRow = Dict[str, Any]
@ -43,7 +42,7 @@ class ScoringFunctionStore(Protocol):
class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score-batch")
@webmethod(route="/scoring/score-batch", method="POST")
async def score_batch(
self,
dataset_id: str,
@ -51,7 +50,7 @@ class Scoring(Protocol):
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score")
@webmethod(route="/scoring/score", method="POST")
async def score(
self,
input_rows: List[Dict[str, Any]],

View file

@ -21,7 +21,6 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
@ -129,15 +128,21 @@ class ScoringFnInput(CommonScoringFnFields, BaseModel):
provider_scoring_fn_id: Optional[str] = None
class ListScoringFunctionsResponse(BaseModel):
data: List[ScoringFn]
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions/list", method="GET")
async def list_scoring_functions(self) -> List[ScoringFn]: ...
@webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/get", method="GET")
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
async def get_scoring_function(
self, scoring_fn_id: str, /
) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions/register", method="POST")
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(
self,
scoring_fn_id: str,

View file

@ -38,16 +38,20 @@ class ShieldInput(CommonShieldFields):
provider_shield_id: Optional[str] = None
class ListShieldsResponse(BaseModel):
data: List[Shield]
@runtime_checkable
@trace_protocol
class Shields(Protocol):
@webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[Shield]: ...
@webmethod(route="/shields", method="GET")
async def list_shields(self) -> ListShieldsResponse: ...
@webmethod(route="/shields/get", method="GET")
@webmethod(route="/shields/{identifier}", method="GET")
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
@webmethod(route="/shields/register", method="POST")
@webmethod(route="/shields", method="POST")
async def register_shield(
self,
shield_id: str,

View file

@ -185,8 +185,8 @@ class Telemetry(Protocol):
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
@webmethod(route="/telemetry/get-span-tree", method="POST")
async def get_span_tree(
@webmethod(route="/telemetry/query-span-tree", method="POST")
async def query_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,

View file

@ -74,13 +74,21 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ...
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
class ListToolGroupsResponse(BaseModel):
data: List[ToolGroup]
class ListToolsResponse(BaseModel):
data: List[Tool]
@runtime_checkable
@trace_protocol
class ToolGroups(Protocol):
@webmethod(route="/toolgroups/register", method="POST")
@webmethod(route="/toolgroups", method="POST")
async def register_tool_group(
self,
toolgroup_id: str,
@ -91,27 +99,33 @@ class ToolGroups(Protocol):
"""Register a tool group"""
...
@webmethod(route="/toolgroups/get", method="GET")
@webmethod(route="/toolgroups/{toolgroup_id}", method="GET")
async def get_tool_group(
self,
toolgroup_id: str,
) -> ToolGroup: ...
@webmethod(route="/toolgroups/list", method="GET")
async def list_tool_groups(self) -> List[ToolGroup]:
@webmethod(route="/toolgroups", method="GET")
async def list_tool_groups(self) -> ListToolGroupsResponse:
"""List tool groups with optional provider"""
...
@webmethod(route="/tools/list", method="GET")
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
@webmethod(route="/tools", method="GET")
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
"""List tools with optional tool group"""
...
@webmethod(route="/tools/get", method="GET")
async def get_tool(self, tool_name: str) -> Tool: ...
@webmethod(route="/tools/{tool_name}", method="GET")
async def get_tool(
self,
tool_name: str,
) -> Tool: ...
@webmethod(route="/toolgroups/unregister", method="POST")
async def unregister_tool_group(self, tool_group_id: str) -> None:
@webmethod(route="/toolgroups/{toolgroup_id}", method="DELETE")
async def unregister_toolgroup(
self,
toolgroup_id: str,
) -> None:
"""Unregister a tool group"""
...