mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-03 08:22:17 +00:00
feat(api): remove List* response types and nils for get/list
TODO: - make sure docstrings are refreshed as needed. - make sure this passes tests. - address a TODO in code (obsolete comment?) - make sure client side still works. - analyze if any providers need adjustments. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
bfc79217a8
commit
90ed785fbd
21 changed files with 222 additions and 935 deletions
|
|
@ -241,16 +241,6 @@ class Agent(BaseModel):
|
|||
created_at: datetime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentsResponse(BaseModel):
|
||||
data: List[Agent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentSessionsResponse(BaseModel):
|
||||
data: List[Session]
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: Optional[str] = None
|
||||
|
||||
|
|
@ -559,10 +549,10 @@ class Agents(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET")
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
async def list_agents(self) -> list[Agent]:
|
||||
"""List all agents.
|
||||
|
||||
:returns: A ListAgentsResponse.
|
||||
:returns: a list of Agents.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
@ -579,10 +569,10 @@ class Agents(Protocol):
|
|||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
) -> list[Session]:
|
||||
"""List all session(s) of a given agent.
|
||||
|
||||
:param agent_id: The ID of the agent to list sessions for.
|
||||
:returns: A ListAgentSessionsResponse.
|
||||
:returns: A list of agent Sessions.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@
|
|||
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
CompletionResponse,
|
||||
|
|
@ -20,17 +18,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchCompletionResponse(BaseModel):
|
||||
batch: List[CompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionResponse(BaseModel):
|
||||
batch: List[ChatCompletionResponse]
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
@ -43,7 +31,7 @@ class BatchInference(Protocol):
|
|||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchCompletionResponse: ...
|
||||
) -> list[CompletionResponse]: ...
|
||||
|
||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||
async def batch_chat_completion(
|
||||
|
|
@ -57,4 +45,4 @@ class BatchInference(Protocol):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchChatCompletionResponse: ...
|
||||
) -> list[ChatCompletionResponse]: ...
|
||||
|
|
|
|||
|
|
@ -39,20 +39,16 @@ class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
|||
provider_benchmark_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListBenchmarksResponse(BaseModel):
|
||||
data: List[Benchmark]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Benchmarks(Protocol):
|
||||
@webmethod(route="/eval/benchmarks", method="GET")
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
|
||||
async def list_benchmarks(self) -> list[Benchmark]: ...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
|
||||
async def get_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
) -> Optional[Benchmark]: ...
|
||||
) -> Benchmark: ...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST")
|
||||
async def register_benchmark(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
from typing import Any, Dict, Literal, Optional, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -42,10 +42,6 @@ class DatasetInput(CommonDatasetFields, BaseModel):
|
|||
provider_dataset_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListDatasetsResponse(BaseModel):
|
||||
data: List[Dataset]
|
||||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets", method="POST")
|
||||
async def register_dataset(
|
||||
|
|
@ -62,10 +58,10 @@ class Datasets(Protocol):
|
|||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> Optional[Dataset]: ...
|
||||
) -> Dataset: ...
|
||||
|
||||
@webmethod(route="/datasets", method="GET")
|
||||
async def list_datasets(self) -> ListDatasetsResponse: ...
|
||||
async def list_datasets(self) -> list[Dataset]: ...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
||||
async def unregister_dataset(
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class Eval(Protocol):
|
|||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
|
||||
"""Get the status of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -13,7 +13,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class FileUploadResponse(BaseModel):
|
||||
class FileUpload(BaseModel):
|
||||
"""
|
||||
Response after initiating a file upload session.
|
||||
|
||||
|
|
@ -30,23 +30,12 @@ class FileUploadResponse(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class BucketResponse(BaseModel):
|
||||
class Bucket(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBucketResponse(BaseModel):
|
||||
"""
|
||||
Response representing a list of file entries.
|
||||
|
||||
:param data: List of FileResponse entries
|
||||
"""
|
||||
|
||||
data: List[BucketResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FileResponse(BaseModel):
|
||||
class File(BaseModel):
|
||||
"""
|
||||
Response representing a file entry.
|
||||
|
||||
|
|
@ -66,17 +55,6 @@ class FileResponse(BaseModel):
|
|||
created_at: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListFileResponse(BaseModel):
|
||||
"""
|
||||
Response representing a list of file entries.
|
||||
|
||||
:param data: List of FileResponse entries
|
||||
"""
|
||||
|
||||
data: List[FileResponse]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Files(Protocol):
|
||||
|
|
@ -87,7 +65,7 @@ class Files(Protocol):
|
|||
key: str,
|
||||
mime_type: str,
|
||||
size: int,
|
||||
) -> FileUploadResponse:
|
||||
) -> FileUpload:
|
||||
"""
|
||||
Create a new upload session for a file identified by a bucket and key.
|
||||
|
||||
|
|
@ -102,7 +80,7 @@ class Files(Protocol):
|
|||
async def upload_content_to_session(
|
||||
self,
|
||||
upload_id: str,
|
||||
) -> Optional[FileResponse]:
|
||||
) -> File:
|
||||
"""
|
||||
Upload file content to an existing upload session.
|
||||
On the server, request body will have the raw bytes that are uploaded.
|
||||
|
|
@ -115,7 +93,7 @@ class Files(Protocol):
|
|||
async def get_upload_session_info(
|
||||
self,
|
||||
upload_id: str,
|
||||
) -> Optional[FileUploadResponse]:
|
||||
) -> FileUpload:
|
||||
"""
|
||||
Returns information about an existsing upload session
|
||||
|
||||
|
|
@ -127,7 +105,7 @@ class Files(Protocol):
|
|||
async def list_all_buckets(
|
||||
self,
|
||||
bucket: str,
|
||||
) -> ListBucketResponse:
|
||||
) -> list[Bucket]:
|
||||
"""
|
||||
List all buckets.
|
||||
"""
|
||||
|
|
@ -137,7 +115,7 @@ class Files(Protocol):
|
|||
async def list_files_in_bucket(
|
||||
self,
|
||||
bucket: str,
|
||||
) -> ListFileResponse:
|
||||
) -> list[File]:
|
||||
"""
|
||||
List all files in a bucket.
|
||||
|
||||
|
|
@ -150,7 +128,7 @@ class Files(Protocol):
|
|||
self,
|
||||
bucket: str,
|
||||
key: str,
|
||||
) -> FileResponse:
|
||||
) -> File:
|
||||
"""
|
||||
Get a file info identified by a bucket and key.
|
||||
|
||||
|
|
@ -164,7 +142,7 @@ class Files(Protocol):
|
|||
self,
|
||||
bucket: str,
|
||||
key: str,
|
||||
) -> FileResponse:
|
||||
) -> File:
|
||||
"""
|
||||
Delete a file identified by a bucket and key.
|
||||
|
||||
|
|
|
|||
|
|
@ -31,26 +31,18 @@ class ProviderInfo(BaseModel):
|
|||
provider_type: str
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VersionInfo(BaseModel):
|
||||
version: str
|
||||
|
||||
|
||||
class ListRoutesResponse(BaseModel):
|
||||
data: List[RouteInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/inspect/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
async def list_providers(self) -> list[ProviderInfo]: ...
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET")
|
||||
async def list_routes(self) -> ListRoutesResponse: ...
|
||||
async def list_routes(self) -> list[RouteInfo]: ...
|
||||
|
||||
@webmethod(route="/health", method="GET")
|
||||
async def health(self) -> HealthInfo: ...
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
|
@ -52,21 +52,17 @@ class ModelInput(CommonModelFields):
|
|||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ListModelsResponse(BaseModel):
|
||||
data: List[Model]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models", method="GET")
|
||||
async def list_models(self) -> ListModelsResponse: ...
|
||||
async def list_models(self) -> list[Model]: ...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||
async def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> Optional[Model]: ...
|
||||
) -> Model: ...
|
||||
|
||||
@webmethod(route="/models", method="POST")
|
||||
async def register_model(
|
||||
|
|
|
|||
|
|
@ -157,10 +157,6 @@ class PostTrainingJobStatusResponse(BaseModel):
|
|||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ListPostTrainingJobsResponse(BaseModel):
|
||||
data: List[PostTrainingJob]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingJobArtifactsResponse(BaseModel):
|
||||
"""Artifacts of a finetuning job."""
|
||||
|
|
@ -199,13 +195,13 @@ class PostTraining(Protocol):
|
|||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
||||
async def get_training_jobs(self) -> list[PostTrainingJob]: ...
|
||||
|
||||
@webmethod(route="/post-training/job/status", method="GET")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ...
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
from typing import Any, Dict, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -19,10 +19,6 @@ class ProviderInfo(BaseModel):
|
|||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""
|
||||
|
|
@ -30,7 +26,7 @@ class Providers(Protocol):
|
|||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
async def list_providers(self) -> list[ProviderInfo]: ...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
|
||||
|
|
|
|||
|
|
@ -125,17 +125,13 @@ class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
|||
provider_scoring_fn_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListScoringFunctionsResponse(BaseModel):
|
||||
data: List[ScoringFn]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring-functions", method="GET")
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||
async def list_scoring_functions(self) -> list[ScoringFn]: ...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST")
|
||||
async def register_scoring_function(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -38,18 +38,14 @@ class ShieldInput(CommonShieldFields):
|
|||
provider_shield_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListShieldsResponse(BaseModel):
|
||||
data: List[Shield]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields", method="GET")
|
||||
async def list_shields(self) -> ListShieldsResponse: ...
|
||||
async def list_shields(self) -> list[Shield]: ...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
|
||||
async def get_shield(self, identifier: str) -> Shield: ...
|
||||
|
||||
@webmethod(route="/shields", method="POST")
|
||||
async def register_shield(
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ class MetricInResponse(BaseModel):
|
|||
unit: Optional[str] = None
|
||||
|
||||
|
||||
# TODO: check what this comment is about
|
||||
# This is a short term solution to allow inference API to return metrics
|
||||
# The ideal way to do this is to have a way for all response types to include metrics
|
||||
# and all metric events logged to the telemetry API to be inlcuded with the response
|
||||
|
|
|
|||
|
|
@ -80,14 +80,6 @@ class ToolStore(Protocol):
|
|||
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||
|
||||
|
||||
class ListToolGroupsResponse(BaseModel):
|
||||
data: List[ToolGroup]
|
||||
|
||||
|
||||
class ListToolsResponse(BaseModel):
|
||||
data: List[Tool]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolGroups(Protocol):
|
||||
|
|
@ -109,12 +101,12 @@ class ToolGroups(Protocol):
|
|||
) -> ToolGroup: ...
|
||||
|
||||
@webmethod(route="/toolgroups", method="GET")
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
async def list_tool_groups(self) -> list[ToolGroup]:
|
||||
"""List tool groups with optional provider"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools", method="GET")
|
||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> list[Tool]:
|
||||
"""List tools with optional tool group"""
|
||||
...
|
||||
|
||||
|
|
@ -148,7 +140,7 @@ class ToolRuntime(Protocol):
|
|||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]: ...
|
||||
) -> list[ToolDef]: ...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||
from typing import Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -36,21 +36,17 @@ class VectorDBInput(BaseModel):
|
|||
provider_vector_db_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListVectorDBsResponse(BaseModel):
|
||||
data: List[VectorDB]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class VectorDBs(Protocol):
|
||||
@webmethod(route="/vector-dbs", method="GET")
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
|
||||
async def list_vector_dbs(self) -> list[VectorDB]: ...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
|
||||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
) -> Optional[VectorDB]: ...
|
||||
) -> VectorDB: ...
|
||||
|
||||
@webmethod(route="/vector-dbs", method="POST")
|
||||
async def register_vector_db(
|
||||
|
|
@ -60,7 +56,7 @@ class VectorDBs(Protocol):
|
|||
embedding_dimension: Optional[int] = 384,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> VectorDB: ...
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class QueryChunksResponse(BaseModel):
|
|||
|
||||
|
||||
class VectorDBStore(Protocol):
|
||||
def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ...
|
||||
def get_vector_db(self, vector_db_id: str) -> VectorDB: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue