feat(api): remove List* response types and nils for get/list

TODO:
- make sure docstrings are refreshed as needed.
- make sure this passes tests.
- address a TODO in code (obsolete comment?)
- make sure client side still works.
- analyze if any providers need adjustments.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-14 10:25:59 -04:00
parent bfc79217a8
commit 90ed785fbd
21 changed files with 222 additions and 935 deletions

View file

@ -241,16 +241,6 @@ class Agent(BaseModel):
created_at: datetime
@json_schema_type
class ListAgentsResponse(BaseModel):
data: List[Agent]
@json_schema_type
class ListAgentSessionsResponse(BaseModel):
data: List[Session]
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None
@ -559,10 +549,10 @@ class Agents(Protocol):
...
@webmethod(route="/agents", method="GET")
async def list_agents(self) -> ListAgentsResponse:
async def list_agents(self) -> list[Agent]:
"""List all agents.
:returns: A ListAgentsResponse.
:returns: a list of Agents.
"""
...
@ -579,10 +569,10 @@ class Agents(Protocol):
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
) -> list[Session]:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
:returns: A ListAgentSessionsResponse.
:returns: A list of agent Sessions.
"""
...

View file

@ -6,8 +6,6 @@
from typing import List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionResponse,
@ -20,17 +18,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
from llama_stack.schema_utils import webmethod
@runtime_checkable
@ -43,7 +31,7 @@ class BatchInference(Protocol):
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
) -> list[CompletionResponse]: ...
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def batch_chat_completion(
@ -57,4 +45,4 @@ class BatchInference(Protocol):
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ...
) -> list[ChatCompletionResponse]: ...

View file

@ -39,20 +39,16 @@ class BenchmarkInput(CommonBenchmarkFields, BaseModel):
provider_benchmark_id: Optional[str] = None
class ListBenchmarksResponse(BaseModel):
data: List[Benchmark]
@runtime_checkable
class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET")
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
async def list_benchmarks(self) -> list[Benchmark]: ...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
async def get_benchmark(
self,
benchmark_id: str,
) -> Optional[Benchmark]: ...
) -> Benchmark: ...
@webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark(

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol
from typing import Any, Dict, Literal, Optional, Protocol
from pydantic import BaseModel, Field
@ -42,10 +42,6 @@ class DatasetInput(CommonDatasetFields, BaseModel):
provider_dataset_id: Optional[str] = None
class ListDatasetsResponse(BaseModel):
data: List[Dataset]
class Datasets(Protocol):
@webmethod(route="/datasets", method="POST")
async def register_dataset(
@ -62,10 +58,10 @@ class Datasets(Protocol):
async def get_dataset(
self,
dataset_id: str,
) -> Optional[Dataset]: ...
) -> Dataset: ...
@webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ...
async def list_datasets(self) -> list[Dataset]: ...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
async def unregister_dataset(

View file

@ -117,7 +117,7 @@ class Eval(Protocol):
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
from pydantic import BaseModel
@ -13,7 +13,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class FileUploadResponse(BaseModel):
class FileUpload(BaseModel):
"""
Response after initiating a file upload session.
@ -30,23 +30,12 @@ class FileUploadResponse(BaseModel):
@json_schema_type
class BucketResponse(BaseModel):
class Bucket(BaseModel):
name: str
@json_schema_type
class ListBucketResponse(BaseModel):
"""
Response representing a list of file entries.
:param data: List of FileResponse entries
"""
data: List[BucketResponse]
@json_schema_type
class FileResponse(BaseModel):
class File(BaseModel):
"""
Response representing a file entry.
@ -66,17 +55,6 @@ class FileResponse(BaseModel):
created_at: int
@json_schema_type
class ListFileResponse(BaseModel):
"""
Response representing a list of file entries.
:param data: List of FileResponse entries
"""
data: List[FileResponse]
@runtime_checkable
@trace_protocol
class Files(Protocol):
@ -87,7 +65,7 @@ class Files(Protocol):
key: str,
mime_type: str,
size: int,
) -> FileUploadResponse:
) -> FileUpload:
"""
Create a new upload session for a file identified by a bucket and key.
@ -102,7 +80,7 @@ class Files(Protocol):
async def upload_content_to_session(
self,
upload_id: str,
) -> Optional[FileResponse]:
) -> File:
"""
Upload file content to an existing upload session.
On the server, request body will have the raw bytes that are uploaded.
@ -115,7 +93,7 @@ class Files(Protocol):
async def get_upload_session_info(
self,
upload_id: str,
) -> Optional[FileUploadResponse]:
) -> FileUpload:
"""
Returns information about an existsing upload session
@ -127,7 +105,7 @@ class Files(Protocol):
async def list_all_buckets(
self,
bucket: str,
) -> ListBucketResponse:
) -> list[Bucket]:
"""
List all buckets.
"""
@ -137,7 +115,7 @@ class Files(Protocol):
async def list_files_in_bucket(
self,
bucket: str,
) -> ListFileResponse:
) -> list[File]:
"""
List all files in a bucket.
@ -150,7 +128,7 @@ class Files(Protocol):
self,
bucket: str,
key: str,
) -> FileResponse:
) -> File:
"""
Get a file info identified by a bucket and key.
@ -164,7 +142,7 @@ class Files(Protocol):
self,
bucket: str,
key: str,
) -> FileResponse:
) -> File:
"""
Delete a file identified by a bucket and key.

View file

@ -31,26 +31,18 @@ class ProviderInfo(BaseModel):
provider_type: str
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@json_schema_type
class VersionInfo(BaseModel):
version: str
class ListRoutesResponse(BaseModel):
data: List[RouteInfo]
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/inspect/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
async def list_providers(self) -> list[ProviderInfo]: ...
@webmethod(route="/inspect/routes", method="GET")
async def list_routes(self) -> ListRoutesResponse: ...
async def list_routes(self) -> list[RouteInfo]: ...
@webmethod(route="/health", method="GET")
async def health(self) -> HealthInfo: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
@ -52,21 +52,17 @@ class ModelInput(CommonModelFields):
model_config = ConfigDict(protected_namespaces=())
class ListModelsResponse(BaseModel):
data: List[Model]
@runtime_checkable
@trace_protocol
class Models(Protocol):
@webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ...
async def list_models(self) -> list[Model]: ...
@webmethod(route="/models/{model_id:path}", method="GET")
async def get_model(
self,
model_id: str,
) -> Optional[Model]: ...
) -> Model: ...
@webmethod(route="/models", method="POST")
async def register_model(

View file

@ -157,10 +157,6 @@ class PostTrainingJobStatusResponse(BaseModel):
checkpoints: List[Checkpoint] = Field(default_factory=list)
class ListPostTrainingJobsResponse(BaseModel):
data: List[PostTrainingJob]
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job."""
@ -199,13 +195,13 @@ class PostTraining(Protocol):
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
async def get_training_jobs(self) -> list[PostTrainingJob]: ...
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ...
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Protocol, runtime_checkable
from typing import Any, Dict, Protocol, runtime_checkable
from pydantic import BaseModel
@ -19,10 +19,6 @@ class ProviderInfo(BaseModel):
config: Dict[str, Any]
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@runtime_checkable
class Providers(Protocol):
"""
@ -30,7 +26,7 @@ class Providers(Protocol):
"""
@webmethod(route="/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
async def list_providers(self) -> list[ProviderInfo]: ...
@webmethod(route="/providers/{provider_id}", method="GET")
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...

View file

@ -125,17 +125,13 @@ class ScoringFnInput(CommonScoringFnFields, BaseModel):
provider_scoring_fn_id: Optional[str] = None
class ListScoringFunctionsResponse(BaseModel):
data: List[ScoringFn]
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
async def list_scoring_functions(self) -> list[ScoringFn]: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
@ -38,18 +38,14 @@ class ShieldInput(CommonShieldFields):
provider_shield_id: Optional[str] = None
class ListShieldsResponse(BaseModel):
data: List[Shield]
@runtime_checkable
@trace_protocol
class Shields(Protocol):
@webmethod(route="/shields", method="GET")
async def list_shields(self) -> ListShieldsResponse: ...
async def list_shields(self) -> list[Shield]: ...
@webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
async def get_shield(self, identifier: str) -> Shield: ...
@webmethod(route="/shields", method="POST")
async def register_shield(

View file

@ -103,6 +103,7 @@ class MetricInResponse(BaseModel):
unit: Optional[str] = None
# TODO: check what this comment is about
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be inlcuded with the response

View file

@ -80,14 +80,6 @@ class ToolStore(Protocol):
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
class ListToolGroupsResponse(BaseModel):
data: List[ToolGroup]
class ListToolsResponse(BaseModel):
data: List[Tool]
@runtime_checkable
@trace_protocol
class ToolGroups(Protocol):
@ -109,12 +101,12 @@ class ToolGroups(Protocol):
) -> ToolGroup: ...
@webmethod(route="/toolgroups", method="GET")
async def list_tool_groups(self) -> ListToolGroupsResponse:
async def list_tool_groups(self) -> list[ToolGroup]:
"""List tool groups with optional provider"""
...
@webmethod(route="/tools", method="GET")
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
async def list_tools(self, toolgroup_id: Optional[str] = None) -> list[Tool]:
"""List tools with optional tool group"""
...
@ -148,7 +140,7 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ...
) -> list[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Literal, Optional, Protocol, runtime_checkable
from typing import Literal, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
@ -36,21 +36,17 @@ class VectorDBInput(BaseModel):
provider_vector_db_id: Optional[str] = None
class ListVectorDBsResponse(BaseModel):
data: List[VectorDB]
@runtime_checkable
@trace_protocol
class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET")
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
async def list_vector_dbs(self) -> list[VectorDB]: ...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
async def get_vector_db(
self,
vector_db_id: str,
) -> Optional[VectorDB]: ...
) -> VectorDB: ...
@webmethod(route="/vector-dbs", method="POST")
async def register_vector_db(
@ -60,7 +56,7 @@ class VectorDBs(Protocol):
embedding_dimension: Optional[int] = 384,
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
) -> VectorDB: ...
) -> None: ...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
async def unregister_vector_db(self, vector_db_id: str) -> None: ...

View file

@ -30,7 +30,7 @@ class QueryChunksResponse(BaseModel):
class VectorDBStore(Protocol):
def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ...
def get_vector_db(self, vector_db_id: str) -> VectorDB: ...
@runtime_checkable