feat(api): remove List* response types and nils for get/list

TODO:
- make sure docstrings are refreshed as needed.
- make sure this passes tests.
- address a TODO in code (obsolete comment?)
- make sure client side still works.
- analyze if any providers need adjustments.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-14 10:25:59 -04:00
parent bfc79217a8
commit 90ed785fbd
21 changed files with 222 additions and 935 deletions

View file

@ -9,28 +9,25 @@ from typing import Any, Dict, List, Optional
from pydantic import TypeAdapter
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.benchmarks import Benchmark, Benchmarks
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.datasets import Dataset, Datasets
from llama_stack.apis.models import Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import (
ListToolGroupsResponse,
ListToolsResponse,
Tool,
ToolGroup,
ToolGroups,
ToolHost,
)
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.apis.vector_dbs import VectorDB, VectorDBs
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
@ -208,8 +205,8 @@ class CommonRoutingTableImpl(RoutingTable):
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def list_models(self) -> list[Model]:
return await self.get_all_with_type("model")
async def get_model(self, model_id: str) -> Optional[Model]:
return await self.get_object_by_identifier("model", model_id)
@ -256,8 +253,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def list_shields(self) -> list[Shield]:
return await self.get_all_with_type(ResourceType.shield.value)
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier)
@ -292,8 +289,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def list_vector_dbs(self) -> list[VectorDB]:
return await self.get_all_with_type("vector_db")
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
return await self.get_object_by_identifier("vector_db", vector_db_id)
@ -344,8 +341,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def list_datasets(self) -> list[Dataset]:
return await self.get_all_with_type(ResourceType.dataset.value)
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id)
@ -389,8 +386,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def list_scoring_functions(self) -> list[ScoringFunctions]:
return await self.get_all_with_type(ResourceType.scoring_function.value)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -426,8 +423,8 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def list_benchmarks(self) -> list[Benchmark]:
return await self.get_all_with_type("benchmark")
async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]:
return await self.get_object_by_identifier("benchmark", benchmark_id)
@ -464,14 +461,14 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
async def list_tools(self, toolgroup_id: Optional[str] = None) -> list[Tool]:
tools = await self.get_all_with_type("tool")
if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
return ListToolsResponse(data=tools)
return tools
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def list_tool_groups(self) -> list[ToolGroup]:
return await self.get_all_with_type("tool_group")
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id)