more idiomatic REST API

This commit is contained in:
Dinesh Yeduguru 2025-01-14 14:52:32 -08:00
parent d0a25dd453
commit b438dad8d2
29 changed files with 2144 additions and 1917 deletions

View file

@ -10,23 +10,32 @@ from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
from llama_stack.apis.memory_banks import (
BankParams,
ListMemoryBanksResponse,
MemoryBank,
MemoryBanks,
MemoryBankType,
)
from llama_stack.apis.models import Model, Models, ModelType
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.apis.tools import (
ListToolGroupsResponse,
ListToolsResponse,
Tool,
ToolGroup,
ToolGroups,
ToolHost,
)
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
@ -215,11 +224,11 @@ class CommonRoutingTableImpl(RoutingTable):
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[Model]:
return await self.get_all_with_type("model")
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def get_model(self, identifier: str) -> Optional[Model]:
return await self.get_object_by_identifier("model", identifier)
async def get_model(self, model_id: str) -> Optional[Model]:
return await self.get_object_by_identifier("model", model_id)
async def register_model(
self,
@ -265,8 +274,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[Shield]:
return await self.get_all_with_type(ResourceType.shield.value)
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(
data=await self.get_all_with_type(ResourceType.shield.value)
)
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier)
@ -301,8 +312,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBank]:
return await self.get_all_with_type(ResourceType.memory_bank.value)
async def list_memory_banks(self) -> ListMemoryBanksResponse:
return ListMemoryBanksResponse(data=await self.get_all_with_type("memory_bank"))
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
@ -365,8 +376,10 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[Dataset]:
return await self.get_all_with_type(ResourceType.dataset.value)
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(
data=await self.get_all_with_type(ResourceType.dataset.value)
)
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id)
@ -410,8 +423,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:
return await self.get_all_with_type(ResourceType.scoring_function.value)
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(
data=await self.get_all_with_type(ResourceType.scoring_function.value)
)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -447,11 +462,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
async def list_eval_tasks(self) -> List[EvalTask]:
return await self.get_all_with_type(ResourceType.eval_task.value)
async def list_eval_tasks(self) -> ListEvalTasksResponse:
return ListEvalTasksResponse(data=await self.get_all_with_type("eval_task"))
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier("eval_task", name)
async def get_eval_task(self, eval_task_id: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier("eval_task", eval_task_id)
async def register_eval_task(
self,
@ -485,14 +500,14 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
tools = await self.get_all_with_type("tool")
if tool_group_id:
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
return tools
if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
return ListToolsResponse(data=tools)
async def list_tool_groups(self) -> List[ToolGroup]:
return await self.get_all_with_type("tool_group")
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id)
@ -551,11 +566,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
)
)
async def unregister_tool_group(self, tool_group_id: str) -> None:
tool_group = await self.get_tool_group(tool_group_id)
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {tool_group_id} not found")
tools = await self.list_tools(tool_group_id)
raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id).data
for tool in tools:
await self.unregister_object(tool)
await self.unregister_object(tool_group)