mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
more idiomatic REST API
This commit is contained in:
parent
d0a25dd453
commit
b438dad8d2
29 changed files with 2144 additions and 1917 deletions
|
@ -8,7 +8,6 @@ import collections.abc
|
||||||
import enum
|
import enum
|
||||||
import inspect
|
import inspect
|
||||||
import typing
|
import typing
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -16,12 +15,7 @@ from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from ..strong_typing.inspection import (
|
from ..strong_typing.inspection import get_signature
|
||||||
get_signature,
|
|
||||||
is_type_enum,
|
|
||||||
is_type_optional,
|
|
||||||
unwrap_optional_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def split_prefix(
|
def split_prefix(
|
||||||
|
@ -113,9 +107,6 @@ class EndpointOperation:
|
||||||
|
|
||||||
def get_route(self) -> str:
|
def get_route(self) -> str:
|
||||||
if self.route is not None:
|
if self.route is not None:
|
||||||
assert (
|
|
||||||
"_" not in self.route
|
|
||||||
), f"route should not contain underscores: {self.route}"
|
|
||||||
return "/".join(["", LLAMA_STACK_API_VERSION, self.route.lstrip("/")])
|
return "/".join(["", LLAMA_STACK_API_VERSION, self.route.lstrip("/")])
|
||||||
|
|
||||||
route_parts = ["", LLAMA_STACK_API_VERSION, self.name]
|
route_parts = ["", LLAMA_STACK_API_VERSION, self.name]
|
||||||
|
@ -265,42 +256,16 @@ def get_endpoint_operations(
|
||||||
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
|
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_type_optional(param_type):
|
if prefix in ["get", "delete"]:
|
||||||
inner_type: type = unwrap_optional_type(param_type)
|
if route_params is not None and param_name in route_params:
|
||||||
else:
|
|
||||||
inner_type = param_type
|
|
||||||
|
|
||||||
if prefix == "get" and (
|
|
||||||
inner_type is bool
|
|
||||||
or inner_type is int
|
|
||||||
or inner_type is float
|
|
||||||
or inner_type is str
|
|
||||||
or inner_type is uuid.UUID
|
|
||||||
or is_type_enum(inner_type)
|
|
||||||
):
|
|
||||||
if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
|
|
||||||
if route_params is not None and param_name not in route_params:
|
|
||||||
raise ValidationError(
|
|
||||||
f"positional parameter '{param_name}' absent from user-defined route '{route}' for function '{func_name}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# simple type maps to route path element, e.g. /study/{uuid}/{version}
|
|
||||||
path_params.append((param_name, param_type))
|
path_params.append((param_name, param_type))
|
||||||
else:
|
else:
|
||||||
if route_params is not None and param_name in route_params:
|
|
||||||
raise ValidationError(
|
|
||||||
f"query parameter '{param_name}' found in user-defined route '{route}' for function '{func_name}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# simple type maps to key=value pair in query string
|
|
||||||
query_params.append((param_name, param_type))
|
query_params.append((param_name, param_type))
|
||||||
else:
|
else:
|
||||||
if route_params is not None and param_name in route_params:
|
if route_params is not None and param_name in route_params:
|
||||||
raise ValidationError(
|
path_params.append((param_name, param_type))
|
||||||
f"user-defined route '{route}' for function '{func_name}' has parameter '{param_name}' of composite type: {param_type}"
|
else:
|
||||||
)
|
request_params.append((param_name, param_type))
|
||||||
|
|
||||||
request_params.append((param_name, param_type))
|
|
||||||
|
|
||||||
# check if function has explicit return type
|
# check if function has explicit return type
|
||||||
if signature.return_annotation is inspect.Signature.empty:
|
if signature.return_annotation is inspect.Signature.empty:
|
||||||
|
@ -335,19 +300,18 @@ def get_endpoint_operations(
|
||||||
|
|
||||||
response_type = process_type(return_type)
|
response_type = process_type(return_type)
|
||||||
|
|
||||||
# set HTTP request method based on type of request and presence of payload
|
|
||||||
if not request_params:
|
|
||||||
if prefix in ["delete", "remove"]:
|
if prefix in ["delete", "remove"]:
|
||||||
http_method = HTTPMethod.DELETE
|
http_method = HTTPMethod.DELETE
|
||||||
else:
|
elif prefix == "post":
|
||||||
|
http_method = HTTPMethod.POST
|
||||||
|
elif prefix == "get":
|
||||||
http_method = HTTPMethod.GET
|
http_method = HTTPMethod.GET
|
||||||
else:
|
elif prefix == "set":
|
||||||
if prefix == "set":
|
|
||||||
http_method = HTTPMethod.PUT
|
http_method = HTTPMethod.PUT
|
||||||
elif prefix == "update":
|
elif prefix == "update":
|
||||||
http_method = HTTPMethod.PATCH
|
http_method = HTTPMethod.PATCH
|
||||||
else:
|
else:
|
||||||
http_method = HTTPMethod.POST
|
raise ValidationError(f"unknown prefix {prefix}")
|
||||||
|
|
||||||
result.append(
|
result.append(
|
||||||
EndpointOperation(
|
EndpointOperation(
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -7,6 +7,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Dict,
|
Dict,
|
||||||
|
@ -20,7 +21,6 @@ from typing import (
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -296,13 +296,13 @@ class AgentStepResponse(BaseModel):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Agents(Protocol):
|
class Agents(Protocol):
|
||||||
@webmethod(route="/agents/create")
|
@webmethod(route="/agents", method="POST")
|
||||||
async def create_agent(
|
async def create_agent(
|
||||||
self,
|
self,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
) -> AgentCreateResponse: ...
|
) -> AgentCreateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/turn/create")
|
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
|
||||||
async def create_agent_turn(
|
async def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
@ -318,36 +318,52 @@ class Agents(Protocol):
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/turn/get")
|
@webmethod(
|
||||||
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET"
|
||||||
|
)
|
||||||
async def get_agents_turn(
|
async def get_agents_turn(
|
||||||
self, agent_id: str, session_id: str, turn_id: str
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_id: str,
|
||||||
) -> Turn: ...
|
) -> Turn: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/step/get")
|
@webmethod(
|
||||||
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||||
|
method="GET",
|
||||||
|
)
|
||||||
async def get_agents_step(
|
async def get_agents_step(
|
||||||
self, agent_id: str, session_id: str, turn_id: str, step_id: str
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
step_id: str,
|
||||||
) -> AgentStepResponse: ...
|
) -> AgentStepResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/session/create")
|
@webmethod(route="/agents/{agent_id}/session", method="POST")
|
||||||
async def create_agent_session(
|
async def create_agent_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_name: str,
|
session_name: str,
|
||||||
) -> AgentSessionCreateResponse: ...
|
) -> AgentSessionCreateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/session/get")
|
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
|
||||||
async def get_agents_session(
|
async def get_agents_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
agent_id: str,
|
||||||
turn_ids: Optional[List[str]] = None,
|
turn_ids: Optional[List[str]] = None,
|
||||||
) -> Session: ...
|
) -> Session: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/session/delete")
|
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
|
||||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
|
async def delete_agents_session(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/delete")
|
@webmethod(route="/agents/{agent_id}", method="DELETE")
|
||||||
async def delete_agents(
|
async def delete_agent(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -54,7 +54,7 @@ class BatchChatCompletionResponse(BaseModel):
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class BatchInference(Protocol):
|
class BatchInference(Protocol):
|
||||||
@webmethod(route="/batch-inference/completion")
|
@webmethod(route="/batch-inference/completion", method="POST")
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -63,7 +63,7 @@ class BatchInference(Protocol):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchCompletionResponse: ...
|
) -> BatchCompletionResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/chat-completion")
|
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||||
async def batch_chat_completion(
|
async def batch_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -29,7 +29,7 @@ class DatasetIO(Protocol):
|
||||||
# keeping for aligning with inference/safety, but this is not used
|
# keeping for aligning with inference/safety, but this is not used
|
||||||
dataset_store: DatasetStore
|
dataset_store: DatasetStore
|
||||||
|
|
||||||
@webmethod(route="/datasetio/get-rows-paginated", method="GET")
|
@webmethod(route="/datasetio/rows", method="GET")
|
||||||
async def get_rows_paginated(
|
async def get_rows_paginated(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
@ -38,7 +38,7 @@ class DatasetIO(Protocol):
|
||||||
filter_condition: Optional[str] = None,
|
filter_condition: Optional[str] = None,
|
||||||
) -> PaginatedRowsResult: ...
|
) -> PaginatedRowsResult: ...
|
||||||
|
|
||||||
@webmethod(route="/datasetio/append-rows", method="POST")
|
@webmethod(route="/datasetio/rows", method="POST")
|
||||||
async def append_rows(
|
async def append_rows(
|
||||||
self, dataset_id: str, rows: List[Dict[str, Any]]
|
self, dataset_id: str, rows: List[Dict[str, Any]]
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -7,11 +7,9 @@
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
|
||||||
|
@ -44,8 +42,12 @@ class DatasetInput(CommonDatasetFields, BaseModel):
|
||||||
provider_dataset_id: Optional[str] = None
|
provider_dataset_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListDatasetsResponse(BaseModel):
|
||||||
|
data: List[Dataset]
|
||||||
|
|
||||||
|
|
||||||
class Datasets(Protocol):
|
class Datasets(Protocol):
|
||||||
@webmethod(route="/datasets/register", method="POST")
|
@webmethod(route="/datasets", method="POST")
|
||||||
async def register_dataset(
|
async def register_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
@ -56,16 +58,16 @@ class Datasets(Protocol):
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/datasets/get", method="GET")
|
@webmethod(route="/datasets/{dataset_id}", method="GET")
|
||||||
async def get_dataset(
|
async def get_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
) -> Optional[Dataset]: ...
|
) -> Optional[Dataset]: ...
|
||||||
|
|
||||||
@webmethod(route="/datasets/list", method="GET")
|
@webmethod(route="/datasets", method="GET")
|
||||||
async def list_datasets(self) -> List[Dataset]: ...
|
async def list_datasets(self) -> ListDatasetsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/datasets/unregister", method="POST")
|
@webmethod(route="/datasets/{dataset_id}", method="DELETE")
|
||||||
async def unregister_dataset(
|
async def unregister_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
|
|
@ -7,9 +7,7 @@
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig
|
from llama_stack.apis.agents import AgentConfig
|
||||||
|
@ -76,7 +74,7 @@ class EvaluateResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
class Eval(Protocol):
|
||||||
@webmethod(route="/eval/run-eval", method="POST")
|
@webmethod(route="/eval/run", method="POST")
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
|
@ -92,11 +90,11 @@ class Eval(Protocol):
|
||||||
task_config: EvalTaskConfig,
|
task_config: EvalTaskConfig,
|
||||||
) -> EvaluateResponse: ...
|
) -> EvaluateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/job/status", method="GET")
|
@webmethod(route="/eval/jobs/{job_id}", method="GET")
|
||||||
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
|
async def job_status(self, job_id: str, task_id: str) -> Optional[JobStatus]: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/job/cancel", method="POST")
|
@webmethod(route="/eval/jobs/cancel", method="POST")
|
||||||
async def job_cancel(self, task_id: str, job_id: str) -> None: ...
|
async def job_cancel(self, job_id: str, task_id: str) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/job/result", method="GET")
|
@webmethod(route="/eval/jobs/{job_id}/result", method="GET")
|
||||||
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...
|
async def job_result(self, job_id: str, task_id: str) -> EvaluateResponse: ...
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
@ -40,15 +39,22 @@ class EvalTaskInput(CommonEvalTaskFields, BaseModel):
|
||||||
provider_eval_task_id: Optional[str] = None
|
provider_eval_task_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListEvalTasksResponse(BaseModel):
|
||||||
|
data: List[EvalTask]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class EvalTasks(Protocol):
|
class EvalTasks(Protocol):
|
||||||
@webmethod(route="/eval-tasks/list", method="GET")
|
@webmethod(route="/eval-tasks", method="GET")
|
||||||
async def list_eval_tasks(self) -> List[EvalTask]: ...
|
async def list_eval_tasks(self) -> ListEvalTasksResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/eval-tasks/get", method="GET")
|
@webmethod(route="/eval-tasks/{eval_task_id}", method="GET")
|
||||||
async def get_eval_task(self, name: str) -> Optional[EvalTask]: ...
|
async def get_eval_task(
|
||||||
|
self,
|
||||||
|
eval_task_id: str,
|
||||||
|
) -> Optional[EvalTask]: ...
|
||||||
|
|
||||||
@webmethod(route="/eval-tasks/register", method="POST")
|
@webmethod(route="/eval-tasks", method="POST")
|
||||||
async def register_eval_task(
|
async def register_eval_task(
|
||||||
self,
|
self,
|
||||||
eval_task_id: str,
|
eval_task_id: str,
|
||||||
|
|
|
@ -291,7 +291,7 @@ class ModelStore(Protocol):
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
model_store: ModelStore
|
model_store: ModelStore
|
||||||
|
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion", method="POST")
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -302,7 +302,7 @@ class Inference(Protocol):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(route="/inference/chat-completion")
|
@webmethod(route="/inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -319,7 +319,7 @@ class Inference(Protocol):
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
]: ...
|
]: ...
|
||||||
|
|
||||||
@webmethod(route="/inference/embeddings")
|
@webmethod(route="/inference/embeddings", method="POST")
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -34,10 +34,14 @@ class VersionInfo(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
|
|
||||||
|
class ListProvidersResponse(BaseModel):
|
||||||
|
data: List[ProviderInfo]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Inspect(Protocol):
|
class Inspect(Protocol):
|
||||||
@webmethod(route="/providers/list", method="GET")
|
@webmethod(route="/providers/list", method="GET")
|
||||||
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
|
async def list_providers(self) -> ListProvidersResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/routes/list", method="GET")
|
@webmethod(route="/routes/list", method="GET")
|
||||||
async def list_routes(self) -> Dict[str, List[RouteInfo]]: ...
|
async def list_routes(self) -> Dict[str, List[RouteInfo]]: ...
|
||||||
|
|
|
@ -50,7 +50,7 @@ class Memory(Protocol):
|
||||||
|
|
||||||
# this will just block now until documents are inserted, but it should
|
# this will just block now until documents are inserted, but it should
|
||||||
# probably return a Job instance which can be polled for completion
|
# probably return a Job instance which can be polled for completion
|
||||||
@webmethod(route="/memory/insert")
|
@webmethod(route="/memory/insert", method="POST")
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
@ -58,7 +58,7 @@ class Memory(Protocol):
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/memory/query")
|
@webmethod(route="/memory/query", method="POST")
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
|
|
@ -16,7 +16,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
@ -133,16 +132,23 @@ class MemoryBankInput(BaseModel):
|
||||||
provider_memory_bank_id: Optional[str] = None
|
provider_memory_bank_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListMemoryBanksResponse(BaseModel):
|
||||||
|
data: List[MemoryBank]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class MemoryBanks(Protocol):
|
class MemoryBanks(Protocol):
|
||||||
@webmethod(route="/memory-banks/list", method="GET")
|
@webmethod(route="/memory-banks", method="GET")
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
async def list_memory_banks(self) -> ListMemoryBanksResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/memory-banks/get", method="GET")
|
@webmethod(route="/memory-banks/{memory_bank_id}", method="GET")
|
||||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ...
|
async def get_memory_bank(
|
||||||
|
self,
|
||||||
|
memory_bank_id: str,
|
||||||
|
) -> Optional[MemoryBank]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory-banks/register", method="POST")
|
@webmethod(route="/memory-banks", method="POST")
|
||||||
async def register_memory_bank(
|
async def register_memory_bank(
|
||||||
self,
|
self,
|
||||||
memory_bank_id: str,
|
memory_bank_id: str,
|
||||||
|
@ -151,5 +157,5 @@ class MemoryBanks(Protocol):
|
||||||
provider_memory_bank_id: Optional[str] = None,
|
provider_memory_bank_id: Optional[str] = None,
|
||||||
) -> MemoryBank: ...
|
) -> MemoryBank: ...
|
||||||
|
|
||||||
@webmethod(route="/memory-banks/unregister", method="POST")
|
@webmethod(route="/memory-banks/{memory_bank_id}", method="DELETE")
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||||
|
|
|
@ -52,16 +52,23 @@ class ModelInput(CommonModelFields):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
|
class ListModelsResponse(BaseModel):
|
||||||
|
data: List[Model]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Models(Protocol):
|
class Models(Protocol):
|
||||||
@webmethod(route="/models/list", method="GET")
|
@webmethod(route="/models", method="GET")
|
||||||
async def list_models(self) -> List[Model]: ...
|
async def list_models(self) -> ListModelsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/models/get", method="GET")
|
@webmethod(route="/models/{model_id}", method="GET")
|
||||||
async def get_model(self, identifier: str) -> Optional[Model]: ...
|
async def get_model(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
) -> Optional[Model]: ...
|
||||||
|
|
||||||
@webmethod(route="/models/register", method="POST")
|
@webmethod(route="/models", method="POST")
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -71,5 +78,8 @@ class Models(Protocol):
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/unregister", method="POST")
|
@webmethod(route="/models/{model_id}", method="DELETE")
|
||||||
async def unregister_model(self, model_id: str) -> None: ...
|
async def unregister_model(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
) -> None: ...
|
||||||
|
|
|
@ -6,16 +6,13 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
|
||||||
from llama_stack.apis.common.job_types import JobStatus
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.common.training_types import Checkpoint
|
from llama_stack.apis.common.training_types import Checkpoint
|
||||||
|
|
||||||
|
@ -159,6 +156,10 @@ class PostTrainingJobStatusResponse(BaseModel):
|
||||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ListPostTrainingJobsResponse(BaseModel):
|
||||||
|
data: List[PostTrainingJob]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PostTrainingJobArtifactsResponse(BaseModel):
|
class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
"""Artifacts of a finetuning job."""
|
"""Artifacts of a finetuning job."""
|
||||||
|
@ -197,7 +198,7 @@ class PostTraining(Protocol):
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status", method="GET")
|
@webmethod(route="/post-training/job/status", method="GET")
|
||||||
async def get_training_job_status(
|
async def get_training_job_status(
|
||||||
|
|
|
@ -12,7 +12,6 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,7 +48,7 @@ class ShieldStore(Protocol):
|
||||||
class Safety(Protocol):
|
class Safety(Protocol):
|
||||||
shield_store: ShieldStore
|
shield_store: ShieldStore
|
||||||
|
|
||||||
@webmethod(route="/safety/run-shield")
|
@webmethod(route="/safety/run-shield", method="POST")
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
|
|
|
@ -11,7 +11,6 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||||
|
|
||||||
|
|
||||||
# mapping of metric to value
|
# mapping of metric to value
|
||||||
ScoringResultRow = Dict[str, Any]
|
ScoringResultRow = Dict[str, Any]
|
||||||
|
|
||||||
|
@ -43,7 +42,7 @@ class ScoringFunctionStore(Protocol):
|
||||||
class Scoring(Protocol):
|
class Scoring(Protocol):
|
||||||
scoring_function_store: ScoringFunctionStore
|
scoring_function_store: ScoringFunctionStore
|
||||||
|
|
||||||
@webmethod(route="/scoring/score-batch")
|
@webmethod(route="/scoring/score-batch", method="POST")
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
@ -51,7 +50,7 @@ class Scoring(Protocol):
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse: ...
|
) -> ScoreBatchResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring/score")
|
@webmethod(route="/scoring/score", method="POST")
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
|
|
|
@ -21,7 +21,6 @@ from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,15 +128,21 @@ class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||||
provider_scoring_fn_id: Optional[str] = None
|
provider_scoring_fn_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListScoringFunctionsResponse(BaseModel):
|
||||||
|
data: List[ScoringFn]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class ScoringFunctions(Protocol):
|
class ScoringFunctions(Protocol):
|
||||||
@webmethod(route="/scoring-functions/list", method="GET")
|
@webmethod(route="/scoring-functions", method="GET")
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions/get", method="GET")
|
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
|
async def get_scoring_function(
|
||||||
|
self, scoring_fn_id: str, /
|
||||||
|
) -> Optional[ScoringFn]: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions/register", method="POST")
|
@webmethod(route="/scoring-functions", method="POST")
|
||||||
async def register_scoring_function(
|
async def register_scoring_function(
|
||||||
self,
|
self,
|
||||||
scoring_fn_id: str,
|
scoring_fn_id: str,
|
||||||
|
|
|
@ -38,16 +38,20 @@ class ShieldInput(CommonShieldFields):
|
||||||
provider_shield_id: Optional[str] = None
|
provider_shield_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListShieldsResponse(BaseModel):
|
||||||
|
data: List[Shield]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Shields(Protocol):
|
class Shields(Protocol):
|
||||||
@webmethod(route="/shields/list", method="GET")
|
@webmethod(route="/shields", method="GET")
|
||||||
async def list_shields(self) -> List[Shield]: ...
|
async def list_shields(self) -> ListShieldsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/shields/get", method="GET")
|
@webmethod(route="/shields/{identifier}", method="GET")
|
||||||
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
|
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
|
||||||
|
|
||||||
@webmethod(route="/shields/register", method="POST")
|
@webmethod(route="/shields", method="POST")
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
|
|
|
@ -185,8 +185,8 @@ class Telemetry(Protocol):
|
||||||
order_by: Optional[List[str]] = None,
|
order_by: Optional[List[str]] = None,
|
||||||
) -> List[Trace]: ...
|
) -> List[Trace]: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/get-span-tree", method="POST")
|
@webmethod(route="/telemetry/query-span-tree", method="POST")
|
||||||
async def get_span_tree(
|
async def query_span_tree(
|
||||||
self,
|
self,
|
||||||
span_id: str,
|
span_id: str,
|
||||||
attributes_to_return: Optional[List[str]] = None,
|
attributes_to_return: Optional[List[str]] = None,
|
||||||
|
|
|
@ -74,13 +74,21 @@ class ToolInvocationResult(BaseModel):
|
||||||
|
|
||||||
class ToolStore(Protocol):
|
class ToolStore(Protocol):
|
||||||
def get_tool(self, tool_name: str) -> Tool: ...
|
def get_tool(self, tool_name: str) -> Tool: ...
|
||||||
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
|
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ListToolGroupsResponse(BaseModel):
|
||||||
|
data: List[ToolGroup]
|
||||||
|
|
||||||
|
|
||||||
|
class ListToolsResponse(BaseModel):
|
||||||
|
data: List[Tool]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class ToolGroups(Protocol):
|
class ToolGroups(Protocol):
|
||||||
@webmethod(route="/toolgroups/register", method="POST")
|
@webmethod(route="/toolgroups", method="POST")
|
||||||
async def register_tool_group(
|
async def register_tool_group(
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
|
@ -91,27 +99,33 @@ class ToolGroups(Protocol):
|
||||||
"""Register a tool group"""
|
"""Register a tool group"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/toolgroups/get", method="GET")
|
@webmethod(route="/toolgroups/{toolgroup_id}", method="GET")
|
||||||
async def get_tool_group(
|
async def get_tool_group(
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
) -> ToolGroup: ...
|
) -> ToolGroup: ...
|
||||||
|
|
||||||
@webmethod(route="/toolgroups/list", method="GET")
|
@webmethod(route="/toolgroups", method="GET")
|
||||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||||
"""List tool groups with optional provider"""
|
"""List tool groups with optional provider"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/tools/list", method="GET")
|
@webmethod(route="/tools", method="GET")
|
||||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||||
"""List tools with optional tool group"""
|
"""List tools with optional tool group"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/tools/get", method="GET")
|
@webmethod(route="/tools/{tool_name}", method="GET")
|
||||||
async def get_tool(self, tool_name: str) -> Tool: ...
|
async def get_tool(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
) -> Tool: ...
|
||||||
|
|
||||||
@webmethod(route="/toolgroups/unregister", method="POST")
|
@webmethod(route="/toolgroups/{toolgroup_id}", method="DELETE")
|
||||||
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
async def unregister_toolgroup(
|
||||||
|
self,
|
||||||
|
toolgroup_id: str,
|
||||||
|
) -> None:
|
||||||
"""Unregister a tool group"""
|
"""Unregister a tool group"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
|
@ -10,23 +10,32 @@ from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.datasets import Dataset, Datasets
|
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
||||||
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks
|
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
|
||||||
from llama_stack.apis.memory_banks import (
|
from llama_stack.apis.memory_banks import (
|
||||||
BankParams,
|
BankParams,
|
||||||
|
ListMemoryBanksResponse,
|
||||||
MemoryBank,
|
MemoryBank,
|
||||||
MemoryBanks,
|
MemoryBanks,
|
||||||
MemoryBankType,
|
MemoryBankType,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, Models, ModelType
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.scoring_functions import (
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
ListScoringFunctionsResponse,
|
||||||
ScoringFn,
|
ScoringFn,
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
ScoringFunctions,
|
ScoringFunctions,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield, Shields
|
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
|
from llama_stack.apis.tools import (
|
||||||
|
ListToolGroupsResponse,
|
||||||
|
ListToolsResponse,
|
||||||
|
Tool,
|
||||||
|
ToolGroup,
|
||||||
|
ToolGroups,
|
||||||
|
ToolHost,
|
||||||
|
)
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
RoutableObject,
|
RoutableObject,
|
||||||
RoutableObjectWithProvider,
|
RoutableObjectWithProvider,
|
||||||
|
@ -215,11 +224,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
async def list_models(self) -> List[Model]:
|
async def list_models(self) -> ListModelsResponse:
|
||||||
return await self.get_all_with_type("model")
|
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||||
|
|
||||||
async def get_model(self, identifier: str) -> Optional[Model]:
|
async def get_model(self, model_id: str) -> Optional[Model]:
|
||||||
return await self.get_object_by_identifier("model", identifier)
|
return await self.get_object_by_identifier("model", model_id)
|
||||||
|
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
|
@ -265,8 +274,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def list_shields(self) -> List[Shield]:
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
return await self.get_all_with_type(ResourceType.shield.value)
|
return ListShieldsResponse(
|
||||||
|
data=await self.get_all_with_type(ResourceType.shield.value)
|
||||||
|
)
|
||||||
|
|
||||||
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
||||||
return await self.get_object_by_identifier("shield", identifier)
|
return await self.get_object_by_identifier("shield", identifier)
|
||||||
|
@ -301,8 +312,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> ListMemoryBanksResponse:
|
||||||
return await self.get_all_with_type(ResourceType.memory_bank.value)
|
return ListMemoryBanksResponse(data=await self.get_all_with_type("memory_bank"))
|
||||||
|
|
||||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
||||||
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
|
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
|
||||||
|
@ -365,8 +376,10 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
async def list_datasets(self) -> List[Dataset]:
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
return await self.get_all_with_type(ResourceType.dataset.value)
|
return ListDatasetsResponse(
|
||||||
|
data=await self.get_all_with_type(ResourceType.dataset.value)
|
||||||
|
)
|
||||||
|
|
||||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||||
|
@ -410,8 +423,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||||
return await self.get_all_with_type(ResourceType.scoring_function.value)
|
return ListScoringFunctionsResponse(
|
||||||
|
data=await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||||
|
)
|
||||||
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||||
|
@ -447,11 +462,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
|
||||||
|
|
||||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||||
async def list_eval_tasks(self) -> List[EvalTask]:
|
async def list_eval_tasks(self) -> ListEvalTasksResponse:
|
||||||
return await self.get_all_with_type(ResourceType.eval_task.value)
|
return ListEvalTasksResponse(data=await self.get_all_with_type("eval_task"))
|
||||||
|
|
||||||
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
|
async def get_eval_task(self, eval_task_id: str) -> Optional[EvalTask]:
|
||||||
return await self.get_object_by_identifier("eval_task", name)
|
return await self.get_object_by_identifier("eval_task", eval_task_id)
|
||||||
|
|
||||||
async def register_eval_task(
|
async def register_eval_task(
|
||||||
self,
|
self,
|
||||||
|
@ -485,14 +500,14 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||||
tools = await self.get_all_with_type("tool")
|
tools = await self.get_all_with_type("tool")
|
||||||
if tool_group_id:
|
if toolgroup_id:
|
||||||
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
|
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
||||||
return tools
|
return ListToolsResponse(data=tools)
|
||||||
|
|
||||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||||
return await self.get_all_with_type("tool_group")
|
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||||
|
|
||||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||||
return await self.get_object_by_identifier("tool_group", toolgroup_id)
|
return await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||||
|
@ -551,11 +566,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
tool_group = await self.get_tool_group(tool_group_id)
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
if tool_group is None:
|
if tool_group is None:
|
||||||
raise ValueError(f"Tool group {tool_group_id} not found")
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||||
tools = await self.list_tools(tool_group_id)
|
tools = await self.list_tools(toolgroup_id).data
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
await self.unregister_object(tool)
|
await self.unregister_object(tool)
|
||||||
await self.unregister_object(tool_group)
|
await self.unregister_object(tool_group)
|
||||||
|
|
|
@ -14,16 +14,13 @@ import signal
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from importlib.metadata import version as parse_version
|
from importlib.metadata import version as parse_version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Union
|
from typing import Any, List, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request
|
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
@ -31,7 +28,6 @@ from termcolor import cprint
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
|
@ -41,13 +37,11 @@ from llama_stack.distribution.stack import (
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
validate_env_pair,
|
validate_env_pair,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||||
TelemetryAdapter,
|
TelemetryAdapter,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
|
@ -56,7 +50,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
@ -178,7 +171,7 @@ async def sse_generator(event_gen):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str):
|
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
set_request_provider_data(request.headers)
|
set_request_provider_data(request.headers)
|
||||||
|
|
||||||
|
@ -196,6 +189,7 @@ def create_dynamic_typed_route(func: Any, method: str):
|
||||||
raise translate_exception(e) from e
|
raise translate_exception(e) from e
|
||||||
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
new_params = [
|
new_params = [
|
||||||
inspect.Parameter(
|
inspect.Parameter(
|
||||||
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
|
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
|
||||||
|
@ -203,12 +197,21 @@ def create_dynamic_typed_route(func: Any, method: str):
|
||||||
]
|
]
|
||||||
new_params.extend(sig.parameters.values())
|
new_params.extend(sig.parameters.values())
|
||||||
|
|
||||||
|
path_params = extract_path_params(route)
|
||||||
if method == "post":
|
if method == "post":
|
||||||
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
# Annotate parameters that are in the path with Path(...) and others with Body(...)
|
||||||
# do anything too intelligent and ask for some parameters in the query
|
|
||||||
# and some in the body
|
|
||||||
new_params = [new_params[0]] + [
|
new_params = [new_params[0]] + [
|
||||||
param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
(
|
||||||
|
param.replace(
|
||||||
|
annotation=Annotated[
|
||||||
|
param.annotation, FastapiPath(..., title=param.name)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if param.name in path_params
|
||||||
|
else param.replace(
|
||||||
|
annotation=Annotated[param.annotation, Body(..., embed=True)]
|
||||||
|
)
|
||||||
|
)
|
||||||
for param in new_params[1:]
|
for param in new_params[1:]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -386,6 +389,7 @@ def main():
|
||||||
create_dynamic_typed_route(
|
create_dynamic_typed_route(
|
||||||
impl_method,
|
impl_method,
|
||||||
endpoint.method,
|
endpoint.method,
|
||||||
|
endpoint.route,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -409,5 +413,13 @@ def main():
|
||||||
uvicorn.run(app, host=listen_host, port=args.port)
|
uvicorn.run(app, host=listen_host, port=args.port)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_path_params(route: str) -> List[str]:
|
||||||
|
segments = route.split("/")
|
||||||
|
params = [
|
||||||
|
seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")
|
||||||
|
]
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -93,7 +93,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||||
await method(**obj.model_dump())
|
await method(**obj.model_dump())
|
||||||
|
|
||||||
method = getattr(impls[api], list_method)
|
method = getattr(impls[api], list_method)
|
||||||
for obj in await method():
|
response = await method()
|
||||||
|
|
||||||
|
objects_to_process = response.data if hasattr(response, "data") else response
|
||||||
|
|
||||||
|
for obj in objects_to_process:
|
||||||
log.info(
|
log.info(
|
||||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||||
)
|
)
|
||||||
|
|
|
@ -624,6 +624,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
tool_call=tool_call,
|
tool_call=tool_call,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
content=tool_call,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -735,8 +739,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for toolgroup_name in agent_config_toolgroups:
|
for toolgroup_name in agent_config_toolgroups:
|
||||||
if toolgroup_name not in toolgroups_for_turn_set:
|
if toolgroup_name not in toolgroups_for_turn_set:
|
||||||
continue
|
continue
|
||||||
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
|
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||||
for tool_def in tools:
|
for tool_def in tools.data:
|
||||||
if (
|
if (
|
||||||
toolgroup_name.startswith("builtin")
|
toolgroup_name.startswith("builtin")
|
||||||
and toolgroup_name != MEMORY_GROUP
|
and toolgroup_name != MEMORY_GROUP
|
||||||
|
|
|
@ -223,5 +223,5 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||||
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
|
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
|
||||||
|
|
||||||
async def delete_agents(self, agent_id: str) -> None:
|
async def delete_agent(self, agent_id: str) -> None:
|
||||||
await self.persistence_store.delete(f"agent:{agent_id}")
|
await self.persistence_store.delete(f"agent:{agent_id}")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import webmethod
|
from llama_models.schema_utils import webmethod
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ from llama_stack.apis.post_training import (
|
||||||
AlgorithmConfig,
|
AlgorithmConfig,
|
||||||
DPOAlignmentConfig,
|
DPOAlignmentConfig,
|
||||||
JobStatus,
|
JobStatus,
|
||||||
|
ListPostTrainingJobsResponse,
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
PostTrainingJob,
|
PostTrainingJob,
|
||||||
PostTrainingJobArtifactsResponse,
|
PostTrainingJobArtifactsResponse,
|
||||||
|
@ -114,8 +115,8 @@ class TorchtunePostTrainingImpl:
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
async def get_training_jobs(self) -> List[PostTrainingJob]:
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||||
return self.jobs_list
|
return ListPostTrainingJobsResponse(data=self.jobs_list)
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status")
|
@webmethod(route="/post-training/job/status")
|
||||||
async def get_training_job_status(
|
async def get_training_job_status(
|
||||||
|
|
|
@ -249,7 +249,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
order_by=order_by,
|
order_by=order_by,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_span_tree(
|
async def query_span_tree(
|
||||||
self,
|
self,
|
||||||
span_id: str,
|
span_id: str,
|
||||||
attributes_to_return: Optional[List[str]] = None,
|
attributes_to_return: Optional[List[str]] = None,
|
||||||
|
|
|
@ -83,13 +83,13 @@ class TestClientTool(ClientTool):
|
||||||
def agent_config(llama_stack_client):
|
def agent_config(llama_stack_client):
|
||||||
available_models = [
|
available_models = [
|
||||||
model.identifier
|
model.identifier
|
||||||
for model in llama_stack_client.models.list()
|
for model in llama_stack_client.models.list().data
|
||||||
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
||||||
]
|
]
|
||||||
model_id = available_models[0]
|
model_id = available_models[0]
|
||||||
print(f"Using model: {model_id}")
|
print(f"Using model: {model_id}")
|
||||||
available_shields = [
|
available_shields = [
|
||||||
shield.identifier for shield in llama_stack_client.shields.list()
|
shield.identifier for shield in llama_stack_client.shields.list().data
|
||||||
]
|
]
|
||||||
available_shields = available_shields[:1]
|
available_shields = available_shields[:1]
|
||||||
print(f"Using shield: {available_shields}")
|
print(f"Using shield: {available_shields}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue