feat(api): remove List* response types and nils for get/list

TODO:
- make sure docstrings are refreshed as needed.
- make sure this passes tests.
- address a TODO in code (obsolete comment?)
- make sure client side still works.
- analyze if any providers need adjustments.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-14 10:25:59 -04:00
parent bfc79217a8
commit 90ed785fbd
21 changed files with 222 additions and 935 deletions

View file

@ -11,8 +11,6 @@ from pydantic import BaseModel
from llama_stack.apis.inspect import (
HealthInfo,
Inspect,
ListProvidersResponse,
ListRoutesResponse,
ProviderInfo,
RouteInfo,
VersionInfo,
@ -39,43 +37,27 @@ class DistributionInspectImpl(Inspect):
async def initialize(self) -> None:
pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
ret = []
for api, providers in run_config.providers.items():
ret.extend(
[
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
)
for p in providers
]
async def list_providers(self) -> list[ProviderInfo]:
return [
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
)
for api, providers in self.config.run_config.providers.items()
for p in providers
]
return ListProvidersResponse(data=ret)
async def list_routes(self) -> ListRoutesResponse:
run_config = self.config.run_config
ret = []
all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items():
providers = run_config.providers.get(api.value, [])
ret.extend(
[
RouteInfo(
route=e.route,
method=e.method,
provider_types=[p.provider_type for p in providers],
)
for e in endpoints
]
async def list_routes(self) -> list[RouteInfo]:
return [
RouteInfo(
route=e.route,
method=e.method,
provider_types=[p.provider_type for p in self.config.run_config.providers.get(api.value, [])],
)
return ListRoutesResponse(data=ret)
for api, endpoints in get_all_api_endpoints().items()
for e in endpoints
]
async def health(self) -> HealthInfo:
return HealthInfo(status="OK")

View file

@ -7,7 +7,7 @@
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.apis.providers import ProviderInfo, Providers
from .datatypes import StackRunConfig
from .stack import redact_sensitive_fields
@ -31,28 +31,23 @@ class ProviderImpl(Providers):
async def initialize(self) -> None:
pass
async def list_providers(self) -> ListProvidersResponse:
async def list_providers(self) -> list[ProviderInfo]:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
ret = []
for api, providers in safe_config.providers.items():
ret.extend(
[
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
config=p.config,
)
for p in providers
]
return [
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
config=p.config,
)
return ListProvidersResponse(data=ret)
for api, providers in safe_config.providers.items()
for p in providers
]
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
all_providers = await self.list_providers()
for p in all_providers.data:
for p in all_providers:
if p.provider_id == provider_id:
return p

View file

@ -9,28 +9,25 @@ from typing import Any, Dict, List, Optional
from pydantic import TypeAdapter
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.benchmarks import Benchmark, Benchmarks
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.datasets import Dataset, Datasets
from llama_stack.apis.models import Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import (
ListToolGroupsResponse,
ListToolsResponse,
Tool,
ToolGroup,
ToolGroups,
ToolHost,
)
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.apis.vector_dbs import VectorDB, VectorDBs
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
@ -208,8 +205,8 @@ class CommonRoutingTableImpl(RoutingTable):
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def list_models(self) -> list[Model]:
return await self.get_all_with_type("model")
async def get_model(self, model_id: str) -> Optional[Model]:
return await self.get_object_by_identifier("model", model_id)
@ -256,8 +253,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def list_shields(self) -> list[Shield]:
return await self.get_all_with_type(ResourceType.shield.value)
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier)
@ -292,8 +289,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def list_vector_dbs(self) -> list[VectorDB]:
return await self.get_all_with_type("vector_db")
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
return await self.get_object_by_identifier("vector_db", vector_db_id)
@ -344,8 +341,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def list_datasets(self) -> list[Dataset]:
return await self.get_all_with_type(ResourceType.dataset.value)
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id)
@ -389,8 +386,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def list_scoring_functions(self) -> list[ScoringFunctions]:
return await self.get_all_with_type(ResourceType.scoring_function.value)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -426,8 +423,8 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def list_benchmarks(self) -> list[Benchmark]:
return await self.get_all_with_type("benchmark")
async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]:
return await self.get_object_by_identifier("benchmark", benchmark_id)
@ -464,14 +461,14 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
async def list_tools(self, toolgroup_id: Optional[str] = None) -> list[Tool]:
tools = await self.get_all_with_type("tool")
if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
return ListToolsResponse(data=tools)
return tools
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def list_tool_groups(self) -> list[ToolGroup]:
return await self.get_all_with_type("tool_group")
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id)