mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-03 02:52:16 +00:00
feat(api): remove List* response types and nils for get/list
TODO: - make sure docstrings are refreshed as needed. - make sure this passes tests. - address a TODO in code (obsolete comment?) - make sure client side still works. - analyze if any providers need adjustments. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
bfc79217a8
commit
90ed785fbd
21 changed files with 222 additions and 935 deletions
|
|
@ -11,8 +11,6 @@ from pydantic import BaseModel
|
|||
from llama_stack.apis.inspect import (
|
||||
HealthInfo,
|
||||
Inspect,
|
||||
ListProvidersResponse,
|
||||
ListRoutesResponse,
|
||||
ProviderInfo,
|
||||
RouteInfo,
|
||||
VersionInfo,
|
||||
|
|
@ -39,43 +37,27 @@ class DistributionInspectImpl(Inspect):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = []
|
||||
for api, providers in run_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
async def list_providers(self) -> list[ProviderInfo]:
|
||||
return [
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
)
|
||||
for api, providers in self.config.run_config.providers.items()
|
||||
for p in providers
|
||||
]
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = []
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
for api, endpoints in all_endpoints.items():
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
ret.extend(
|
||||
[
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
]
|
||||
async def list_routes(self) -> list[RouteInfo]:
|
||||
return [
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
provider_types=[p.provider_type for p in self.config.run_config.providers.get(api.value, [])],
|
||||
)
|
||||
|
||||
return ListRoutesResponse(data=ret)
|
||||
for api, endpoints in get_all_api_endpoints().items()
|
||||
for e in endpoints
|
||||
]
|
||||
|
||||
async def health(self) -> HealthInfo:
|
||||
return HealthInfo(status="OK")
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
from llama_stack.apis.providers import ProviderInfo, Providers
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .stack import redact_sensitive_fields
|
||||
|
|
@ -31,28 +31,23 @@ class ProviderImpl(Providers):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
async def list_providers(self) -> list[ProviderInfo]:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
return [
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
for api, providers in safe_config.providers.items()
|
||||
for p in providers
|
||||
]
|
||||
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
all_providers = await self.list_providers()
|
||||
for p in all_providers.data:
|
||||
for p in all_providers:
|
||||
if p.provider_id == provider_id:
|
||||
return p
|
||||
|
||||
|
|
|
|||
|
|
@ -9,28 +9,25 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||
from llama_stack.apis.datasets import Dataset, Datasets
|
||||
from llama_stack.apis.models import Model, Models, ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
ListScoringFunctionsResponse,
|
||||
ScoringFn,
|
||||
ScoringFnParams,
|
||||
ScoringFunctions,
|
||||
)
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||
from llama_stack.apis.shields import Shield, Shields
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolGroupsResponse,
|
||||
ListToolsResponse,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ToolGroups,
|
||||
ToolHost,
|
||||
)
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBs
|
||||
from llama_stack.distribution.datatypes import (
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
|
|
@ -208,8 +205,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
async def list_models(self) -> list[Model]:
|
||||
return await self.get_all_with_type("model")
|
||||
|
||||
async def get_model(self, model_id: str) -> Optional[Model]:
|
||||
return await self.get_object_by_identifier("model", model_id)
|
||||
|
|
@ -256,8 +253,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||
async def list_shields(self) -> list[Shield]:
|
||||
return await self.get_all_with_type(ResourceType.shield.value)
|
||||
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
||||
return await self.get_object_by_identifier("shield", identifier)
|
||||
|
|
@ -292,8 +289,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
|
||||
|
||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return await self.get_all_with_type("vector_db")
|
||||
|
||||
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
|
||||
return await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||
|
|
@ -344,8 +341,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||
async def list_datasets(self) -> list[Dataset]:
|
||||
return await self.get_all_with_type(ResourceType.dataset.value)
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||
|
|
@ -389,8 +386,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||
async def list_scoring_functions(self) -> list[ScoringFunctions]:
|
||||
return await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
|
|
@ -426,8 +423,8 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
|
||||
|
||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
||||
async def list_benchmarks(self) -> list[Benchmark]:
|
||||
return await self.get_all_with_type("benchmark")
|
||||
|
||||
async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]:
|
||||
return await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||
|
|
@ -464,14 +461,14 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> list[Tool]:
|
||||
tools = await self.get_all_with_type("tool")
|
||||
if toolgroup_id:
|
||||
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
||||
return ListToolsResponse(data=tools)
|
||||
return tools
|
||||
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||
async def list_tool_groups(self) -> list[ToolGroup]:
|
||||
return await self.get_all_with_type("tool_group")
|
||||
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
return await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue