more idiomatic REST API

This commit is contained in:
Dinesh Yeduguru 2025-01-14 14:52:32 -08:00
parent d0a25dd453
commit b438dad8d2
29 changed files with 2144 additions and 1917 deletions

View file

@ -624,6 +624,10 @@ class ChatAgent(ShieldRunnerMixin):
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call=tool_call,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
content=tool_call,
),
)
)
)
@ -735,8 +739,8 @@ class ChatAgent(ShieldRunnerMixin):
for toolgroup_name in agent_config_toolgroups:
if toolgroup_name not in toolgroups_for_turn_set:
continue
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
for tool_def in tools:
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
for tool_def in tools.data:
if (
toolgroup_name.startswith("builtin")
and toolgroup_name != MEMORY_GROUP

View file

@ -223,5 +223,5 @@ class MetaReferenceAgentsImpl(Agents):
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
async def delete_agents(self, agent_id: str) -> None:
async def delete_agent(self, agent_id: str) -> None:
await self.persistence_store.delete(f"agent:{agent_id}")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from llama_models.schema_utils import webmethod
@ -14,6 +14,7 @@ from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
LoraFinetuningConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
@ -114,8 +115,8 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> List[PostTrainingJob]:
return self.jobs_list
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(data=self.jobs_list)
@webmethod(route="/post-training/job/status")
async def get_training_job_status(

View file

@ -249,7 +249,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
order_by=order_by,
)
async def get_span_tree(
async def query_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,