More idiomatic REST API (#765)

# What does this PR do?

This PR changes our API to follow more idiomatic REST API approaches of
having paths being resources and methods indicating the action being
performed.

Changes made to generator:
1) removed the prefix check of "get" as its not required and is actually
needed for other method types too
2) removed _ check on path since variables can have "_"



## Test Plan

LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v
tests/client-sdk/agents/test_agents.py
This commit is contained in:
Dinesh Yeduguru 2025-01-15 13:20:09 -08:00 committed by GitHub
parent 6deef1ece0
commit 7fb2c1c48d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 2144 additions and 1917 deletions

View file

@ -10,23 +10,32 @@ from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
from llama_stack.apis.memory_banks import (
BankParams,
ListMemoryBanksResponse,
MemoryBank,
MemoryBanks,
MemoryBankType,
)
from llama_stack.apis.models import Model, Models, ModelType
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.apis.tools import (
ListToolGroupsResponse,
ListToolsResponse,
Tool,
ToolGroup,
ToolGroups,
ToolHost,
)
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
@ -215,11 +224,11 @@ class CommonRoutingTableImpl(RoutingTable):
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[Model]:
return await self.get_all_with_type("model")
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def get_model(self, identifier: str) -> Optional[Model]:
return await self.get_object_by_identifier("model", identifier)
async def get_model(self, model_id: str) -> Optional[Model]:
return await self.get_object_by_identifier("model", model_id)
async def register_model(
self,
@ -265,8 +274,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[Shield]:
return await self.get_all_with_type(ResourceType.shield.value)
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(
data=await self.get_all_with_type(ResourceType.shield.value)
)
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier)
@ -301,8 +312,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBank]:
return await self.get_all_with_type(ResourceType.memory_bank.value)
async def list_memory_banks(self) -> ListMemoryBanksResponse:
return ListMemoryBanksResponse(data=await self.get_all_with_type("memory_bank"))
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
@ -365,8 +376,10 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[Dataset]:
return await self.get_all_with_type(ResourceType.dataset.value)
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(
data=await self.get_all_with_type(ResourceType.dataset.value)
)
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id)
@ -410,8 +423,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:
return await self.get_all_with_type(ResourceType.scoring_function.value)
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(
data=await self.get_all_with_type(ResourceType.scoring_function.value)
)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -447,11 +462,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
async def list_eval_tasks(self) -> List[EvalTask]:
return await self.get_all_with_type(ResourceType.eval_task.value)
async def list_eval_tasks(self) -> ListEvalTasksResponse:
return ListEvalTasksResponse(data=await self.get_all_with_type("eval_task"))
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier("eval_task", name)
async def get_eval_task(self, eval_task_id: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier("eval_task", eval_task_id)
async def register_eval_task(
self,
@ -485,14 +500,14 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
tools = await self.get_all_with_type("tool")
if tool_group_id:
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
return tools
if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
return ListToolsResponse(data=tools)
async def list_tool_groups(self) -> List[ToolGroup]:
return await self.get_all_with_type("tool_group")
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id)
@ -551,11 +566,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
)
)
async def unregister_tool_group(self, tool_group_id: str) -> None:
tool_group = await self.get_tool_group(tool_group_id)
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {tool_group_id} not found")
tools = await self.list_tools(tool_group_id)
raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id).data
for tool in tools:
await self.unregister_object(tool)
await self.unregister_object(tool_group)