Merge-related changes.

This commit is contained in:
ilya-kolchinsky 2025-04-02 19:56:44 +02:00
commit 60e9f46856
456 changed files with 38636 additions and 10892 deletions

View file

@ -36,7 +36,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -189,13 +188,11 @@ class AgentToolGroupWithArgs(BaseModel):
args: Dict[str, Any]
AgentToolGroup = register_schema(
Union[
str,
AgentToolGroupWithArgs,
],
name="AgentTool",
)
AgentToolGroup = Union[
str,
AgentToolGroupWithArgs,
]
register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
@ -312,20 +309,18 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
turn: Turn
AgentTurnResponseEventPayload = register_schema(
Annotated[
Union[
AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnAwaitingInputPayload,
],
Field(discriminator="event_type"),
AgentTurnResponseEventPayload = Annotated[
Union[
AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnAwaitingInputPayload,
],
name="AgentTurnResponseEventPayload",
)
Field(discriminator="event_type"),
]
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
@json_schema_type
@ -370,7 +365,7 @@ class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]]
tool_responses: List[ToolResponse]
stream: Optional[bool] = False
@ -387,7 +382,6 @@ class AgentStepResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Agents(Protocol):
"""Agents API for creating and interacting with agentic systems.
@ -399,7 +393,7 @@ class Agents(Protocol):
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
"""
@webmethod(route="/agents", method="POST")
@webmethod(route="/agents", method="POST", descriptive_name="create_agent")
async def create_agent(
self,
agent_config: AgentConfig,
@ -411,7 +405,9 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
)
async def create_agent_turn(
self,
agent_id: str,
@ -443,13 +439,14 @@ class Agents(Protocol):
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
)
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses.
@ -460,7 +457,6 @@ class Agents(Protocol):
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
NOTE: ToolResponseMessage will be deprecated. Use ToolResponse.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""
@ -506,7 +502,7 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/session", method="POST")
@webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
async def create_agent_session(
self,
agent_id: str,

View file

@ -52,7 +52,7 @@ class Benchmarks(Protocol):
async def get_benchmark(
self,
benchmark_id: str,
) -> Optional[Benchmark]: ...
) -> Benchmark: ...
@webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark(

View file

@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
# other modalities can be added here
InterleavedContentItem = register_schema(
Annotated[
Union[ImageContentItem, TextContentItem],
Field(discriminator="type"),
],
name="InterleavedContentItem",
)
InterleavedContentItem = Annotated[
Union[ImageContentItem, TextContentItem],
Field(discriminator="type"),
]
register_schema(InterleavedContentItem, name="InterleavedContentItem")
# accept a single "str" as a special case since it is common
InterleavedContent = register_schema(
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
name="InterleavedContent",
)
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
register_schema(InterleavedContent, name="InterleavedContent")
@json_schema_type
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
# streaming completions send a stream of ContentDeltas
ContentDelta = register_schema(
Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"),
],
name="ContentDelta",
)
ContentDelta = Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"),
]
register_schema(ContentDelta, name="ContentDelta")

View file

@ -10,14 +10,15 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Job(BaseModel):
job_id: str
@json_schema_type
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"
cancelled = "cancelled"
@json_schema_type
class Job(BaseModel):
job_id: str
status: JobStatus

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class PaginatedResponse(BaseModel):
"""A generic paginated response that follows a simple format.
:param data: The list of items for the current page
:param has_more: Whether there are more items available after this set
"""
data: List[Dict[str, Any]]
has_more: bool

View file

@ -72,24 +72,22 @@ class DialogType(BaseModel):
type: Literal["dialog"] = "dialog"
ParamType = register_schema(
Annotated[
Union[
StringType,
NumberType,
BooleanType,
ArrayType,
ObjectType,
JsonType,
UnionType,
ChatCompletionInputType,
CompletionInputType,
AgentTurnInputType,
],
Field(discriminator="type"),
ParamType = Annotated[
Union[
StringType,
NumberType,
BooleanType,
ArrayType,
ObjectType,
JsonType,
UnionType,
ChatCompletionInputType,
CompletionInputType,
AgentTurnInputType,
],
name="ParamType",
)
Field(discriminator="type"),
]
register_schema(ParamType, name="ParamType")
"""
# TODO: recursive definition of ParamType in these containers

View file

@ -6,26 +6,9 @@
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class PaginatedRowsResult(BaseModel):
"""
A paginated list of rows from a dataset.
:param rows: The rows in the current page.
:param total_count: The total number of rows in the dataset.
:param next_page_token: The token to get the next page of rows.
"""
# the rows obey the DatasetSchema for the given dataset
rows: List[Dict[str, Any]]
total_count: int
next_page_token: Optional[str] = None
from llama_stack.schema_utils import webmethod
class DatasetStore(Protocol):
@ -37,22 +20,28 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/rows", method="GET")
async def get_rows_paginated(
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> PaginatedResponse:
"""Get a paginated list of rows from a dataset.
Uses offset-based pagination where:
- start_index: The starting index (0-based). If None, starts from beginning.
- limit: Number of items to return. If None or -1, returns all items.
The response includes:
- data: List of items for the current page
- has_more: Whether there are more items available after this set
:param dataset_id: The ID of the dataset to get the rows from.
:param rows_in_page: The number of rows to get per page.
:param page_token: The token to get the next page of rows.
:param filter_condition: (Optional) A condition to filter the rows by.
:param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get.
"""
...
@webmethod(route="/datasetio/rows", method="POST")
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...

View file

@ -4,19 +4,100 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class DatasetPurpose(str, Enum):
"""
Purpose of the dataset. Each purpose has a required input data schema.
:cvar post-training/messages: The dataset contains messages used for post-training.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
:cvar eval/question-answer: The dataset contains a question column and an answer column.
{
"question": "What is the capital of France?",
"answer": "Paris"
}
:cvar eval/messages-answer: The dataset contains a messages column with list of messages and an answer column.
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
"""
post_training_messages = "post-training/messages"
eval_question_answer = "eval/question-answer"
eval_messages_answer = "eval/messages-answer"
# TODO: add more schemas here
class DatasetType(Enum):
"""
Type of the dataset source.
:cvar uri: The dataset can be obtained from a URI.
:cvar rows: The dataset is stored in rows.
"""
uri = "uri"
rows = "rows"
@json_schema_type
class URIDataSource(BaseModel):
"""A dataset that can be obtained from a URI.
:param uri: The dataset can be obtained from a URI. E.g.
- "https://mywebsite.com/mydata.jsonl"
- "lsfs://mydata.jsonl"
- "data:csv;base64,{base64_content}"
"""
type: Literal["uri"] = "uri"
uri: str
@json_schema_type
class RowsDataSource(BaseModel):
"""A dataset stored in rows.
:param rows: The dataset is stored in rows. E.g.
- [
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
]
"""
type: Literal["rows"] = "rows"
rows: List[Dict[str, Any]]
DataSource = Annotated[
Union[URIDataSource, RowsDataSource],
Field(discriminator="type"),
]
register_schema(DataSource, name="DataSource")
class CommonDatasetFields(BaseModel):
dataset_schema: Dict[str, ParamType]
url: URL
"""
Common fields for a dataset.
"""
purpose: DatasetPurpose
source: DataSource
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this dataset",
@ -38,8 +119,6 @@ class Dataset(CommonDatasetFields, Resource):
class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str
provider_id: Optional[str] = None
provider_dataset_id: Optional[str] = None
class ListDatasetsResponse(BaseModel):
@ -50,19 +129,75 @@ class Datasets(Protocol):
@webmethod(route="/datasets", method="POST")
async def register_dataset(
self,
dataset_id: str,
dataset_schema: Dict[str, ParamType],
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
dataset_id: Optional[str] = None,
) -> Dataset:
"""
Register a new dataset.
:param purpose: The purpose of the dataset. One of
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
{
"question": "What is the capital of France?",
"answer": "Paris"
}
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples:
- {
"type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl"
}
- {
"type": "uri",
"uri": "lsfs://mydata.jsonl"
}
- {
"type": "uri",
"uri": "data:csv;base64,{base64_content}"
}
- {
"type": "uri",
"uri": "huggingface://llamastack/simpleqa?split=train"
}
- {
"type": "rows",
"rows": [
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
]
}
:param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"}
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
async def get_dataset(
self,
dataset_id: str,
) -> Optional[Dataset]: ...
) -> Dataset: ...
@webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ...

View file

@ -14,6 +14,7 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Api(Enum):
providers = "providers"
inference = "inference"
safety = "safety"
agents = "agents"
@ -34,6 +35,7 @@ class Api(Enum):
scoring_functions = "scoring_functions"
benchmarks = "benchmarks"
tool_groups = "tool_groups"
files = "files"
preprocessors = "preprocessors"
# built-in API

View file

@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -43,10 +43,8 @@ class AgentCandidate(BaseModel):
config: AgentConfig
EvalCandidate = register_schema(
Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")],
name="EvalCandidate",
)
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate")
@json_schema_type
@ -117,7 +115,7 @@ class Eval(Protocol):
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.

View file

@ -115,7 +115,7 @@ class Files(Protocol):
async def get_upload_session_info(
self,
upload_id: str,
) -> Optional[FileUploadResponse]:
) -> FileUploadResponse:
"""
Returns information about an existsing upload session
@ -164,7 +164,7 @@ class Files(Protocol):
self,
bucket: str,
key: str,
) -> FileResponse:
) -> None:
"""
Delete a file identified by a bucket and key.

View file

@ -117,13 +117,11 @@ class ToolResponseMessage(BaseModel):
:param role: Must be "tool" to identify this as a tool response
:param call_id: Unique identifier for the tool call this response is for
:param tool_name: Name of the tool that was called
:param content: The response content from the tool
"""
role: Literal["tool"] = "tool"
call_id: str
tool_name: Union[BuiltinTool, str]
content: InterleavedContent
@ -146,18 +144,16 @@ class CompletionMessage(BaseModel):
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
Message = register_schema(
Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
],
Field(discriminator="role"),
Message = Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
],
name="Message",
)
Field(discriminator="role"),
]
register_schema(Message, name="Message")
@json_schema_type
@ -265,13 +261,11 @@ class GrammarResponseFormat(BaseModel):
bnf: Dict[str, Any]
ResponseFormat = register_schema(
Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
],
name="ResponseFormat",
)
ResponseFormat = Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]
register_schema(ResponseFormat, name="ResponseFormat")
# This is an internally used class
@ -285,7 +279,7 @@ class CompletionRequest(BaseModel):
@json_schema_type
class CompletionResponse(BaseModel):
class CompletionResponse(MetricResponseMixin):
"""Response from a completion request.
:param content: The generated completion text
@ -299,7 +293,7 @@ class CompletionResponse(BaseModel):
@json_schema_type
class CompletionResponseStreamChunk(BaseModel):
class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens.
@ -368,7 +362,7 @@ class ChatCompletionRequest(BaseModel):
@json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed chat completion response.
:param event: The event containing the new content
@ -378,7 +372,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request.
:param completion_message: The complete response message
@ -400,7 +394,7 @@ class EmbeddingsResponse(BaseModel):
class ModelStore(Protocol):
def get_model(self, identifier: str) -> Model: ...
async def get_model(self, identifier: str) -> Model: ...
class TextTruncation(Enum):
@ -437,7 +431,7 @@ class Inference(Protocol):
- Embedding models: these models generate embeddings to be used for semantic search.
"""
model_store: ModelStore
model_store: ModelStore | None = None
@webmethod(route="/inference/completion", method="POST")
async def completion(

View file

@ -11,13 +11,6 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
@json_schema_type
class RouteInfo(BaseModel):
route: str
@ -36,19 +29,12 @@ class VersionInfo(BaseModel):
version: str
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
class ListRoutesResponse(BaseModel):
data: List[RouteInfo]
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/inspect/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/inspect/routes", method="GET")
async def list_routes(self) -> ListRoutesResponse: ...

View file

@ -66,7 +66,7 @@ class Models(Protocol):
async def get_model(
self,
model_id: str,
) -> Optional[Model]: ...
) -> Model: ...
@webmethod(route="/models", method="POST")
async def register_model(

View file

@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
group_size: int
AlgorithmConfig = register_schema(
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
name="AlgorithmConfig",
)
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@json_schema_type
@ -202,10 +200,10 @@ class PostTraining(Protocol):
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ...
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...

View file

@ -58,7 +58,7 @@ PreprocessorChain = List[PreprocessorChainElement]
@json_schema_type
class PreprocessorResponse(BaseModel):
success: bool
output_data_type: PreprocessingDataType
output_data_type: PreprocessingDataType | None
results: Optional[List[PreprocessingDataElement]] = None
@ -68,7 +68,7 @@ class PreprocessorStore(Protocol):
@runtime_checkable
class Preprocessing(Protocol):
preprocessor_store: PreprocessorStore
preprocessor_store: PreprocessorStore | None
@webmethod(route="/preprocess", method="POST")
async def preprocess(

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class Preprocessor(Resource):
type: Literal[ResourceType.preprocessor.value] = ResourceType.preprocessor.value
type: ResourceType = ResourceType.preprocessor
@property
def preprocessor_id(self) -> str:
@ -48,7 +48,7 @@ class Preprocessors(Protocol):
async def get_preprocessor(
self,
preprocessor_id: str,
) -> Optional[Preprocessor]: ...
) -> Preprocessor: ...
@webmethod(route="/preprocessors", method="POST")
async def register_preprocessor(

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .providers import * # noqa: F401 F403

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
config: Dict[str, Any]
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@runtime_checkable
class Providers(Protocol):
"""
Providers API for inspecting, listing, and modifying providers and their configurations.
"""
@webmethod(route="/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/providers/{provider_id}", method="GET")
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...

View file

@ -36,6 +36,7 @@ class ScoringFnParamsType(Enum):
@json_schema_type
class AggregationFunctionType(Enum):
average = "average"
weighted_average = "weighted_average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
@ -78,17 +79,15 @@ class BasicScoringFnParams(BaseModel):
)
ScoringFnParams = register_schema(
Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
Field(discriminator="type"),
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
name="ScoringFnParams",
)
Field(discriminator="type"),
]
register_schema(ScoringFnParams, name="ScoringFnParams")
class CommonScoringFnFields(BaseModel):
@ -135,7 +134,7 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(

View file

@ -49,7 +49,7 @@ class Shields(Protocol):
async def list_shields(self) -> ListShieldsResponse: ...
@webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
async def get_shield(self, identifier: str) -> Shield: ...
@webmethod(route="/shields", method="POST")
async def register_shield(

View file

@ -96,6 +96,13 @@ class MetricEvent(EventCommon):
unit: str
@json_schema_type
class MetricInResponse(BaseModel):
metric: str
value: Union[int, float]
unit: Optional[str] = None
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be inlcuded with the response
@ -117,7 +124,7 @@ class MetricEvent(EventCommon):
class MetricResponseMixin(BaseModel):
metrics: Optional[List[MetricEvent]] = None
metrics: Optional[List[MetricInResponse]] = None
@json_schema_type
@ -139,16 +146,14 @@ class SpanEndPayload(BaseModel):
status: SpanStatus
StructuredLogPayload = register_schema(
Annotated[
Union[
SpanStartPayload,
SpanEndPayload,
],
Field(discriminator="type"),
StructuredLogPayload = Annotated[
Union[
SpanStartPayload,
SpanEndPayload,
],
name="StructuredLogPayload",
)
Field(discriminator="type"),
]
register_schema(StructuredLogPayload, name="StructuredLogPayload")
@json_schema_type
@ -157,17 +162,15 @@ class StructuredLogEvent(EventCommon):
payload: StructuredLogPayload
Event = register_schema(
Annotated[
Union[
UnstructuredLogEvent,
MetricEvent,
StructuredLogEvent,
],
Field(discriminator="type"),
Event = Annotated[
Union[
UnstructuredLogEvent,
MetricEvent,
StructuredLogEvent,
],
name="Event",
)
Field(discriminator="type"),
]
register_schema(Event, name="Event")
@json_schema_type

View file

@ -18,6 +18,15 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
@json_schema_type
class RAGDocument(BaseModel):
"""
A document to be used for document ingestion in the RAG Tool.
:param document_id: The unique identifier for the document.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
:param metadata: Additional metadata for the document.
"""
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
@ -50,16 +59,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
template: str
RAGQueryGeneratorConfig = register_schema(
Annotated[
Union[
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
Field(discriminator="type"),
RAGQueryGeneratorConfig = Annotated[
Union[
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
name="RAGQueryGeneratorConfig",
)
Field(discriminator="type"),
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type

View file

@ -69,7 +69,7 @@ class ToolGroup(Resource):
@json_schema_type
class ToolInvocationResult(BaseModel):
content: InterleavedContent
content: Optional[InterleavedContent] = None
error_message: Optional[str] = None
error_code: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
@ -88,6 +88,10 @@ class ListToolsResponse(BaseModel):
data: List[Tool]
class ListToolDefsResponse(BaseModel):
data: list[ToolDef]
@runtime_checkable
@trace_protocol
class ToolGroups(Protocol):
@ -140,15 +144,15 @@ class SpecialToolGroup(Enum):
@runtime_checkable
@trace_protocol
class ToolRuntime(Protocol):
tool_store: ToolStore
tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime
rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ...
) -> ListToolDefsResponse: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:

View file

@ -50,7 +50,7 @@ class VectorDBs(Protocol):
async def get_vector_db(
self,
vector_db_id: str,
) -> Optional[VectorDB]: ...
) -> VectorDB: ...
@webmethod(route="/vector-dbs", method="POST")
async def register_vector_db(

View file

@ -36,7 +36,7 @@ class VectorDBStore(Protocol):
@runtime_checkable
@trace_protocol
class VectorIO(Protocol):
vector_db_store: VectorDBStore
vector_db_store: VectorDBStore | None = None
# this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion