More idiomatic REST API (#765)

# What does this PR do?

This PR changes our API to follow more idiomatic REST API approaches of
having paths being resources and methods indicating the action being
performed.

Changes made to generator:
1) removed the prefix check of "get" as its not required and is actually
needed for other method types too
2) removed _ check on path since variables can have "_"



## Test Plan

LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v
tests/client-sdk/agents/test_agents.py
This commit is contained in:
Dinesh Yeduguru 2025-01-15 13:20:09 -08:00 committed by GitHub
parent 6deef1ece0
commit 7fb2c1c48d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 2144 additions and 1917 deletions

View file

@ -8,7 +8,6 @@ import collections.abc
import enum
import inspect
import typing
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
@ -16,12 +15,7 @@ from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from termcolor import colored
from ..strong_typing.inspection import (
get_signature,
is_type_enum,
is_type_optional,
unwrap_optional_type,
)
from ..strong_typing.inspection import get_signature
def split_prefix(
@ -113,9 +107,6 @@ class EndpointOperation:
def get_route(self) -> str:
if self.route is not None:
assert (
"_" not in self.route
), f"route should not contain underscores: {self.route}"
return "/".join(["", LLAMA_STACK_API_VERSION, self.route.lstrip("/")])
route_parts = ["", LLAMA_STACK_API_VERSION, self.name]
@ -265,42 +256,16 @@ def get_endpoint_operations(
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
)
if is_type_optional(param_type):
inner_type: type = unwrap_optional_type(param_type)
else:
inner_type = param_type
if prefix == "get" and (
inner_type is bool
or inner_type is int
or inner_type is float
or inner_type is str
or inner_type is uuid.UUID
or is_type_enum(inner_type)
):
if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
if route_params is not None and param_name not in route_params:
raise ValidationError(
f"positional parameter '{param_name}' absent from user-defined route '{route}' for function '{func_name}'"
)
# simple type maps to route path element, e.g. /study/{uuid}/{version}
if prefix in ["get", "delete"]:
if route_params is not None and param_name in route_params:
path_params.append((param_name, param_type))
else:
if route_params is not None and param_name in route_params:
raise ValidationError(
f"query parameter '{param_name}' found in user-defined route '{route}' for function '{func_name}'"
)
# simple type maps to key=value pair in query string
query_params.append((param_name, param_type))
else:
if route_params is not None and param_name in route_params:
raise ValidationError(
f"user-defined route '{route}' for function '{func_name}' has parameter '{param_name}' of composite type: {param_type}"
)
request_params.append((param_name, param_type))
path_params.append((param_name, param_type))
else:
request_params.append((param_name, param_type))
# check if function has explicit return type
if signature.return_annotation is inspect.Signature.empty:
@ -335,19 +300,18 @@ def get_endpoint_operations(
response_type = process_type(return_type)
# set HTTP request method based on type of request and presence of payload
if not request_params:
if prefix in ["delete", "remove"]:
http_method = HTTPMethod.DELETE
else:
elif prefix == "post":
http_method = HTTPMethod.POST
elif prefix == "get":
http_method = HTTPMethod.GET
else:
if prefix == "set":
elif prefix == "set":
http_method = HTTPMethod.PUT
elif prefix == "update":
http_method = HTTPMethod.PATCH
else:
http_method = HTTPMethod.POST
raise ValidationError(f"unknown prefix {prefix}")
result.append(
EndpointOperation(

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -7,6 +7,7 @@
from datetime import datetime
from enum import Enum
from typing import (
Annotated,
Any,
AsyncIterator,
Dict,
@ -20,7 +21,6 @@ from typing import (
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
from llama_stack.apis.inference import (
@ -296,13 +296,13 @@ class AgentStepResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Agents(Protocol):
@webmethod(route="/agents/create")
@webmethod(route="/agents", method="POST")
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse: ...
@webmethod(route="/agents/turn/create")
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
async def create_agent_turn(
self,
agent_id: str,
@ -318,36 +318,52 @@ class Agents(Protocol):
toolgroups: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET"
)
async def get_agents_turn(
self, agent_id: str, session_id: str, turn_id: str
self,
agent_id: str,
session_id: str,
turn_id: str,
) -> Turn: ...
@webmethod(route="/agents/step/get")
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
)
async def get_agents_step(
self, agent_id: str, session_id: str, turn_id: str, step_id: str
self,
agent_id: str,
session_id: str,
turn_id: str,
step_id: str,
) -> AgentStepResponse: ...
@webmethod(route="/agents/session/create")
@webmethod(route="/agents/{agent_id}/session", method="POST")
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse: ...
@webmethod(route="/agents/session/get")
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
async def get_agents_session(
self,
agent_id: str,
session_id: str,
agent_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session: ...
@webmethod(route="/agents/session/delete")
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
async def delete_agents_session(
self,
session_id: str,
agent_id: str,
) -> None: ...
@webmethod(route="/agents/delete")
async def delete_agents(
@webmethod(route="/agents/{agent_id}", method="DELETE")
async def delete_agent(
self,
agent_id: str,
) -> None: ...

View file

@ -54,7 +54,7 @@ class BatchChatCompletionResponse(BaseModel):
@runtime_checkable
class BatchInference(Protocol):
@webmethod(route="/batch-inference/completion")
@webmethod(route="/batch-inference/completion", method="POST")
async def batch_completion(
self,
model: str,
@ -63,7 +63,7 @@ class BatchInference(Protocol):
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
@webmethod(route="/batch-inference/chat-completion")
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def batch_chat_completion(
self,
model: str,

View file

@ -29,7 +29,7 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/get-rows-paginated", method="GET")
@webmethod(route="/datasetio/rows", method="GET")
async def get_rows_paginated(
self,
dataset_id: str,
@ -38,7 +38,7 @@ class DatasetIO(Protocol):
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ...
@webmethod(route="/datasetio/append-rows", method="POST")
@webmethod(route="/datasetio/rows", method="POST")
async def append_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -7,11 +7,9 @@
from typing import Any, Dict, List, Literal, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
@ -44,8 +42,12 @@ class DatasetInput(CommonDatasetFields, BaseModel):
provider_dataset_id: Optional[str] = None
class ListDatasetsResponse(BaseModel):
data: List[Dataset]
class Datasets(Protocol):
@webmethod(route="/datasets/register", method="POST")
@webmethod(route="/datasets", method="POST")
async def register_dataset(
self,
dataset_id: str,
@ -56,16 +58,16 @@ class Datasets(Protocol):
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
@webmethod(route="/datasets/get", method="GET")
@webmethod(route="/datasets/{dataset_id}", method="GET")
async def get_dataset(
self,
dataset_id: str,
) -> Optional[Dataset]: ...
@webmethod(route="/datasets/list", method="GET")
async def list_datasets(self) -> List[Dataset]: ...
@webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ...
@webmethod(route="/datasets/unregister", method="POST")
@webmethod(route="/datasets/{dataset_id}", method="DELETE")
async def unregister_dataset(
self,
dataset_id: str,

View file

@ -7,9 +7,7 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
@ -76,7 +74,7 @@ class EvaluateResponse(BaseModel):
class Eval(Protocol):
@webmethod(route="/eval/run-eval", method="POST")
@webmethod(route="/eval/run", method="POST")
async def run_eval(
self,
task_id: str,
@ -92,11 +90,11 @@ class Eval(Protocol):
task_config: EvalTaskConfig,
) -> EvaluateResponse: ...
@webmethod(route="/eval/job/status", method="GET")
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/jobs/{job_id}", method="GET")
async def job_status(self, job_id: str, task_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/job/cancel", method="POST")
async def job_cancel(self, task_id: str, job_id: str) -> None: ...
@webmethod(route="/eval/jobs/cancel", method="POST")
async def job_cancel(self, job_id: str, task_id: str) -> None: ...
@webmethod(route="/eval/job/result", method="GET")
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...
@webmethod(route="/eval/jobs/{job_id}/result", method="GET")
async def job_result(self, job_id: str, task_id: str) -> EvaluateResponse: ...

View file

@ -6,7 +6,6 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
@ -40,15 +39,22 @@ class EvalTaskInput(CommonEvalTaskFields, BaseModel):
provider_eval_task_id: Optional[str] = None
class ListEvalTasksResponse(BaseModel):
data: List[EvalTask]
@runtime_checkable
class EvalTasks(Protocol):
@webmethod(route="/eval-tasks/list", method="GET")
async def list_eval_tasks(self) -> List[EvalTask]: ...
@webmethod(route="/eval-tasks", method="GET")
async def list_eval_tasks(self) -> ListEvalTasksResponse: ...
@webmethod(route="/eval-tasks/get", method="GET")
async def get_eval_task(self, name: str) -> Optional[EvalTask]: ...
@webmethod(route="/eval-tasks/{eval_task_id}", method="GET")
async def get_eval_task(
self,
eval_task_id: str,
) -> Optional[EvalTask]: ...
@webmethod(route="/eval-tasks/register", method="POST")
@webmethod(route="/eval-tasks", method="POST")
async def register_eval_task(
self,
eval_task_id: str,

View file

@ -291,7 +291,7 @@ class ModelStore(Protocol):
class Inference(Protocol):
model_store: ModelStore
@webmethod(route="/inference/completion")
@webmethod(route="/inference/completion", method="POST")
async def completion(
self,
model_id: str,
@ -302,7 +302,7 @@ class Inference(Protocol):
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
@webmethod(route="/inference/chat-completion")
@webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion(
self,
model_id: str,
@ -319,7 +319,7 @@ class Inference(Protocol):
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: ...
@webmethod(route="/inference/embeddings")
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
self,
model_id: str,

View file

@ -34,10 +34,14 @@ class VersionInfo(BaseModel):
version: str
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/providers/list", method="GET")
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/routes/list", method="GET")
async def list_routes(self) -> Dict[str, List[RouteInfo]]: ...

View file

@ -50,7 +50,7 @@ class Memory(Protocol):
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/memory/insert")
@webmethod(route="/memory/insert", method="POST")
async def insert_documents(
self,
bank_id: str,
@ -58,7 +58,7 @@ class Memory(Protocol):
ttl_seconds: Optional[int] = None,
) -> None: ...
@webmethod(route="/memory/query")
@webmethod(route="/memory/query", method="POST")
async def query_documents(
self,
bank_id: str,

View file

@ -16,7 +16,6 @@ from typing import (
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
@ -133,16 +132,23 @@ class MemoryBankInput(BaseModel):
provider_memory_bank_id: Optional[str] = None
class ListMemoryBanksResponse(BaseModel):
data: List[MemoryBank]
@runtime_checkable
@trace_protocol
class MemoryBanks(Protocol):
@webmethod(route="/memory-banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory-banks", method="GET")
async def list_memory_banks(self) -> ListMemoryBanksResponse: ...
@webmethod(route="/memory-banks/get", method="GET")
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory-banks/{memory_bank_id}", method="GET")
async def get_memory_bank(
self,
memory_bank_id: str,
) -> Optional[MemoryBank]: ...
@webmethod(route="/memory-banks/register", method="POST")
@webmethod(route="/memory-banks", method="POST")
async def register_memory_bank(
self,
memory_bank_id: str,
@ -151,5 +157,5 @@ class MemoryBanks(Protocol):
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank: ...
@webmethod(route="/memory-banks/unregister", method="POST")
@webmethod(route="/memory-banks/{memory_bank_id}", method="DELETE")
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...

View file

@ -52,16 +52,23 @@ class ModelInput(CommonModelFields):
model_config = ConfigDict(protected_namespaces=())
class ListModelsResponse(BaseModel):
data: List[Model]
@runtime_checkable
@trace_protocol
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> List[Model]: ...
@webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ...
@webmethod(route="/models/get", method="GET")
async def get_model(self, identifier: str) -> Optional[Model]: ...
@webmethod(route="/models/{model_id}", method="GET")
async def get_model(
self,
model_id: str,
) -> Optional[Model]: ...
@webmethod(route="/models/register", method="POST")
@webmethod(route="/models", method="POST")
async def register_model(
self,
model_id: str,
@ -71,5 +78,8 @@ class Models(Protocol):
model_type: Optional[ModelType] = None,
) -> Model: ...
@webmethod(route="/models/unregister", method="POST")
async def unregister_model(self, model_id: str) -> None: ...
@webmethod(route="/models/{model_id}", method="DELETE")
async def unregister_model(
self,
model_id: str,
) -> None: ...

View file

@ -6,16 +6,13 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
@ -159,6 +156,10 @@ class PostTrainingJobStatusResponse(BaseModel):
checkpoints: List[Checkpoint] = Field(default_factory=list)
class ListPostTrainingJobsResponse(BaseModel):
data: List[PostTrainingJob]
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job."""
@ -197,7 +198,7 @@ class PostTraining(Protocol):
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(

View file

@ -12,7 +12,6 @@ from pydantic import BaseModel, Field
from llama_stack.apis.inference import Message
from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -49,7 +48,7 @@ class ShieldStore(Protocol):
class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run-shield")
@webmethod(route="/safety/run-shield", method="POST")
async def run_shield(
self,
shield_id: str,

View file

@ -11,7 +11,6 @@ from pydantic import BaseModel
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
# mapping of metric to value
ScoringResultRow = Dict[str, Any]
@ -43,7 +42,7 @@ class ScoringFunctionStore(Protocol):
class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score-batch")
@webmethod(route="/scoring/score-batch", method="POST")
async def score_batch(
self,
dataset_id: str,
@ -51,7 +50,7 @@ class Scoring(Protocol):
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score")
@webmethod(route="/scoring/score", method="POST")
async def score(
self,
input_rows: List[Dict[str, Any]],

View file

@ -21,7 +21,6 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
@ -129,15 +128,21 @@ class ScoringFnInput(CommonScoringFnFields, BaseModel):
provider_scoring_fn_id: Optional[str] = None
class ListScoringFunctionsResponse(BaseModel):
data: List[ScoringFn]
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions/list", method="GET")
async def list_scoring_functions(self) -> List[ScoringFn]: ...
@webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/get", method="GET")
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
async def get_scoring_function(
self, scoring_fn_id: str, /
) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions/register", method="POST")
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(
self,
scoring_fn_id: str,

View file

@ -38,16 +38,20 @@ class ShieldInput(CommonShieldFields):
provider_shield_id: Optional[str] = None
class ListShieldsResponse(BaseModel):
data: List[Shield]
@runtime_checkable
@trace_protocol
class Shields(Protocol):
@webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[Shield]: ...
@webmethod(route="/shields", method="GET")
async def list_shields(self) -> ListShieldsResponse: ...
@webmethod(route="/shields/get", method="GET")
@webmethod(route="/shields/{identifier}", method="GET")
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
@webmethod(route="/shields/register", method="POST")
@webmethod(route="/shields", method="POST")
async def register_shield(
self,
shield_id: str,

View file

@ -185,8 +185,8 @@ class Telemetry(Protocol):
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
@webmethod(route="/telemetry/get-span-tree", method="POST")
async def get_span_tree(
@webmethod(route="/telemetry/query-span-tree", method="POST")
async def query_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,

View file

@ -74,13 +74,21 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ...
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
class ListToolGroupsResponse(BaseModel):
data: List[ToolGroup]
class ListToolsResponse(BaseModel):
data: List[Tool]
@runtime_checkable
@trace_protocol
class ToolGroups(Protocol):
@webmethod(route="/toolgroups/register", method="POST")
@webmethod(route="/toolgroups", method="POST")
async def register_tool_group(
self,
toolgroup_id: str,
@ -91,27 +99,33 @@ class ToolGroups(Protocol):
"""Register a tool group"""
...
@webmethod(route="/toolgroups/get", method="GET")
@webmethod(route="/toolgroups/{toolgroup_id}", method="GET")
async def get_tool_group(
self,
toolgroup_id: str,
) -> ToolGroup: ...
@webmethod(route="/toolgroups/list", method="GET")
async def list_tool_groups(self) -> List[ToolGroup]:
@webmethod(route="/toolgroups", method="GET")
async def list_tool_groups(self) -> ListToolGroupsResponse:
"""List tool groups with optional provider"""
...
@webmethod(route="/tools/list", method="GET")
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
@webmethod(route="/tools", method="GET")
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
"""List tools with optional tool group"""
...
@webmethod(route="/tools/get", method="GET")
async def get_tool(self, tool_name: str) -> Tool: ...
@webmethod(route="/tools/{tool_name}", method="GET")
async def get_tool(
self,
tool_name: str,
) -> Tool: ...
@webmethod(route="/toolgroups/unregister", method="POST")
async def unregister_tool_group(self, tool_group_id: str) -> None:
@webmethod(route="/toolgroups/{toolgroup_id}", method="DELETE")
async def unregister_toolgroup(
self,
toolgroup_id: str,
) -> None:
"""Unregister a tool group"""
...

View file

@ -10,23 +10,32 @@ from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
from llama_stack.apis.memory_banks import (
BankParams,
ListMemoryBanksResponse,
MemoryBank,
MemoryBanks,
MemoryBankType,
)
from llama_stack.apis.models import Model, Models, ModelType
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.apis.tools import (
ListToolGroupsResponse,
ListToolsResponse,
Tool,
ToolGroup,
ToolGroups,
ToolHost,
)
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
@ -215,11 +224,11 @@ class CommonRoutingTableImpl(RoutingTable):
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[Model]:
return await self.get_all_with_type("model")
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def get_model(self, identifier: str) -> Optional[Model]:
return await self.get_object_by_identifier("model", identifier)
async def get_model(self, model_id: str) -> Optional[Model]:
return await self.get_object_by_identifier("model", model_id)
async def register_model(
self,
@ -265,8 +274,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[Shield]:
return await self.get_all_with_type(ResourceType.shield.value)
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(
data=await self.get_all_with_type(ResourceType.shield.value)
)
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier)
@ -301,8 +312,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBank]:
return await self.get_all_with_type(ResourceType.memory_bank.value)
async def list_memory_banks(self) -> ListMemoryBanksResponse:
return ListMemoryBanksResponse(data=await self.get_all_with_type("memory_bank"))
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
@ -365,8 +376,10 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[Dataset]:
return await self.get_all_with_type(ResourceType.dataset.value)
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(
data=await self.get_all_with_type(ResourceType.dataset.value)
)
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id)
@ -410,8 +423,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:
return await self.get_all_with_type(ResourceType.scoring_function.value)
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(
data=await self.get_all_with_type(ResourceType.scoring_function.value)
)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -447,11 +462,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
async def list_eval_tasks(self) -> List[EvalTask]:
return await self.get_all_with_type(ResourceType.eval_task.value)
async def list_eval_tasks(self) -> ListEvalTasksResponse:
return ListEvalTasksResponse(data=await self.get_all_with_type("eval_task"))
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier("eval_task", name)
async def get_eval_task(self, eval_task_id: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier("eval_task", eval_task_id)
async def register_eval_task(
self,
@ -485,14 +500,14 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
tools = await self.get_all_with_type("tool")
if tool_group_id:
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
return tools
if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
return ListToolsResponse(data=tools)
async def list_tool_groups(self) -> List[ToolGroup]:
return await self.get_all_with_type("tool_group")
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id)
@ -551,11 +566,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
)
)
async def unregister_tool_group(self, tool_group_id: str) -> None:
tool_group = await self.get_tool_group(tool_group_id)
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {tool_group_id} not found")
tools = await self.list_tools(tool_group_id)
raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id).data
for tool in tools:
await self.unregister_object(tool)
await self.unregister_object(tool_group)

View file

@ -14,16 +14,13 @@ import signal
import sys
import traceback
import warnings
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Any, Union
from typing import Any, List, Union
import yaml
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
@ -31,7 +28,6 @@ from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError
@ -41,13 +37,11 @@ from llama_stack.distribution.stack import (
replace_env_vars,
validate_env_pair,
)
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
@ -56,7 +50,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -178,7 +171,7 @@ async def sse_generator(event_gen):
)
def create_dynamic_typed_route(func: Any, method: str):
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
set_request_provider_data(request.headers)
@ -196,6 +189,7 @@ def create_dynamic_typed_route(func: Any, method: str):
raise translate_exception(e) from e
sig = inspect.signature(func)
new_params = [
inspect.Parameter(
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
@ -203,12 +197,21 @@ def create_dynamic_typed_route(func: Any, method: str):
]
new_params.extend(sig.parameters.values())
path_params = extract_path_params(route)
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
# Annotate parameters that are in the path with Path(...) and others with Body(...)
new_params = [new_params[0]] + [
param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
(
param.replace(
annotation=Annotated[
param.annotation, FastapiPath(..., title=param.name)
]
)
if param.name in path_params
else param.replace(
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
)
for param in new_params[1:]
]
@ -386,6 +389,7 @@ def main():
create_dynamic_typed_route(
impl_method,
endpoint.method,
endpoint.route,
)
)
@ -409,5 +413,13 @@ def main():
uvicorn.run(app, host=listen_host, port=args.port)
def extract_path_params(route: str) -> List[str]:
segments = route.split("/")
params = [
seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")
]
return params
if __name__ == "__main__":
main()

View file

@ -93,7 +93,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
await method(**obj.model_dump())
method = getattr(impls[api], list_method)
for obj in await method():
response = await method()
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
log.info(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
)

View file

@ -624,6 +624,10 @@ class ChatAgent(ShieldRunnerMixin):
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call=tool_call,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
content=tool_call,
),
)
)
)
@ -735,8 +739,8 @@ class ChatAgent(ShieldRunnerMixin):
for toolgroup_name in agent_config_toolgroups:
if toolgroup_name not in toolgroups_for_turn_set:
continue
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
for tool_def in tools:
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
for tool_def in tools.data:
if (
toolgroup_name.startswith("builtin")
and toolgroup_name != MEMORY_GROUP

View file

@ -223,5 +223,5 @@ class MetaReferenceAgentsImpl(Agents):
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
async def delete_agents(self, agent_id: str) -> None:
async def delete_agent(self, agent_id: str) -> None:
await self.persistence_store.delete(f"agent:{agent_id}")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from llama_models.schema_utils import webmethod
@ -14,6 +14,7 @@ from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
LoraFinetuningConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
@ -114,8 +115,8 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> List[PostTrainingJob]:
return self.jobs_list
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(data=self.jobs_list)
@webmethod(route="/post-training/job/status")
async def get_training_job_status(

View file

@ -249,7 +249,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
order_by=order_by,
)
async def get_span_tree(
async def query_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,

View file

@ -83,13 +83,13 @@ class TestClientTool(ClientTool):
def agent_config(llama_stack_client):
available_models = [
model.identifier
for model in llama_stack_client.models.list()
for model in llama_stack_client.models.list().data
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
]
model_id = available_models[0]
print(f"Using model: {model_id}")
available_shields = [
shield.identifier for shield in llama_stack_client.shields.list()
shield.identifier for shield in llama_stack_client.shields.list().data
]
available_shields = available_shields[:1]
print(f"Using shield: {available_shields}")