wiprouters

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-03 17:06:43 +01:00
parent 357be98279
commit 8df9340dd3
No known key found for this signature in database
155 changed files with 61817 additions and 95863 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -15,6 +15,7 @@ using multiple validation tools and approaches.
import argparse import argparse
import json import json
import sys import sys
import traceback
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -44,6 +45,8 @@ def validate_openapi_schema(schema: dict[str, Any], schema_name: str = "OpenAPI
return False return False
except Exception as e: except Exception as e:
print(f"{schema_name} validation error: {e}") print(f"{schema_name} validation error: {e}")
print(" Traceback:")
traceback.print_exc()
return False return False

View file

@ -4,4 +4,108 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .agents import * # Import routes to trigger router registration
from . import routes # noqa: F401
from .agents_service import AgentsService
from .models import (
Agent,
AgentConfig,
AgentConfigCommon,
AgentConfigOverridablePerTurn,
AgentCreateResponse,
AgentSessionCreateResponse,
AgentStepResponse,
AgentToolGroup,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseEvent,
AgentTurnResponseEventPayload,
AgentTurnResponseEventType,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepStartPayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResumeRequest,
Attachment,
CreateAgentSessionRequest,
CreateOpenAIResponseRequest,
Document,
InferenceStep,
MemoryRetrievalStep,
ResponseGuardrail,
ResponseGuardrailSpec,
Session,
ShieldCallStep,
Step,
StepCommon,
StepType,
ToolExecutionStep,
Turn,
)
from .openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponsePrompt,
OpenAIResponseText,
)
# Backward compatibility - export Agents as alias for AgentsService
Agents = AgentsService
__all__ = [
"Agents",
"AgentsService",
"Agent",
"AgentConfig",
"AgentConfigCommon",
"AgentConfigOverridablePerTurn",
"AgentCreateResponse",
"AgentSessionCreateResponse",
"AgentStepResponse",
"AgentToolGroup",
"AgentToolGroupWithArgs",
"AgentTurnCreateRequest",
"AgentTurnResumeRequest",
"AgentTurnResponseEvent",
"AgentTurnResponseEventPayload",
"AgentTurnResponseEventType",
"AgentTurnResponseStepCompletePayload",
"AgentTurnResponseStepProgressPayload",
"AgentTurnResponseStepStartPayload",
"AgentTurnResponseStreamChunk",
"AgentTurnResponseTurnAwaitingInputPayload",
"AgentTurnResponseTurnCompletePayload",
"AgentTurnResponseTurnStartPayload",
"Attachment",
"CreateAgentSessionRequest",
"CreateOpenAIResponseRequest",
"Document",
"InferenceStep",
"MemoryRetrievalStep",
"ResponseGuardrail",
"ResponseGuardrailSpec",
"Session",
"ShieldCallStep",
"Step",
"StepCommon",
"StepType",
"ToolExecutionStep",
"Turn",
"ListOpenAIResponseInputItem",
"ListOpenAIResponseObject",
"OpenAIDeleteResponseObject",
"OpenAIResponseInput",
"OpenAIResponseInputTool",
"OpenAIResponseObject",
"OpenAIResponseObjectStream",
"OpenAIResponsePrompt",
"OpenAIResponseText",
]

View file

@ -1,814 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from datetime import datetime
from enum import StrEnum
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import Order, PaginatedResponse
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
SamplingParams,
ToolCall,
ToolChoice,
ToolConfig,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
from .openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponsePrompt,
OpenAIResponseText,
)
@json_schema_type
class ResponseGuardrailSpec(BaseModel):
"""Specification for a guardrail to apply during response generation.
:param type: The type/identifier of the guardrail.
"""
type: str
# TODO: more fields to be added for guardrail configuration
ResponseGuardrail = str | ResponseGuardrailSpec
class Attachment(BaseModel):
"""An attachment to an agent turn.
:param content: The content of the attachment.
:param mime_type: The MIME type of the attachment.
"""
content: InterleavedContent | URL
mime_type: str
class Document(BaseModel):
"""A document to be used by an agent.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
"""
content: InterleavedContent | URL
mime_type: str
class StepCommon(BaseModel):
"""A common step in an agent turn.
:param turn_id: The ID of the turn.
:param step_id: The ID of the step.
:param started_at: The time the step started.
:param completed_at: The time the step completed.
"""
turn_id: str
step_id: str
started_at: datetime | None = None
completed_at: datetime | None = None
class StepType(StrEnum):
"""Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM.
:cvar tool_execution: The step is a tool execution step that executes a tool call.
:cvar shield_call: The step is a shield call step that checks for safety violations.
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
"""
inference = "inference"
tool_execution = "tool_execution"
shield_call = "shield_call"
memory_retrieval = "memory_retrieval"
@json_schema_type
class InferenceStep(StepCommon):
"""An inference step in an agent turn.
:param model_response: The response from the LLM.
"""
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference] = StepType.inference
model_response: CompletionMessage
@json_schema_type
class ToolExecutionStep(StepCommon):
"""A tool execution step in an agent turn.
:param tool_calls: The tool calls to execute.
:param tool_responses: The tool responses from the tool calls.
"""
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
tool_calls: list[ToolCall]
tool_responses: list[ToolResponse]
@json_schema_type
class ShieldCallStep(StepCommon):
"""A shield call step in an agent turn.
:param violation: The violation from the shield call.
"""
step_type: Literal[StepType.shield_call] = StepType.shield_call
violation: SafetyViolation | None
@json_schema_type
class MemoryRetrievalStep(StepCommon):
"""A memory retrieval step in an agent turn.
:param vector_store_ids: The IDs of the vector databases to retrieve context from.
:param inserted_context: The context retrieved from the vector databases.
"""
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]?
vector_store_ids: str
inserted_context: InterleavedContent
Step = Annotated[
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
Field(discriminator="step_type"),
]
@json_schema_type
class Turn(BaseModel):
"""A single turn in an interaction with an Agentic System.
:param turn_id: Unique identifier for the turn within a session
:param session_id: Unique identifier for the conversation session
:param input_messages: List of messages that initiated this turn
:param steps: Ordered list of processing steps executed during this turn
:param output_message: The model's generated response containing content and metadata
:param output_attachments: (Optional) Files or media attached to the agent's response
:param started_at: Timestamp when the turn began
:param completed_at: (Optional) Timestamp when the turn finished, if completed
"""
turn_id: str
session_id: str
input_messages: list[UserMessage | ToolResponseMessage]
steps: list[Step]
output_message: CompletionMessage
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
started_at: datetime
completed_at: datetime | None = None
@json_schema_type
class Session(BaseModel):
"""A single session of an interaction with an Agentic System.
:param session_id: Unique identifier for the conversation session
:param session_name: Human-readable name for the session
:param turns: List of all turns that have occurred in this session
:param started_at: Timestamp when the session was created
"""
session_id: str
session_name: str
turns: list[Turn]
started_at: datetime
class AgentToolGroupWithArgs(BaseModel):
name: str
args: dict[str, Any]
AgentToolGroup = str | AgentToolGroupWithArgs
register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None)
max_infer_iters: int | None = 10
def model_post_init(self, __context):
if self.tool_config:
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
else:
params = {}
if self.tool_choice:
params["tool_choice"] = self.tool_choice
if self.tool_prompt_format:
params["tool_prompt_format"] = self.tool_prompt_format
self.tool_config = ToolConfig(**params)
@json_schema_type
class AgentConfig(AgentConfigCommon):
"""Configuration for an agent.
:param model: The model identifier to use for the agent
:param instructions: The system instructions for the agent
:param name: Optional name for the agent, used in telemetry and identification
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
:param response_format: Optional response format configuration
"""
model: str
instructions: str
name: str | None = None
enable_session_persistence: bool | None = False
response_format: ResponseFormat | None = None
@json_schema_type
class Agent(BaseModel):
"""An agent instance with configuration and metadata.
:param agent_id: Unique identifier for the agent
:param agent_config: Configuration settings for the agent
:param created_at: Timestamp when the agent was created
"""
agent_id: str
agent_config: AgentConfig
created_at: datetime
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = None
class AgentTurnResponseEventType(StrEnum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
turn_start = "turn_start"
turn_complete = "turn_complete"
turn_awaiting_input = "turn_awaiting_input"
@json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel):
"""Payload for step start events in agent turn responses.
:param event_type: Type of event being reported
:param step_type: Type of step being executed
:param step_id: Unique identifier for the step within a turn
:param metadata: (Optional) Additional metadata for the step
"""
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
step_type: StepType
step_id: str
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
@json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel):
"""Payload for step completion events in agent turn responses.
:param event_type: Type of event being reported
:param step_type: Type of step being executed
:param step_id: Unique identifier for the step within a turn
:param step_details: Complete details of the executed step
"""
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
step_type: StepType
step_id: str
step_details: Step
@json_schema_type
class AgentTurnResponseStepProgressPayload(BaseModel):
"""Payload for step progress events in agent turn responses.
:param event_type: Type of event being reported
:param step_type: Type of step being executed
:param step_id: Unique identifier for the step within a turn
:param delta: Incremental content changes during step execution
"""
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
step_type: StepType
step_id: str
delta: ContentDelta
@json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel):
"""Payload for turn start events in agent turn responses.
:param event_type: Type of event being reported
:param turn_id: Unique identifier for the turn within a session
"""
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
turn_id: str
@json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel):
"""Payload for turn completion events in agent turn responses.
:param event_type: Type of event being reported
:param turn: Complete turn data including all steps and results
"""
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
turn: Turn
@json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
"""Payload for turn awaiting input events in agent turn responses.
:param event_type: Type of event being reported
:param turn: Turn data when waiting for external tool responses
"""
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
turn: Turn
AgentTurnResponseEventPayload = Annotated[
AgentTurnResponseStepStartPayload
| AgentTurnResponseStepProgressPayload
| AgentTurnResponseStepCompletePayload
| AgentTurnResponseTurnStartPayload
| AgentTurnResponseTurnCompletePayload
| AgentTurnResponseTurnAwaitingInputPayload,
Field(discriminator="event_type"),
]
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
@json_schema_type
class AgentTurnResponseEvent(BaseModel):
"""An event in an agent turn response stream.
:param payload: Event-specific payload containing event data
"""
payload: AgentTurnResponseEventPayload
@json_schema_type
class AgentCreateResponse(BaseModel):
"""Response returned when creating a new agent.
:param agent_id: Unique identifier for the created agent
"""
agent_id: str
@json_schema_type
class AgentSessionCreateResponse(BaseModel):
"""Response returned when creating a new agent session.
:param session_id: Unique identifier for the created session
"""
session_id: str
@json_schema_type
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
"""Request to create a new turn for an agent.
:param agent_id: Unique identifier for the agent
:param session_id: Unique identifier for the conversation session
:param messages: List of messages to start the turn with
:param documents: (Optional) List of documents to provide to the agent
:param toolgroups: (Optional) List of tool groups to make available for this turn
:param stream: (Optional) Whether to stream the response
:param tool_config: (Optional) Tool configuration to override agent defaults
"""
agent_id: str
session_id: str
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: list[UserMessage | ToolResponseMessage]
documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
stream: bool | None = False
tool_config: ToolConfig | None = None
@json_schema_type
class AgentTurnResumeRequest(BaseModel):
"""Request to resume an agent turn with tool responses.
:param agent_id: Unique identifier for the agent
:param session_id: Unique identifier for the conversation session
:param turn_id: Unique identifier for the turn within a session
:param tool_responses: List of tool responses to submit to continue the turn
:param stream: (Optional) Whether to stream the response
"""
agent_id: str
session_id: str
turn_id: str
tool_responses: list[ToolResponse]
stream: bool | None = False
@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
"""Streamed agent turn completion response.
:param event: Individual event in the agent turn response stream
"""
event: AgentTurnResponseEvent
@json_schema_type
class AgentStepResponse(BaseModel):
"""Response containing details of a specific agent step.
:param step: The complete step data and execution details
"""
step: Step
@runtime_checkable
class Agents(Protocol):
"""Agents
APIs for creating and interacting with agentic systems."""
@webmethod(
route="/agents",
method="POST",
descriptive_name="create_agent",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse:
"""Create an agent with the given configuration.
:param agent_config: The configuration for the agent.
:returns: An AgentCreateResponse with the agent ID.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
descriptive_name="create_agent_turn",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Create a new turn for an agent.
:param agent_id: The ID of the agent to create the turn for.
:param session_id: The ID of the session to create the turn for.
:param messages: List of messages to start the turn with.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param documents: (Optional) List of documents to create the turn with.
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
level=LLAMA_STACK_API_V1ALPHA,
)
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: list[ToolResponse],
stream: bool | None = False,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
:param agent_id: The ID of the agent to resume.
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
) -> Turn:
"""Retrieve an agent turn by its ID.
:param agent_id: The ID of the agent to get the turn for.
:param session_id: The ID of the session to get the turn for.
:param turn_id: The ID of the turn to get.
:returns: A Turn.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_step(
self,
agent_id: str,
session_id: str,
turn_id: str,
step_id: str,
) -> AgentStepResponse:
"""Retrieve an agent step by its ID.
:param agent_id: The ID of the agent to get the step for.
:param session_id: The ID of the session to get the step for.
:param turn_id: The ID of the turn to get the step for.
:param step_id: The ID of the step to get.
:returns: An AgentStepResponse.
"""
...
@webmethod(
route="/agents/{agent_id}/session",
method="POST",
descriptive_name="create_agent_session",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
"""Create a new session for an agent.
:param agent_id: The ID of the agent to create the session for.
:param session_name: The name of the session to create.
:returns: An AgentSessionCreateResponse.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_session(
self,
session_id: str,
agent_id: str,
turn_ids: list[str] | None = None,
) -> Session:
"""Retrieve an agent session by its ID.
:param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by.
:returns: A Session.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="DELETE",
level=LLAMA_STACK_API_V1ALPHA,
)
async def delete_agents_session(
self,
session_id: str,
agent_id: str,
) -> None:
"""Delete an agent session by its ID and its associated turns.
:param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for.
"""
...
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def delete_agent(
self,
agent_id: str,
) -> None:
"""Delete an agent by its ID and its associated sessions and turns.
:param agent_id: The ID of the agent to delete.
"""
...
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents.
:param start_index: The index to start the pagination from.
:param limit: The number of agents to return.
:returns: A PaginatedResponse.
"""
...
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
:param agent_id: ID of the agent.
:returns: An Agent of the agent.
"""
...
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agent_sessions(
self,
agent_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
:param start_index: The index to start the pagination from.
:param limit: The number of sessions to return.
:returns: A PaginatedResponse.
"""
...
# We situate the OpenAI Responses API in the Agents API just like we did things
# for Inference. The Responses API, in its intent, serves the same purpose as
# the Agents API above -- it is essentially a lightweight "agentic loop" with
# integrated tool calling.
#
# Both of these APIs are inherently stateful.
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_openai_response(
self,
response_id: str,
) -> OpenAIResponseObject:
"""Get a model response.
:param response_id: The ID of the OpenAI response to retrieve.
:returns: An OpenAIResponseObject.
"""
...
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
prompt: OpenAIResponsePrompt | None = None,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
guardrails: Annotated[
list[ResponseGuardrail] | None,
ExtraBodyField(
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
),
] = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a model response.
:param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions.
:param prompt: (Optional) Prompt object with ID, version, and variables.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response.
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
:returns: An OpenAIResponseObject.
"""
...
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
"""List all responses.
:param after: The ID of the last response to return.
:param limit: The number of responses to return.
:param model: The model to filter responses by.
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
:returns: A ListOpenAIResponseObject.
"""
...
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
:param before: An item ID to list items before, used for pagination.
:param include: Additional fields to include in the response.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
:param order: The order to return the input items in. Default is desc.
:returns: An ListOpenAIResponseInputItem.
"""
...
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
"""Delete a response.
:param response_id: The ID of the OpenAI response to delete.
:returns: An OpenAIDeleteResponseObject
"""
...

View file

@ -0,0 +1,308 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Annotated, Protocol, runtime_checkable
from llama_stack.apis.common.responses import Order, PaginatedResponse
from llama_stack.apis.inference import ToolConfig, ToolResponse, ToolResponseMessage, UserMessage
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import ExtraBodyField
from .models import (
Agent,
AgentConfig,
AgentCreateResponse,
AgentSessionCreateResponse,
AgentStepResponse,
AgentToolGroup,
AgentTurnResponseStreamChunk,
Document,
ResponseGuardrail,
Session,
Turn,
)
from .openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponsePrompt,
OpenAIResponseText,
)
@runtime_checkable
@trace_protocol
class AgentsService(Protocol):
"""Agents
APIs for creating and interacting with agentic systems."""
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse:
"""Create an agent with the given configuration.
:param agent_config: The configuration for the agent.
:returns: An AgentCreateResponse with the agent ID.
"""
...
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Create a new turn for an agent.
:param agent_id: The ID of the agent to create the turn for.
:param session_id: The ID of the session to create the turn for.
:param messages: List of messages to start the turn with.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param documents: (Optional) List of documents to create the turn with.
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
"""
...
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: list[ToolResponse],
stream: bool | None = False,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
:param agent_id: The ID of the agent to resume.
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""
...
async def get_agents_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
) -> Turn:
"""Retrieve an agent turn by its ID.
:param agent_id: The ID of the agent to get the turn for.
:param session_id: The ID of the session to get the turn for.
:param turn_id: The ID of the turn to get.
:returns: A Turn.
"""
...
async def get_agents_step(
self,
agent_id: str,
session_id: str,
turn_id: str,
step_id: str,
) -> AgentStepResponse:
"""Retrieve an agent step by its ID.
:param agent_id: The ID of the agent to get the step for.
:param session_id: The ID of the session to get the step for.
:param turn_id: The ID of the turn to get the step for.
:param step_id: The ID of the step to get.
:returns: An AgentStepResponse.
"""
...
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
"""Create a new session for an agent.
:param agent_id: The ID of the agent to create the session for.
:param session_name: The name of the session to create.
:returns: An AgentSessionCreateResponse.
"""
...
async def get_agents_session(
self,
session_id: str,
agent_id: str,
turn_ids: list[str] | None = None,
) -> Session:
"""Retrieve an agent session by its ID.
:param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by.
:returns: A Session.
"""
...
async def delete_agents_session(
self,
session_id: str,
agent_id: str,
) -> None:
"""Delete an agent session by its ID and its associated turns.
:param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for.
"""
...
async def delete_agent(
self,
agent_id: str,
) -> None:
"""Delete an agent by its ID and its associated sessions and turns.
:param agent_id: The ID of the agent to delete.
"""
...
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents.
:param start_index: The index to start the pagination from.
:param limit: The number of agents to return.
:returns: A PaginatedResponse.
"""
...
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
:param agent_id: ID of the agent.
:returns: An Agent of the agent.
"""
...
async def list_agent_sessions(
self,
agent_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
:param start_index: The index to start the pagination from.
:param limit: The number of sessions to return.
:returns: A PaginatedResponse.
"""
...
async def get_openai_response(
self,
response_id: str,
) -> OpenAIResponseObject:
"""Get a model response.
:param response_id: The ID of the OpenAI response to retrieve.
:returns: An OpenAIResponseObject.
"""
...
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
prompt: OpenAIResponsePrompt | None = None,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: Annotated[
list[ResponseGuardrail] | None,
ExtraBodyField(
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
),
] = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a model response.
:param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions.
:param prompt: (Optional) Prompt object with ID, version, and variables.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response.
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
:returns: An OpenAIResponseObject.
"""
...
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
"""List all responses.
:param after: The ID of the last response to return.
:param limit: The number of responses to return.
:param model: The model to filter responses by.
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
:returns: A ListOpenAIResponseObject.
"""
...
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
:param before: An item ID to list items before, used for pagination.
:param include: Additional fields to include in the response.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
:param order: The order to return the input items in. Default is desc.
:returns: An ListOpenAIResponseInputItem.
"""
...
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
"""Delete a response.
:param response_id: The ID of the OpenAI response to delete.
:returns: An OpenAIDeleteResponseObject
"""
...

View file

@ -0,0 +1,409 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import StrEnum
from typing import Annotated, Any, Literal
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
SamplingParams,
ToolCall,
ToolChoice,
ToolConfig,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.schema_utils import json_schema_type, register_schema
from .openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponsePrompt,
OpenAIResponseText,
)
@json_schema_type
class ResponseGuardrailSpec(BaseModel):
"""Specification for a guardrail to apply during response generation."""
type: str = Field(description="The type/identifier of the guardrail.")
# TODO: more fields to be added for guardrail configuration
ResponseGuardrail = str | ResponseGuardrailSpec
class Attachment(BaseModel):
"""An attachment to an agent turn."""
content: InterleavedContent | URL = Field(description="The content of the attachment.")
mime_type: str = Field(description="The MIME type of the attachment.")
class Document(BaseModel):
"""A document to be used by an agent."""
content: InterleavedContent | URL = Field(description="The content of the document.")
mime_type: str = Field(description="The MIME type of the document.")
class StepCommon(BaseModel):
"""A common step in an agent turn."""
turn_id: str = Field(description="The ID of the turn.")
step_id: str = Field(description="The ID of the step.")
started_at: datetime | None = Field(default=None, description="The time the step started.")
completed_at: datetime | None = Field(default=None, description="The time the step completed.")
class StepType(StrEnum):
"""Type of the step in an agent turn."""
inference = "inference"
tool_execution = "tool_execution"
shield_call = "shield_call"
memory_retrieval = "memory_retrieval"
@json_schema_type
class InferenceStep(StepCommon):
"""An inference step in an agent turn."""
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference] = Field(default=StepType.inference)
model_response: CompletionMessage = Field(description="The response from the LLM.")
@json_schema_type
class ToolExecutionStep(StepCommon):
"""A tool execution step in an agent turn."""
step_type: Literal[StepType.tool_execution] = Field(default=StepType.tool_execution)
tool_calls: list[ToolCall] = Field(description="The tool calls to execute.")
tool_responses: list[ToolResponse] = Field(description="The tool responses from the tool calls.")
@json_schema_type
class ShieldCallStep(StepCommon):
"""A shield call step in an agent turn."""
step_type: Literal[StepType.shield_call] = Field(default=StepType.shield_call)
violation: SafetyViolation | None = Field(default=None, description="The violation from the shield call.")
@json_schema_type
class MemoryRetrievalStep(StepCommon):
"""A memory retrieval step in an agent turn."""
step_type: Literal[StepType.memory_retrieval] = Field(default=StepType.memory_retrieval)
# TODO: should this be List[str]?
vector_store_ids: str = Field(description="The IDs of the vector databases to retrieve context from.")
inserted_context: InterleavedContent = Field(description="The context retrieved from the vector databases.")
Step = Annotated[
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
Field(discriminator="step_type"),
]
@json_schema_type
class Turn(BaseModel):
"""A single turn in an interaction with an Agentic System."""
turn_id: str = Field(description="Unique identifier for the turn within a session")
session_id: str = Field(description="Unique identifier for the conversation session")
input_messages: list[UserMessage | ToolResponseMessage] = Field(
description="List of messages that initiated this turn"
)
steps: list[Step] = Field(description="Ordered list of processing steps executed during this turn")
output_message: CompletionMessage = Field(
description="The model's generated response containing content and metadata"
)
output_attachments: list[Attachment] | None = Field(
default_factory=lambda: [], description="Files or media attached to the agent's response"
)
started_at: datetime = Field(description="Timestamp when the turn began")
completed_at: datetime | None = Field(default=None, description="Timestamp when the turn finished, if completed")
@json_schema_type
class Session(BaseModel):
"""A single session of an interaction with an Agentic System."""
session_id: str = Field(description="Unique identifier for the conversation session")
session_name: str = Field(description="Human-readable name for the session")
turns: list[Turn] = Field(description="List of all turns that have occurred in this session")
started_at: datetime = Field(description="Timestamp when the session was created")
class AgentToolGroupWithArgs(BaseModel):
name: str = Field()
args: dict[str, Any] = Field()
AgentToolGroup = str | AgentToolGroupWithArgs
register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None)
max_infer_iters: int | None = 10
def model_post_init(self, __context):
if self.tool_config:
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
else:
params = {}
if self.tool_choice:
params["tool_choice"] = self.tool_choice
if self.tool_prompt_format:
params["tool_prompt_format"] = self.tool_prompt_format
self.tool_config = ToolConfig(**params)
@json_schema_type
class AgentConfig(AgentConfigCommon):
"""Configuration for an agent."""
model: str = Field(description="The model identifier to use for the agent")
instructions: str = Field(description="The system instructions for the agent")
name: str | None = Field(
default=None, description="Optional name for the agent, used in telemetry and identification"
)
enable_session_persistence: bool | None = Field(
default=False, description="Optional flag indicating whether session data has to be persisted"
)
response_format: ResponseFormat | None = Field(default=None, description="Optional response format configuration")
@json_schema_type
class Agent(BaseModel):
"""An agent instance with configuration and metadata."""
agent_id: str = Field(description="Unique identifier for the agent")
agent_config: AgentConfig = Field(description="Configuration settings for the agent")
created_at: datetime = Field(description="Timestamp when the agent was created")
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = Field(default=None)
class AgentTurnResponseEventType(StrEnum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
turn_start = "turn_start"
turn_complete = "turn_complete"
turn_awaiting_input = "turn_awaiting_input"
@json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel):
"""Payload for step start events in agent turn responses."""
event_type: Literal[AgentTurnResponseEventType.step_start] = Field(
default=AgentTurnResponseEventType.step_start, description="Type of event being reported"
)
step_type: StepType = Field(description="Type of step being executed")
step_id: str = Field(description="Unique identifier for the step within a turn")
metadata: dict[str, Any] | None = Field(default_factory=lambda: {}, description="Additional metadata for the step")
@json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel):
"""Payload for step completion events in agent turn responses."""
event_type: Literal[AgentTurnResponseEventType.step_complete] = Field(
default=AgentTurnResponseEventType.step_complete, description="Type of event being reported"
)
step_type: StepType = Field(description="Type of step being executed")
step_id: str = Field(description="Unique identifier for the step within a turn")
step_details: Step = Field(description="Complete details of the executed step")
@json_schema_type
class AgentTurnResponseStepProgressPayload(BaseModel):
"""Payload for step progress events in agent turn responses."""
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress] = Field(
default=AgentTurnResponseEventType.step_progress, description="Type of event being reported"
)
step_type: StepType = Field(description="Type of step being executed")
step_id: str = Field(description="Unique identifier for the step within a turn")
delta: ContentDelta = Field(description="Incremental content changes during step execution")
@json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel):
"""Payload for turn start events in agent turn responses."""
event_type: Literal[AgentTurnResponseEventType.turn_start] = Field(
default=AgentTurnResponseEventType.turn_start, description="Type of event being reported"
)
turn_id: str = Field(description="Unique identifier for the turn within a session")
@json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel):
"""Payload for turn completion events in agent turn responses."""
event_type: Literal[AgentTurnResponseEventType.turn_complete] = Field(
default=AgentTurnResponseEventType.turn_complete, description="Type of event being reported"
)
turn: Turn = Field(description="Complete turn data including all steps and results")
@json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
"""Payload for turn awaiting input events in agent turn responses."""
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = Field(
default=AgentTurnResponseEventType.turn_awaiting_input, description="Type of event being reported"
)
turn: Turn = Field(description="Turn data when waiting for external tool responses")
AgentTurnResponseEventPayload = Annotated[
AgentTurnResponseStepStartPayload
| AgentTurnResponseStepProgressPayload
| AgentTurnResponseStepCompletePayload
| AgentTurnResponseTurnStartPayload
| AgentTurnResponseTurnCompletePayload
| AgentTurnResponseTurnAwaitingInputPayload,
Field(discriminator="event_type"),
]
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
@json_schema_type
class AgentTurnResponseEvent(BaseModel):
"""An event in an agent turn response stream."""
payload: AgentTurnResponseEventPayload = Field(description="Event-specific payload containing event data")
@json_schema_type
class AgentCreateResponse(BaseModel):
"""Response returned when creating a new agent."""
agent_id: str = Field(description="Unique identifier for the created agent")
@json_schema_type
class AgentSessionCreateResponse(BaseModel):
"""Response returned when creating a new agent session."""
session_id: str = Field(description="Unique identifier for the created session")
@json_schema_type
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
"""Request to create a new turn for an agent."""
agent_id: str = Field(description="Unique identifier for the agent")
session_id: str = Field(description="Unique identifier for the conversation session")
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: list[UserMessage | ToolResponseMessage] = Field(description="List of messages to start the turn with")
documents: list[Document] | None = Field(default=None, description="List of documents to provide to the agent")
toolgroups: list[AgentToolGroup] | None = Field(
default_factory=lambda: [], description="List of tool groups to make available for this turn"
)
stream: bool | None = Field(default=False, description="Whether to stream the response")
tool_config: ToolConfig | None = Field(default=None, description="Tool configuration to override agent defaults")
@json_schema_type
class AgentTurnResumeRequest(BaseModel):
"""Request to resume an agent turn with tool responses."""
agent_id: str = Field(description="Unique identifier for the agent")
session_id: str = Field(description="Unique identifier for the conversation session")
turn_id: str = Field(description="Unique identifier for the turn within a session")
tool_responses: list[ToolResponse] = Field(description="List of tool responses to submit to continue the turn")
stream: bool | None = Field(default=False, description="Whether to stream the response")
@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
"""Streamed agent turn completion response."""
event: AgentTurnResponseEvent = Field(description="Individual event in the agent turn response stream")
@json_schema_type
class AgentStepResponse(BaseModel):
"""Response containing details of a specific agent step."""
step: Step = Field(description="The complete step data and execution details")
@json_schema_type
class CreateAgentSessionRequest(BaseModel):
"""Request to create a new session for an agent."""
agent_id: str = Field(..., description="The ID of the agent to create the session for")
session_name: str = Field(..., description="The name of the session to create")
@json_schema_type
class CreateOpenAIResponseRequest(BaseModel):
"""Request to create a model response."""
input: str | list[OpenAIResponseInput] = Field(..., description="Input message(s) to create the response")
model: str = Field(..., description="The underlying LLM used for completions")
prompt: OpenAIResponsePrompt | None = Field(None, description="Prompt object with ID, version, and variables")
instructions: str | None = Field(None, description="System instructions")
previous_response_id: str | None = Field(
None, description="If specified, the new response will be a continuation of the previous response"
)
conversation: str | None = Field(
None, description="The ID of a conversation to add the response to. Must begin with 'conv_'"
)
store: bool = Field(True, description="Whether to store the response")
stream: bool = Field(False, description="Whether to stream the response")
temperature: float | None = Field(None, description="Sampling temperature")
text: OpenAIResponseText | None = Field(None, description="Text generation parameters")
tools: list[OpenAIResponseInputTool] | None = Field(None, description="Tools to make available")
include: list[str] | None = Field(None, description="Additional fields to include in the response")
max_infer_iters: int = Field(10, description="Maximum number of inference iterations (extension to the OpenAI API)")
guardrails: list[ResponseGuardrail] | None = Field(
None, description="List of guardrails to apply during response generation"
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,452 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated
from fastapi import Body, Depends, Query, Request
from fastapi import Path as FastAPIPath
from llama_stack.apis.common.responses import Order
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .agents_service import AgentsService
from .models import (
Agent,
AgentConfig,
AgentCreateResponse,
AgentSessionCreateResponse,
AgentStepResponse,
AgentTurnCreateRequest,
AgentTurnResumeRequest,
CreateAgentSessionRequest,
CreateOpenAIResponseRequest,
Session,
Turn,
)
from .openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseObject,
)
def get_agents_service(request: Request) -> AgentsService:
"""Dependency to get the agents service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.agents not in impls:
raise ValueError("Agents API implementation not found")
return impls[Api.agents]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Agents"],
responses=standard_responses,
)
router_v1alpha = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
tags=["Agents"],
responses=standard_responses,
)
@router.post(
"/agents",
response_model=AgentCreateResponse,
summary="Create an agent.",
description="Create an agent with the given configuration.",
deprecated=True,
)
@router_v1alpha.post(
"/agents",
response_model=AgentCreateResponse,
summary="Create an agent.",
description="Create an agent with the given configuration.",
)
async def create_agent(
agent_config: AgentConfig = Body(...),
svc: AgentsService = Depends(get_agents_service),
) -> AgentCreateResponse:
"""Create an agent with the given configuration."""
return await svc.create_agent(agent_config=agent_config)
@router.post(
"/agents/{agent_id}/session/{session_id}/turn",
summary="Create a new turn for an agent.",
description="Create a new turn for an agent.",
deprecated=True,
)
@router_v1alpha.post(
"/agents/{{agent_id}}/session/{{session_id}}/turn",
summary="Create a new turn for an agent.",
description="Create a new turn for an agent.",
)
async def create_agent_turn(
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to create the turn for.")],
session_id: Annotated[str, FastAPIPath(..., description="The ID of the session to create the turn for.")],
body: AgentTurnCreateRequest = Body(...),
svc: AgentsService = Depends(get_agents_service),
):
"""Create a new turn for an agent."""
return await svc.create_agent_turn(
agent_id=agent_id,
session_id=session_id,
messages=body.messages,
stream=body.stream,
documents=body.documents,
toolgroups=body.toolgroups,
tool_config=body.tool_config,
)
@router.post(
"/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
summary="Resume an agent turn.",
description="Resume an agent turn with executed tool call responses.",
deprecated=True,
)
@router_v1alpha.post(
"/agents/{{agent_id}}/session/{{session_id}}/turn/{{turn_id}}/resume",
summary="Resume an agent turn.",
description="Resume an agent turn with executed tool call responses.",
)
async def resume_agent_turn(
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to resume.")],
session_id: Annotated[str, FastAPIPath(..., description="The ID of the session to resume.")],
turn_id: Annotated[str, FastAPIPath(..., description="The ID of the turn to resume.")],
body: AgentTurnResumeRequest = Body(...),
svc: AgentsService = Depends(get_agents_service),
):
"""Resume an agent turn with executed tool call responses."""
return await svc.resume_agent_turn(
agent_id=agent_id,
session_id=session_id,
turn_id=turn_id,
tool_responses=body.tool_responses,
stream=body.stream,
)
@router.get(
"/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
response_model=Turn,
summary="Retrieve an agent turn.",
description="Retrieve an agent turn by its ID.",
deprecated=True,
)
@router_v1alpha.get(
"/agents/{{agent_id}}/session/{{session_id}}/turn/{{turn_id}}",
response_model=Turn,
summary="Retrieve an agent turn.",
description="Retrieve an agent turn by its ID.",
)
async def get_agents_turn(
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to get the turn for.")],
session_id: Annotated[str, FastAPIPath(..., description="The ID of the session to get the turn for.")],
turn_id: Annotated[str, FastAPIPath(..., description="The ID of the turn to get.")],
svc: AgentsService = Depends(get_agents_service),
) -> Turn:
"""Retrieve an agent turn by its ID."""
return await svc.get_agents_turn(agent_id=agent_id, session_id=session_id, turn_id=turn_id)
@router.get(
"/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
response_model=AgentStepResponse,
summary="Retrieve an agent step.",
description="Retrieve an agent step by its ID.",
deprecated=True,
)
@router_v1alpha.get(
"/agents/{{agent_id}}/session/{{session_id}}/turn/{{turn_id}}/step/{{step_id}}",
response_model=AgentStepResponse,
summary="Retrieve an agent step.",
description="Retrieve an agent step by its ID.",
)
async def get_agents_step(
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to get the step for.")],
session_id: Annotated[str, FastAPIPath(..., description="The ID of the session to get the step for.")],
turn_id: Annotated[str, FastAPIPath(..., description="The ID of the turn to get the step for.")],
step_id: Annotated[str, FastAPIPath(..., description="The ID of the step to get.")],
svc: AgentsService = Depends(get_agents_service),
) -> AgentStepResponse:
"""Retrieve an agent step by its ID."""
return await svc.get_agents_step(agent_id=agent_id, session_id=session_id, turn_id=turn_id, step_id=step_id)
@router.post(
"/agents/{agent_id}/session",
response_model=AgentSessionCreateResponse,
summary="Create a new session for an agent.",
description="Create a new session for an agent.",
deprecated=True,
)
@router_v1alpha.post(
"/agents/{{agent_id}}/session",
response_model=AgentSessionCreateResponse,
summary="Create a new session for an agent.",
description="Create a new session for an agent.",
)
async def create_agent_session(
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to create the session for.")],
body: CreateAgentSessionRequest = Body(...),
svc: AgentsService = Depends(get_agents_service),
) -> AgentSessionCreateResponse:
"""Create a new session for an agent."""
return await svc.create_agent_session(agent_id=agent_id, session_name=body.session_name)
@router.get(
"/agents/{agent_id}/session/{session_id}",
response_model=Session,
summary="Retrieve an agent session.",
description="Retrieve an agent session by its ID.",
deprecated=True,
)
@router_v1alpha.get(
"/agents/{{agent_id}}/session/{{session_id}}",
response_model=Session,
summary="Retrieve an agent session.",
description="Retrieve an agent session by its ID.",
)
async def get_agents_session(
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to get the session for.")],
session_id: Annotated[str, FastAPIPath(..., description="The ID of the session to get.")],
turn_ids: list[str] | None = Query(None, description="List of turn IDs to filter the session by."),
svc: AgentsService = Depends(get_agents_service),
) -> Session:
"""Retrieve an agent session by its ID."""
return await svc.get_agents_session(session_id=session_id, agent_id=agent_id, turn_ids=turn_ids)
@router.delete(
"/agents/{agent_id}/session/{session_id}",
response_model=None,
status_code=204,
summary="Delete an agent session.",
description="Delete an agent session by its ID.",
deprecated=True,
)
@router_v1alpha.delete(
"/agents/{{agent_id}}/session/{{session_id}}",
response_model=None,
status_code=204,
summary="Delete an agent session.",
description="Delete an agent session by its ID.",
)
async def delete_agents_session(
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to delete the session for.")],
session_id: Annotated[str, FastAPIPath(..., description="The ID of the session to delete.")],
svc: AgentsService = Depends(get_agents_service),
) -> None:
"""Delete an agent session by its ID and its associated turns."""
await svc.delete_agents_session(session_id=session_id, agent_id=agent_id)
@router.delete(
"/agents/{agent_id}",
response_model=None,
status_code=204,
summary="Delete an agent.",
description="Delete an agent by its ID.",
deprecated=True,
)
@router_v1alpha.delete(
"/agents/{{agent_id}}",
response_model=None,
status_code=204,
summary="Delete an agent.",
description="Delete an agent by its ID.",
)
async def delete_agent(
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to delete.")],
svc: AgentsService = Depends(get_agents_service),
) -> None:
"""Delete an agent by its ID and its associated sessions and turns."""
await svc.delete_agent(agent_id=agent_id)
@router.get(
"/agents",
summary="List all agents.",
description="List all agents.",
deprecated=True,
)
@router_v1alpha.get(
"/agents",
summary="List all agents.",
description="List all agents.",
)
async def list_agents(
start_index: int | None = Query(None, description="The index to start the pagination from."),
limit: int | None = Query(None, description="The number of agents to return."),
svc: AgentsService = Depends(get_agents_service),
):
"""List all agents."""
return await svc.list_agents(start_index=start_index, limit=limit)
@router.get(
"/agents/{agent_id}",
response_model=Agent,
summary="Describe an agent.",
description="Describe an agent by its ID.",
deprecated=True,
)
@router_v1alpha.get(
"/agents/{{agent_id}}",
response_model=Agent,
summary="Describe an agent.",
description="Describe an agent by its ID.",
)
async def get_agent(
agent_id: Annotated[str, FastAPIPath(..., description="ID of the agent.")],
svc: AgentsService = Depends(get_agents_service),
) -> Agent:
"""Describe an agent by its ID."""
return await svc.get_agent(agent_id=agent_id)
@router.get(
"/agents/{agent_id}/sessions",
summary="List all sessions of an agent.",
description="List all session(s) of a given agent.",
deprecated=True,
)
@router_v1alpha.get(
"/agents/{{agent_id}}/sessions",
summary="List all sessions of an agent.",
description="List all session(s) of a given agent.",
)
async def list_agent_sessions(
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to list sessions for.")],
start_index: int | None = Query(None, description="The index to start the pagination from."),
limit: int | None = Query(None, description="The number of sessions to return."),
svc: AgentsService = Depends(get_agents_service),
):
"""List all session(s) of a given agent."""
return await svc.list_agent_sessions(agent_id=agent_id, start_index=start_index, limit=limit)
# OpenAI Responses API endpoints
@router.get(
"/responses/{response_id}",
response_model=OpenAIResponseObject,
summary="Get a model response.",
description="Get a model response.",
)
async def get_openai_response(
response_id: Annotated[str, FastAPIPath(..., description="The ID of the OpenAI response to retrieve.")],
svc: AgentsService = Depends(get_agents_service),
) -> OpenAIResponseObject:
"""Get a model response."""
return await svc.get_openai_response(response_id=response_id)
@router.post(
"/responses",
summary="Create a model response.",
description="Create a model response.",
)
async def create_openai_response(
body: CreateOpenAIResponseRequest = Body(...),
svc: AgentsService = Depends(get_agents_service),
):
"""Create a model response."""
return await svc.create_openai_response(
input=body.input,
model=body.model,
prompt=body.prompt,
instructions=body.instructions,
previous_response_id=body.previous_response_id,
conversation=body.conversation,
store=body.store,
stream=body.stream,
temperature=body.temperature,
text=body.text,
tools=body.tools,
include=body.include,
max_infer_iters=body.max_infer_iters,
guardrails=body.guardrails,
)
@router.get(
"/responses",
response_model=ListOpenAIResponseObject,
summary="List all responses.",
description="List all responses.",
)
async def list_openai_responses(
after: str | None = Query(None, description="The ID of the last response to return."),
limit: int | None = Query(50, description="The number of responses to return."),
model: str | None = Query(None, description="The model to filter responses by."),
order: Order | None = Query(
Order.desc, description="The order to sort responses by when sorted by created_at ('asc' or 'desc')."
),
svc: AgentsService = Depends(get_agents_service),
) -> ListOpenAIResponseObject:
"""List all responses."""
return await svc.list_openai_responses(after=after, limit=limit, model=model, order=order)
@router.get(
"/responses/{response_id}/input_items",
response_model=ListOpenAIResponseInputItem,
summary="List input items.",
description="List input items.",
)
async def list_openai_response_input_items(
response_id: Annotated[str, FastAPIPath(..., description="The ID of the response to retrieve input items for.")],
after: str | None = Query(None, description="An item ID to list items after, used for pagination."),
before: str | None = Query(None, description="An item ID to list items before, used for pagination."),
include: list[str] | None = Query(None, description="Additional fields to include in the response."),
limit: int | None = Query(
20,
description="A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.",
ge=1,
le=100,
),
order: Order | None = Query(Order.desc, description="The order to return the input items in. Default is desc."),
svc: AgentsService = Depends(get_agents_service),
) -> ListOpenAIResponseInputItem:
"""List input items."""
return await svc.list_openai_response_input_items(
response_id=response_id, after=after, before=before, include=include, limit=limit, order=order
)
@router.delete(
"/responses/{response_id}",
response_model=OpenAIDeleteResponseObject,
summary="Delete a response.",
description="Delete a response.",
)
async def delete_openai_response(
response_id: Annotated[str, FastAPIPath(..., description="The ID of the OpenAI response to delete.")],
svc: AgentsService = Depends(get_agents_service),
) -> OpenAIDeleteResponseObject:
"""Delete a response."""
return await svc.delete_openai_response(response_id=response_id)
# For backward compatibility with the router registry system
def create_agents_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Agents API (legacy compatibility)."""
main_router = APIRouter()
main_router.include_router(router)
main_router.include_router(router_v1alpha)
return main_router
# Register the router factory
register_router(Api.agents, create_agents_router)

View file

@ -4,6 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .batches import Batches, BatchObject, ListBatchesResponse try:
from openai.types import Batch as BatchObject
except ImportError:
BatchObject = None # type: ignore[assignment,misc]
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] # Import routes to trigger router registration
from . import routes # noqa: F401
from .batches_service import BatchService
from .models import CreateBatchRequest, ListBatchesResponse
# Backward compatibility - export Batches as alias for BatchService
Batches = BatchService
__all__ = ["Batches", "BatchService", "BatchObject", "ListBatchesResponse", "CreateBatchRequest"]

View file

@ -1,96 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
try:
from openai.types import Batch as BatchObject
except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
@json_schema_type
class ListBatchesResponse(BaseModel):
"""Response containing a list of batch objects."""
object: Literal["list"] = "list"
data: list[BatchObject] = Field(..., description="List of batch objects")
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
has_more: bool = Field(default=False, description="Whether there are more batches available")
@runtime_checkable
class Batches(Protocol):
"""
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions:
- idempotent batch creation
Note: This API is currently under active development and may undergo changes.
"""
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject:
"""Create a new batch for processing multiple API requests.
:param input_file_id: The ID of an uploaded file containing requests for the batch.
:param endpoint: The endpoint to be used for all requests in the batch.
:param completion_window: The time window within which the batch should be processed.
:param metadata: Optional metadata for the batch.
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
:returns: The created batch object.
"""
...
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch.
:param batch_id: The ID of the batch to retrieve.
:returns: The batch object.
"""
...
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress.
:param batch_id: The ID of the batch to cancel.
:returns: The updated batch object.
"""
...
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
) -> ListBatchesResponse:
"""List all batches for the current user.
:param after: A cursor for pagination; returns batches after this batch ID.
:param limit: Number of batches to return (default 20, max 100).
:returns: A list of batch objects.
"""
...

View file

@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Protocol, runtime_checkable
try:
from openai.types import Batch as BatchObject
except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
from .models import ListBatchesResponse
@runtime_checkable
class BatchService(Protocol):
"""The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions:
- idempotent batch creation
Note: This API is currently under active development and may undergo changes.
"""
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject:
"""Create a new batch for processing multiple API requests."""
...
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch."""
...
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress."""
...
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
) -> ListBatchesResponse:
"""List all batches for the current user."""
...

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
try:
from openai.types import Batch as BatchObject
except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
@json_schema_type
class CreateBatchRequest(BaseModel):
"""Request model for creating a batch."""
input_file_id: str = Field(..., description="The ID of an uploaded file containing requests for the batch.")
endpoint: str = Field(..., description="The endpoint to be used for all requests in the batch.")
completion_window: Literal["24h"] = Field(
..., description="The time window within which the batch should be processed."
)
metadata: dict[str, str] | None = Field(default=None, description="Optional metadata for the batch.")
idempotency_key: str | None = Field(
default=None, description="Optional idempotency key. When provided, enables idempotent behavior."
)
@json_schema_type
class ListBatchesResponse(BaseModel):
"""Response containing a list of batch objects."""
object: Literal["list"] = Field(default="list", description="The object type, which is always 'list'.")
data: list[BatchObject] = Field(..., description="List of batch objects.")
first_id: str | None = Field(default=None, description="ID of the first batch in the list.")
last_id: str | None = Field(default=None, description="ID of the last batch in the list.")
has_more: bool = Field(default=False, description="Whether there are more batches available.")

View file

@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated
from fastapi import Body, Depends, Query, Request
from fastapi import Path as FastAPIPath
try:
from openai.types import Batch as BatchObject
except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .batches_service import BatchService
from .models import CreateBatchRequest, ListBatchesResponse
def get_batch_service(request: Request) -> BatchService:
"""Dependency to get the batch service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.batches not in impls:
raise ValueError("Batches API implementation not found")
return impls[Api.batches]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Batches"],
responses=standard_responses,
)
@router.post(
"/batches",
response_model=BatchObject,
summary="Create a new batch for processing multiple API requests.",
description="Create a new batch for processing multiple API requests.",
)
async def create_batch(
request: CreateBatchRequest = Body(...),
svc: BatchService = Depends(get_batch_service),
) -> BatchObject:
"""Create a new batch."""
return await svc.create_batch(
input_file_id=request.input_file_id,
endpoint=request.endpoint,
completion_window=request.completion_window,
metadata=request.metadata,
idempotency_key=request.idempotency_key,
)
@router.get(
"/batches/{batch_id}",
response_model=BatchObject,
summary="Retrieve information about a specific batch.",
description="Retrieve information about a specific batch.",
)
async def retrieve_batch(
batch_id: Annotated[str, FastAPIPath(..., description="The ID of the batch to retrieve.")],
svc: BatchService = Depends(get_batch_service),
) -> BatchObject:
"""Retrieve batch information."""
return await svc.retrieve_batch(batch_id)
@router.post(
"/batches/{batch_id}/cancel",
response_model=BatchObject,
summary="Cancel a batch that is in progress.",
description="Cancel a batch that is in progress.",
)
async def cancel_batch(
batch_id: Annotated[str, FastAPIPath(..., description="The ID of the batch to cancel.")],
svc: BatchService = Depends(get_batch_service),
) -> BatchObject:
"""Cancel a batch."""
return await svc.cancel_batch(batch_id)
@router.get(
"/batches",
response_model=ListBatchesResponse,
summary="List all batches for the current user.",
description="List all batches for the current user.",
)
async def list_batches(
after: str | None = Query(None, description="A cursor for pagination; returns batches after this batch ID."),
limit: int = Query(20, description="Number of batches to return (default 20, max 100).", ge=1, le=100),
svc: BatchService = Depends(get_batch_service),
) -> ListBatchesResponse:
"""List all batches."""
return await svc.list_batches(after=after, limit=limit)
# For backward compatibility with the router registry system
def create_batches_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Batches API (legacy compatibility)."""
return router
# Register the router factory
register_router(Api.batches, create_batches_router)

View file

@ -4,4 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .benchmarks import * # Import routes to trigger router registration
from . import routes # noqa: F401
from .benchmarks_service import BenchmarksService
from .models import (
Benchmark,
BenchmarkInput,
CommonBenchmarkFields,
ListBenchmarksResponse,
RegisterBenchmarkRequest,
)
# Backward compatibility - export Benchmarks as alias for BenchmarksService
Benchmarks = BenchmarksService
__all__ = [
"Benchmarks",
"BenchmarksService",
"Benchmark",
"BenchmarkInput",
"CommonBenchmarkFields",
"ListBenchmarksResponse",
"RegisterBenchmarkRequest",
]

View file

@ -1,104 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonBenchmarkFields(BaseModel):
dataset_id: str
scoring_functions: list[str]
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Metadata for this evaluation task",
)
@json_schema_type
class Benchmark(CommonBenchmarkFields, Resource):
"""A benchmark resource for evaluating model performance.
:param dataset_id: Identifier of the dataset to use for the benchmark evaluation
:param scoring_functions: List of scoring function identifiers to apply during evaluation
:param metadata: Metadata for this evaluation task
:param type: The resource type, always benchmark
"""
type: Literal[ResourceType.benchmark] = ResourceType.benchmark
@property
def benchmark_id(self) -> str:
return self.identifier
@property
def provider_benchmark_id(self) -> str | None:
return self.provider_resource_id
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
benchmark_id: str
provider_id: str | None = None
provider_benchmark_id: str | None = None
class ListBenchmarksResponse(BaseModel):
data: list[Benchmark]
@runtime_checkable
class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_benchmarks(self) -> ListBenchmarksResponse:
"""List all benchmarks.
:returns: A ListBenchmarksResponse.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_benchmark(
self,
benchmark_id: str,
) -> Benchmark:
"""Get a benchmark by its ID.
:param benchmark_id: The ID of the benchmark to get.
:returns: A Benchmark.
"""
...
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: list[str],
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Register a benchmark.
:param benchmark_id: The ID of the benchmark to register.
:param dataset_id: The ID of the dataset to use for the benchmark.
:param scoring_functions: The scoring functions to use for the benchmark.
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
:param provider_id: The ID of the provider to use for the benchmark.
:param metadata: The metadata to use for the benchmark.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark.
:param benchmark_id: The ID of the benchmark to unregister.
"""
...

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from .models import Benchmark, ListBenchmarksResponse
@runtime_checkable
@trace_protocol
class BenchmarksService(Protocol):
async def list_benchmarks(self) -> ListBenchmarksResponse:
"""List all benchmarks."""
...
async def get_benchmark(
self,
benchmark_id: str,
) -> Benchmark:
"""Get a benchmark by its ID."""
...
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: list[str],
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Register a benchmark."""
...
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark."""
...

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type
class CommonBenchmarkFields(BaseModel):
dataset_id: str = Field(..., description="The ID of the dataset to use for the benchmark")
scoring_functions: list[str] = Field(..., description="The scoring functions to use for the benchmark")
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Metadata for this evaluation task",
)
@json_schema_type
class Benchmark(CommonBenchmarkFields, Resource):
"""A benchmark resource for evaluating model performance."""
type: Literal[ResourceType.benchmark] = Field(
default=ResourceType.benchmark, description="The resource type, always benchmark"
)
class ListBenchmarksResponse(BaseModel):
"""Response model for listing benchmarks."""
data: list[Benchmark] = Field(..., description="List of benchmark resources")
@json_schema_type
class RegisterBenchmarkRequest(BaseModel):
"""Request model for registering a benchmark."""
benchmark_id: str = Field(..., description="The ID of the benchmark to register")
dataset_id: str = Field(..., description="The ID of the dataset to use for the benchmark")
scoring_functions: list[str] = Field(..., description="The scoring functions to use for the benchmark")
provider_benchmark_id: str | None = Field(
default=None, description="The ID of the provider benchmark to use for the benchmark"
)
provider_id: str | None = Field(default=None, description="The ID of the provider to use for the benchmark")
metadata: dict[str, Any] | None = Field(default=None, description="The metadata to use for the benchmark")
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
benchmark_id: str = Field(..., description="The ID of the benchmark to use for the benchmark")
provider_id: str | None = Field(default=None, description="The ID of the provider to use for the benchmark")
provider_benchmark_id: str | None = Field(
default=None, description="The ID of the provider benchmark to use for the benchmark"
)

View file

@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated
from fastapi import Body, Depends, Request
from fastapi import Path as FastAPIPath
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .benchmarks_service import BenchmarksService
from .models import Benchmark, ListBenchmarksResponse, RegisterBenchmarkRequest
def get_benchmarks_service(request: Request) -> BenchmarksService:
"""Dependency to get the benchmarks service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.benchmarks not in impls:
raise ValueError("Benchmarks API implementation not found")
return impls[Api.benchmarks]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Benchmarks"],
responses=standard_responses,
)
router_v1alpha = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
tags=["Benchmarks"],
responses=standard_responses,
)
@router.get(
"/eval/benchmarks",
response_model=ListBenchmarksResponse,
summary="List all benchmarks",
description="List all benchmarks",
deprecated=True,
)
@router_v1alpha.get(
"/eval/benchmarks",
response_model=ListBenchmarksResponse,
summary="List all benchmarks",
description="List all benchmarks",
)
async def list_benchmarks(svc: BenchmarksService = Depends(get_benchmarks_service)) -> ListBenchmarksResponse:
"""List all benchmarks."""
return await svc.list_benchmarks()
@router.get(
"/eval/benchmarks/{benchmark_id}",
response_model=Benchmark,
summary="Get a benchmark by its ID",
description="Get a benchmark by its ID",
deprecated=True,
)
@router_v1alpha.get(
"/eval/benchmarks/{{benchmark_id}}",
response_model=Benchmark,
summary="Get a benchmark by its ID",
description="Get a benchmark by its ID",
)
async def get_benchmark(
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to get")],
svc: BenchmarksService = Depends(get_benchmarks_service),
) -> Benchmark:
"""Get a benchmark by its ID."""
return await svc.get_benchmark(benchmark_id=benchmark_id)
@router.post(
"/eval/benchmarks",
response_model=None,
status_code=204,
summary="Register a benchmark",
description="Register a benchmark",
deprecated=True,
)
@router_v1alpha.post(
"/eval/benchmarks",
response_model=None,
status_code=204,
summary="Register a benchmark",
description="Register a benchmark",
)
async def register_benchmark(
body: RegisterBenchmarkRequest = Body(...),
svc: BenchmarksService = Depends(get_benchmarks_service),
) -> None:
"""Register a benchmark."""
return await svc.register_benchmark(
benchmark_id=body.benchmark_id,
dataset_id=body.dataset_id,
scoring_functions=body.scoring_functions,
provider_benchmark_id=body.provider_benchmark_id,
provider_id=body.provider_id,
metadata=body.metadata,
)
@router.delete(
"/eval/benchmarks/{benchmark_id}",
response_model=None,
status_code=204,
summary="Unregister a benchmark",
description="Unregister a benchmark",
deprecated=True,
)
@router_v1alpha.delete(
"/eval/benchmarks/{{benchmark_id}}",
response_model=None,
status_code=204,
summary="Unregister a benchmark",
description="Unregister a benchmark",
)
async def unregister_benchmark(
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to unregister")],
svc: BenchmarksService = Depends(get_benchmarks_service),
) -> None:
"""Unregister a benchmark."""
await svc.unregister_benchmark(benchmark_id=benchmark_id)
# For backward compatibility with the router registry system
def create_benchmarks_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Benchmarks API (legacy compatibility)."""
main_router = APIRouter()
main_router.include_router(router)
main_router.include_router(router_v1alpha)
return main_router
# Register the router factory
register_router(Api.benchmarks, create_benchmarks_router)

View file

@ -15,21 +15,13 @@ from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type @json_schema_type
class URL(BaseModel): class URL(BaseModel):
"""A URL reference to external content. """A URL reference to external content."""
:param uri: The URL string pointing to the resource
"""
uri: str uri: str
class _URLOrData(BaseModel): class _URLOrData(BaseModel):
""" """A URL or a base64 encoded string."""
A URL or a base64 encoded string
:param url: A URL of the image or data URL in the format of data:image/{type};base64,{data}. Note that URL could have length limits.
:param data: base64 encoded image data as string
"""
url: URL | None = None url: URL | None = None
# data is a base64 encoded string, hint with contentEncoding=base64 # data is a base64 encoded string, hint with contentEncoding=base64
@ -45,11 +37,7 @@ class _URLOrData(BaseModel):
@json_schema_type @json_schema_type
class ImageContentItem(BaseModel): class ImageContentItem(BaseModel):
"""A image content item """A image content item."""
:param type: Discriminator type of the content item. Always "image"
:param image: Image as a base64 encoded string or an URL
"""
type: Literal["image"] = "image" type: Literal["image"] = "image"
image: _URLOrData image: _URLOrData
@ -57,11 +45,7 @@ class ImageContentItem(BaseModel):
@json_schema_type @json_schema_type
class TextContentItem(BaseModel): class TextContentItem(BaseModel):
"""A text content item """A text content item."""
:param type: Discriminator type of the content item. Always "text"
:param text: Text content
"""
type: Literal["text"] = "text" type: Literal["text"] = "text"
text: str text: str
@ -81,11 +65,7 @@ register_schema(InterleavedContent, name="InterleavedContent")
@json_schema_type @json_schema_type
class TextDelta(BaseModel): class TextDelta(BaseModel):
"""A text content delta for streaming responses. """A text content delta for streaming responses."""
:param type: Discriminator type of the delta. Always "text"
:param text: The incremental text content
"""
type: Literal["text"] = "text" type: Literal["text"] = "text"
text: str text: str
@ -93,23 +73,14 @@ class TextDelta(BaseModel):
@json_schema_type @json_schema_type
class ImageDelta(BaseModel): class ImageDelta(BaseModel):
"""An image content delta for streaming responses. """An image content delta for streaming responses."""
:param type: Discriminator type of the delta. Always "image"
:param image: The incremental image data as bytes
"""
type: Literal["image"] = "image" type: Literal["image"] = "image"
image: bytes image: bytes
class ToolCallParseStatus(Enum): class ToolCallParseStatus(Enum):
"""Status of tool call parsing during streaming. """Status of tool call parsing during streaming."""
:cvar started: Tool call parsing has begun
:cvar in_progress: Tool call parsing is ongoing
:cvar failed: Tool call parsing failed
:cvar succeeded: Tool call parsing completed successfully
"""
started = "started" started = "started"
in_progress = "in_progress" in_progress = "in_progress"
@ -119,12 +90,7 @@ class ToolCallParseStatus(Enum):
@json_schema_type @json_schema_type
class ToolCallDelta(BaseModel): class ToolCallDelta(BaseModel):
"""A tool call content delta for streaming responses. """A tool call content delta for streaming responses."""
:param type: Discriminator type of the delta. Always "tool_call"
:param tool_call: Either an in-progress tool call string or the final parsed tool call
:param parse_status: Current parsing status of the tool call
"""
type: Literal["tool_call"] = "tool_call" type: Literal["tool_call"] = "tool_call"

View file

@ -28,11 +28,7 @@ class JobStatus(Enum):
@json_schema_type @json_schema_type
class Job(BaseModel): class Job(BaseModel):
"""A job execution instance with status tracking. """A job execution instance with status tracking."""
:param job_id: Unique identifier for the job
:param status: Current execution status of the job
"""
job_id: str job_id: str
status: JobStatus status: JobStatus

View file

@ -7,16 +7,13 @@
from enum import Enum from enum import Enum
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
class Order(Enum): class Order(Enum):
"""Sort order for paginated responses. """Sort order for paginated responses."""
:cvar asc: Ascending order
:cvar desc: Descending order
"""
asc = "asc" asc = "asc"
desc = "desc" desc = "desc"
@ -24,13 +21,8 @@ class Order(Enum):
@json_schema_type @json_schema_type
class PaginatedResponse(BaseModel): class PaginatedResponse(BaseModel):
"""A generic paginated response that follows a simple format. """A generic paginated response that follows a simple format."""
:param data: The list of items for the current page data: list[dict[str, Any]] = Field(description="The list of items for the current page.")
:param has_more: Whether there are more items available after this set has_more: bool = Field(description="Whether there are more items available after this set.")
:param url: The URL for accessing this list url: str | None = Field(description="The URL for accessing this list.")
"""
data: list[dict[str, Any]]
has_more: bool
url: str | None = None

View file

@ -6,42 +6,28 @@
from datetime import datetime from datetime import datetime
from pydantic import BaseModel from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class PostTrainingMetric(BaseModel): class PostTrainingMetric(BaseModel):
"""Training metrics captured during post-training jobs. """Training metrics captured during post-training jobs."""
:param epoch: Training epoch number epoch: int = Field(description="Training epoch number.")
:param train_loss: Loss value on the training dataset train_loss: float = Field(description="Loss value on the training dataset.")
:param validation_loss: Loss value on the validation dataset validation_loss: float = Field(description="Loss value on the validation dataset.")
:param perplexity: Perplexity metric indicating model confidence perplexity: float = Field(description="Perplexity metric indicating model confidence.")
"""
epoch: int
train_loss: float
validation_loss: float
perplexity: float
@json_schema_type @json_schema_type
class Checkpoint(BaseModel): class Checkpoint(BaseModel):
"""Checkpoint created during training runs. """Checkpoint created during training runs."""
:param identifier: Unique identifier for the checkpoint identifier: str = Field(description="Unique identifier for the checkpoint.")
:param created_at: Timestamp when the checkpoint was created created_at: datetime = Field(description="Timestamp when the checkpoint was created.")
:param epoch: Training epoch when the checkpoint was saved epoch: int = Field(description="Training epoch when the checkpoint was saved.")
:param post_training_job_id: Identifier of the training job that created this checkpoint post_training_job_id: str = Field(description="Identifier of the training job that created this checkpoint.")
:param path: File system path where the checkpoint is stored path: str = Field(description="File system path where the checkpoint is stored.")
:param training_metrics: (Optional) Training metrics associated with this checkpoint training_metrics: PostTrainingMetric | None = Field(description="Training metrics associated with this checkpoint.")
"""
identifier: str
created_at: datetime
epoch: int
post_training_job_id: str
path: str
training_metrics: PostTrainingMetric | None = None

View file

@ -4,28 +4,38 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .conversations import ( # Import routes to trigger router registration
from . import routes # noqa: F401
from .conversations_service import ConversationService
from .models import (
Conversation, Conversation,
ConversationCreateRequest, ConversationCreateRequest,
ConversationDeletedResource, ConversationDeletedResource,
ConversationItem, ConversationItem,
ConversationItemCreateRequest, ConversationItemCreateRequest,
ConversationItemDeletedResource, ConversationItemDeletedResource,
ConversationItemInclude,
ConversationItemList, ConversationItemList,
Conversations, ConversationMessage,
ConversationUpdateRequest, ConversationUpdateRequest,
Metadata, Metadata,
) )
# Backward compatibility - export Conversations as alias for ConversationService
Conversations = ConversationService
__all__ = [ __all__ = [
"Conversation",
"ConversationCreateRequest",
"ConversationDeletedResource",
"ConversationItem",
"ConversationItemCreateRequest",
"ConversationItemDeletedResource",
"ConversationItemList",
"Conversations", "Conversations",
"ConversationService",
"Conversation",
"ConversationMessage",
"ConversationItem",
"ConversationCreateRequest",
"ConversationUpdateRequest", "ConversationUpdateRequest",
"ConversationDeletedResource",
"ConversationItemCreateRequest",
"ConversationItemList",
"ConversationItemDeletedResource",
"ConversationItemInclude",
"Metadata", "Metadata",
] ]

View file

@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Protocol, runtime_checkable
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from .models import (
Conversation,
ConversationDeletedResource,
ConversationItem,
ConversationItemDeletedResource,
ConversationItemInclude,
ConversationItemList,
Metadata,
)
@runtime_checkable
@trace_protocol
class ConversationService(Protocol):
"""Conversations
Protocol for conversation management operations."""
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation."""
...
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Retrieve a conversation."""
...
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation."""
...
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation."""
...
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create items."""
...
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve an item."""
...
async def list_items(
self,
conversation_id: str,
after: str | None = None,
include: list[ConversationItemInclude] | None = None,
limit: int | None = None,
order: Literal["asc", "desc"] | None = None,
) -> ConversationItemList:
"""List items."""
...
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete an item."""
...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import StrEnum from enum import StrEnum
from typing import Annotated, Literal, Protocol, runtime_checkable from typing import Annotated, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -20,9 +20,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPListTools, OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall, OpenAIResponseOutputMessageWebSearchToolCall,
) )
from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.schema_utils import json_schema_type, register_schema
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
Metadata = dict[str, str] Metadata = dict[str, str]
@ -76,31 +74,6 @@ ConversationItem = Annotated[
] ]
register_schema(ConversationItem, name="ConversationItem") register_schema(ConversationItem, name="ConversationItem")
# Using OpenAI types directly caused issues but some notes for reference:
# Note that ConversationItem is a Annotated Union of the types below:
# from openai.types.responses import *
# from openai.types.responses.response_item import *
# from openai.types.conversations import ConversationItem
# f = [
# ResponseFunctionToolCallItem,
# ResponseFunctionToolCallOutputItem,
# ResponseFileSearchToolCall,
# ResponseFunctionWebSearch,
# ImageGenerationCall,
# ResponseComputerToolCall,
# ResponseComputerToolCallOutputItem,
# ResponseReasoningItem,
# ResponseCodeInterpreterToolCall,
# LocalShellCall,
# LocalShellCallOutput,
# McpListTools,
# McpApprovalRequest,
# McpApprovalResponse,
# McpCall,
# ResponseCustomToolCall,
# ResponseCustomToolCallOutput
# ]
@json_schema_type @json_schema_type
class ConversationCreateRequest(BaseModel): class ConversationCreateRequest(BaseModel):
@ -180,119 +153,3 @@ class ConversationItemDeletedResource(BaseModel):
id: str = Field(..., description="The deleted item identifier") id: str = Field(..., description="The deleted item identifier")
object: str = Field(default="conversation.item.deleted", description="Object type") object: str = Field(default="conversation.item.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted") deleted: bool = Field(default=True, description="Whether the object was deleted")
@runtime_checkable
@trace_protocol
class Conversations(Protocol):
"""Conversations
Protocol for conversation management operations."""
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation.
Create a conversation.
:param items: Initial items to include in the conversation context.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The created conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Retrieve a conversation.
Get a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation.
Update a conversation's metadata with the given ID.
:param conversation_id: The conversation identifier.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The updated conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation.
Delete a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The deleted conversation resource.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create items.
Create items in the conversation.
:param conversation_id: The conversation identifier.
:param items: Items to include in the conversation context.
:returns: List of created items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve an item.
Retrieve a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The conversation item.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
async def list_items(
self,
conversation_id: str,
after: str | None = None,
include: list[ConversationItemInclude] | None = None,
limit: int | None = None,
order: Literal["asc", "desc"] | None = None,
) -> ConversationItemList:
"""List items.
List items in the conversation.
:param conversation_id: The conversation identifier.
:param after: An item ID to list items after, used in pagination.
:param include: Specify additional output data to include in the response.
:param limit: A limit on the number of objects to be returned (1-100, default 20).
:param order: The order to return items in (asc or desc, default desc).
:returns: List of conversation items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete an item.
Delete a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The deleted item resource.
"""
...

View file

@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Literal
from fastapi import Body, Depends, Query, Request
from fastapi import Path as FastAPIPath
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .conversations_service import ConversationService
from .models import (
Conversation,
ConversationCreateRequest,
ConversationDeletedResource,
ConversationItem,
ConversationItemCreateRequest,
ConversationItemDeletedResource,
ConversationItemInclude,
ConversationItemList,
ConversationUpdateRequest,
)
def get_conversation_service(request: Request) -> ConversationService:
"""Dependency to get the conversation service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.conversations not in impls:
raise ValueError("Conversations API implementation not found")
return impls[Api.conversations]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Conversations"],
responses=standard_responses,
)
@router.post(
"/conversations",
response_model=Conversation,
summary="Create a conversation",
description="Create a conversation",
)
async def create_conversation(
body: ConversationCreateRequest = Body(...),
svc: ConversationService = Depends(get_conversation_service),
) -> Conversation:
"""Create a conversation."""
return await svc.create_conversation(items=body.items, metadata=body.metadata)
@router.get(
"/conversations/{conversation_id}",
response_model=Conversation,
summary="Retrieve a conversation",
description="Get a conversation with the given ID",
)
async def get_conversation(
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
svc: ConversationService = Depends(get_conversation_service),
) -> Conversation:
"""Get a conversation."""
return await svc.get_conversation(conversation_id=conversation_id)
@router.post(
"/conversations/{conversation_id}",
response_model=Conversation,
summary="Update a conversation",
description="Update a conversation's metadata with the given ID",
)
async def update_conversation(
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
body: ConversationUpdateRequest = Body(...),
svc: ConversationService = Depends(get_conversation_service),
) -> Conversation:
"""Update a conversation."""
return await svc.update_conversation(conversation_id=conversation_id, metadata=body.metadata)
@router.delete(
"/conversations/{conversation_id}",
response_model=ConversationDeletedResource,
summary="Delete a conversation",
description="Delete a conversation with the given ID",
)
async def openai_delete_conversation(
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
svc: ConversationService = Depends(get_conversation_service),
) -> ConversationDeletedResource:
"""Delete a conversation."""
return await svc.openai_delete_conversation(conversation_id=conversation_id)
@router.post(
"/conversations/{conversation_id}/items",
response_model=ConversationItemList,
summary="Create items",
description="Create items in the conversation",
)
async def add_items(
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
body: ConversationItemCreateRequest = Body(...),
svc: ConversationService = Depends(get_conversation_service),
) -> ConversationItemList:
"""Create items in the conversation."""
return await svc.add_items(conversation_id=conversation_id, items=body.items)
@router.get(
"/conversations/{conversation_id}/items/{item_id}",
response_model=ConversationItem,
summary="Retrieve an item",
description="Retrieve a conversation item",
)
async def retrieve(
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
item_id: Annotated[str, FastAPIPath(..., description="The item identifier")],
svc: ConversationService = Depends(get_conversation_service),
) -> ConversationItem:
"""Retrieve a conversation item."""
return await svc.retrieve(conversation_id=conversation_id, item_id=item_id)
@router.get(
"/conversations/{conversation_id}/items",
response_model=ConversationItemList,
summary="List items",
description="List items in the conversation",
)
async def list_items(
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
after: str | None = Query(None, description="An item ID to list items after, used in pagination"),
include: list[ConversationItemInclude] | None = Query(
None, description="Specify additional output data to include in the response"
),
limit: int | None = Query(None, description="A limit on the number of objects to be returned (1-100, default 20)"),
order: Literal["asc", "desc"] | None = Query(
None, description="The order to return items in (asc or desc, default desc)"
),
svc: ConversationService = Depends(get_conversation_service),
) -> ConversationItemList:
"""List items in the conversation."""
return await svc.list_items(conversation_id=conversation_id, after=after, include=include, limit=limit, order=order)
@router.delete(
"/conversations/{conversation_id}/items/{item_id}",
response_model=ConversationItemDeletedResource,
summary="Delete an item",
description="Delete a conversation item",
)
async def openai_delete_conversation_item(
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
item_id: Annotated[str, FastAPIPath(..., description="The item identifier")],
svc: ConversationService = Depends(get_conversation_service),
) -> ConversationItemDeletedResource:
"""Delete a conversation item."""
return await svc.openai_delete_conversation_item(conversation_id=conversation_id, item_id=item_id)
# For backward compatibility with the router registry system
def create_conversations_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Conversations API (legacy compatibility)."""
return router
# Register the router factory
register_router(Api.conversations, create_conversations_router)

View file

@ -4,4 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .datasetio import * # Import routes to trigger router registration
from . import routes # noqa: F401
from .datasetio_service import DatasetIOService, DatasetStore
# Backward compatibility - export DatasetIO as alias for DatasetIOService
DatasetIO = DatasetIOService
__all__ = ["DatasetIO", "DatasetIOService", "DatasetStore"]

View file

@ -17,7 +17,7 @@ class DatasetStore(Protocol):
@runtime_checkable @runtime_checkable
class DatasetIO(Protocol): class DatasetIOService(Protocol):
# keeping for aligning with inference/safety, but this is not used # keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore dataset_store: DatasetStore
@ -28,28 +28,10 @@ class DatasetIO(Protocol):
start_index: int | None = None, start_index: int | None = None,
limit: int | None = None, limit: int | None = None,
) -> PaginatedResponse: ) -> PaginatedResponse:
"""Get a paginated list of rows from a dataset. """Get a paginated list of rows from a dataset."""
Uses offset-based pagination where:
- start_index: The starting index (0-based). If None, starts from beginning.
- limit: Number of items to return. If None or -1, returns all items.
The response includes:
- data: List of items for the current page.
- has_more: Whether there are more items available after this set.
:param dataset_id: The ID of the dataset to get the rows from.
:param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get.
:returns: A PaginatedResponse.
"""
... ...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1BETA) @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1BETA)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
"""Append rows to a dataset. """Append rows to a dataset."""
:param dataset_id: The ID of the dataset to append the rows to.
:param rows: The rows to append to the dataset.
"""
... ...

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Any
from fastapi import Body, Depends, Query, Request
from fastapi import Path as FastAPIPath
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1BETA
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .datasetio_service import DatasetIOService
def get_datasetio_service(request: Request) -> DatasetIOService:
"""Dependency to get the datasetio service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.datasetio not in impls:
raise ValueError("DatasetIO API implementation not found")
return impls[Api.datasetio]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1BETA}",
tags=["DatasetIO"],
responses=standard_responses,
)
@router.get(
"/datasetio/iterrows/{dataset_id:path}",
response_model=PaginatedResponse,
summary="Get a paginated list of rows from a dataset.",
description="Get a paginated list of rows from a dataset using offset-based pagination.",
)
async def iterrows(
dataset_id: Annotated[str, FastAPIPath(..., description="The ID of the dataset to get the rows from")],
start_index: int | None = Query(
None, description="Index into dataset for the first row to get. Get all rows if None."
),
limit: int | None = Query(None, description="The number of rows to get."),
svc: DatasetIOService = Depends(get_datasetio_service),
) -> PaginatedResponse:
"""Get a paginated list of rows from a dataset."""
return await svc.iterrows(dataset_id=dataset_id, start_index=start_index, limit=limit)
@router.post(
"/datasetio/append-rows/{dataset_id:path}",
response_model=None,
status_code=204,
summary="Append rows to a dataset.",
description="Append rows to a dataset.",
)
async def append_rows(
dataset_id: Annotated[str, FastAPIPath(..., description="The ID of the dataset to append the rows to")],
body: list[dict[str, Any]] = Body(..., description="The rows to append to the dataset."),
svc: DatasetIOService = Depends(get_datasetio_service),
) -> None:
"""Append rows to a dataset."""
await svc.append_rows(dataset_id=dataset_id, rows=body)
# For backward compatibility with the router registry system
def create_datasetio_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the DatasetIO API (legacy compatibility)."""
return router
# Register the router factory
register_router(Api.datasetio, create_datasetio_router)

View file

@ -4,4 +4,36 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .datasets import * # Import routes to trigger router registration
from . import routes # noqa: F401
from .datasets_service import DatasetsService
from .models import (
CommonDatasetFields,
Dataset,
DatasetInput,
DatasetPurpose,
DatasetType,
DataSource,
ListDatasetsResponse,
RegisterDatasetRequest,
RowsDataSource,
URIDataSource,
)
# Backward compatibility - export Datasets as alias for DatasetsService
Datasets = DatasetsService
__all__ = [
"Datasets",
"DatasetsService",
"Dataset",
"DatasetInput",
"CommonDatasetFields",
"DatasetPurpose",
"DatasetType",
"DataSource",
"URIDataSource",
"RowsDataSource",
"ListDatasetsResponse",
"RegisterDatasetRequest",
]

View file

@ -1,247 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1BETA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class DatasetPurpose(StrEnum):
"""
Purpose of the dataset. Each purpose has a required input data schema.
:cvar post-training/messages: The dataset contains messages used for post-training.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
:cvar eval/question-answer: The dataset contains a question column and an answer column.
{
"question": "What is the capital of France?",
"answer": "Paris"
}
:cvar eval/messages-answer: The dataset contains a messages column with list of messages and an answer column.
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
"""
post_training_messages = "post-training/messages"
eval_question_answer = "eval/question-answer"
eval_messages_answer = "eval/messages-answer"
# TODO: add more schemas here
class DatasetType(Enum):
"""
Type of the dataset source.
:cvar uri: The dataset can be obtained from a URI.
:cvar rows: The dataset is stored in rows.
"""
uri = "uri"
rows = "rows"
@json_schema_type
class URIDataSource(BaseModel):
"""A dataset that can be obtained from a URI.
:param uri: The dataset can be obtained from a URI. E.g.
- "https://mywebsite.com/mydata.jsonl"
- "lsfs://mydata.jsonl"
- "data:csv;base64,{base64_content}"
"""
type: Literal["uri"] = "uri"
uri: str
@json_schema_type
class RowsDataSource(BaseModel):
"""A dataset stored in rows.
:param rows: The dataset is stored in rows. E.g.
- [
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
]
"""
type: Literal["rows"] = "rows"
rows: list[dict[str, Any]]
DataSource = Annotated[
URIDataSource | RowsDataSource,
Field(discriminator="type"),
]
register_schema(DataSource, name="DataSource")
class CommonDatasetFields(BaseModel):
"""
Common fields for a dataset.
:param purpose: Purpose of the dataset indicating its intended use
:param source: Data source configuration for the dataset
:param metadata: Additional metadata for the dataset
"""
purpose: DatasetPurpose
source: DataSource
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this dataset",
)
@json_schema_type
class Dataset(CommonDatasetFields, Resource):
"""Dataset resource for storing and accessing training or evaluation data.
:param type: Type of resource, always 'dataset' for datasets
"""
type: Literal[ResourceType.dataset] = ResourceType.dataset
@property
def dataset_id(self) -> str:
return self.identifier
@property
def provider_dataset_id(self) -> str | None:
return self.provider_resource_id
class DatasetInput(CommonDatasetFields, BaseModel):
"""Input parameters for dataset operations.
:param dataset_id: Unique identifier for the dataset
"""
dataset_id: str
class ListDatasetsResponse(BaseModel):
"""Response from listing datasets.
:param data: List of datasets
"""
data: list[Dataset]
class Datasets(Protocol):
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
"""
Register a new dataset.
:param purpose: The purpose of the dataset.
One of:
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
{
"question": "What is the capital of France?",
"answer": "Paris"
}
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples:
- {
"type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl"
}
- {
"type": "uri",
"uri": "lsfs://mydata.jsonl"
}
- {
"type": "uri",
"uri": "data:csv;base64,{base64_content}"
}
- {
"type": "uri",
"uri": "huggingface://llamastack/simpleqa?split=train"
}
- {
"type": "rows",
"rows": [
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
]
}
:param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"}.
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
:returns: A Dataset.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
async def get_dataset(
self,
dataset_id: str,
) -> Dataset:
"""Get a dataset by its ID.
:param dataset_id: The ID of the dataset to get.
:returns: A Dataset.
"""
...
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1BETA)
async def list_datasets(self) -> ListDatasetsResponse:
"""List all datasets.
:returns: A ListDatasetsResponse.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
async def unregister_dataset(
self,
dataset_id: str,
) -> None:
"""Unregister a dataset by its ID.
:param dataset_id: The ID of the dataset to unregister.
"""
...

View file

@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from .models import Dataset, DatasetPurpose, DataSource, ListDatasetsResponse
@runtime_checkable
@trace_protocol
class DatasetsService(Protocol):
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
"""
Register a new dataset.
:param purpose: The purpose of the dataset.
One of:
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset.
:param metadata: The metadata for the dataset.
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
:returns: A Dataset.
"""
...
async def get_dataset(
self,
dataset_id: str,
) -> Dataset:
"""Get a dataset by its ID.
:param dataset_id: The ID of the dataset to get.
:returns: A Dataset.
"""
...
async def list_datasets(self) -> ListDatasetsResponse:
"""List all datasets.
:returns: A ListDatasetsResponse.
"""
...
async def unregister_dataset(
self,
dataset_id: str,
) -> None:
"""Unregister a dataset by its ID.
:param dataset_id: The ID of the dataset to unregister.
"""
...

View file

@ -0,0 +1,134 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema
class DatasetPurpose(StrEnum):
"""
Purpose of the dataset. Each purpose has a required input data schema.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
{
"question": "What is the capital of France?",
"answer": "Paris"
}
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
"""
post_training_messages = "post-training/messages"
eval_question_answer = "eval/question-answer"
eval_messages_answer = "eval/messages-answer"
# TODO: add more schemas here
class DatasetType(Enum):
"""
Type of the dataset source.
"""
uri = "uri"
rows = "rows"
@json_schema_type
class URIDataSource(BaseModel):
"""A dataset that can be obtained from a URI."""
type: Literal["uri"] = Field(default="uri", description="The type of data source")
uri: str = Field(
...,
description="The dataset can be obtained from a URI. E.g. 'https://mywebsite.com/mydata.jsonl', 'lsfs://mydata.jsonl', 'data:csv;base64,{base64_content}'",
)
@json_schema_type
class RowsDataSource(BaseModel):
"""A dataset stored in rows."""
type: Literal["rows"] = Field(default="rows", description="The type of data source")
rows: list[dict[str, Any]] = Field(
...,
description="The dataset is stored in rows. E.g. [{'messages': [{'role': 'user', 'content': 'Hello, world!'}, {'role': 'assistant', 'content': 'Hello, world!'}]}]",
)
DataSource = Annotated[
URIDataSource | RowsDataSource,
Field(discriminator="type"),
]
register_schema(DataSource, name="DataSource")
class CommonDatasetFields(BaseModel):
"""Common fields for a dataset."""
purpose: DatasetPurpose
source: DataSource
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this dataset",
)
@json_schema_type
class Dataset(CommonDatasetFields, Resource):
"""Dataset resource for storing and accessing training or evaluation data."""
type: Literal[ResourceType.dataset] = Field(
default=ResourceType.dataset, description="Type of resource, always 'dataset' for datasets"
)
@property
def dataset_id(self) -> str:
return self.identifier
@property
def provider_dataset_id(self) -> str | None:
return self.provider_resource_id
class DatasetInput(CommonDatasetFields, BaseModel):
"""Input parameters for dataset operations."""
dataset_id: str = Field(..., description="Unique identifier for the dataset")
class ListDatasetsResponse(BaseModel):
"""Response from listing datasets."""
data: list[Dataset] = Field(..., description="List of datasets")
@json_schema_type
class RegisterDatasetRequest(BaseModel):
"""Request model for registering a dataset."""
purpose: DatasetPurpose = Field(..., description="The purpose of the dataset")
source: DataSource = Field(..., description="The data source of the dataset")
metadata: dict[str, Any] | None = Field(default=None, description="The metadata for the dataset")
dataset_id: str | None = Field(
default=None, description="The ID of the dataset. If not provided, an ID will be generated"
)

View file

@ -0,0 +1,140 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated
from fastapi import Body, Depends, Request
from fastapi import Path as FastAPIPath
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .datasets_service import DatasetsService
from .models import Dataset, ListDatasetsResponse, RegisterDatasetRequest
def get_datasets_service(request: Request) -> DatasetsService:
"""Dependency to get the datasets service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.datasets not in impls:
raise ValueError("Datasets API implementation not found")
return impls[Api.datasets]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Datasets"],
responses=standard_responses,
)
router_v1beta = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1BETA}",
tags=["Datasets"],
responses=standard_responses,
)
@router.post(
"/datasets",
response_model=Dataset,
summary="Register a new dataset",
description="Register a new dataset",
deprecated=True,
)
@router_v1beta.post(
"/datasets",
response_model=Dataset,
summary="Register a new dataset",
description="Register a new dataset",
)
async def register_dataset(
body: RegisterDatasetRequest = Body(...),
svc: DatasetsService = Depends(get_datasets_service),
) -> Dataset:
"""Register a new dataset."""
return await svc.register_dataset(
purpose=body.purpose,
source=body.source,
metadata=body.metadata,
dataset_id=body.dataset_id,
)
@router.get(
"/datasets/{dataset_id:path}",
response_model=Dataset,
summary="Get a dataset by its ID",
description="Get a dataset by its ID",
deprecated=True,
)
@router_v1beta.get(
"/datasets/{{dataset_id:path}}",
response_model=Dataset,
summary="Get a dataset by its ID",
description="Get a dataset by its ID",
)
async def get_dataset(
dataset_id: Annotated[str, FastAPIPath(..., description="The ID of the dataset to get")],
svc: DatasetsService = Depends(get_datasets_service),
) -> Dataset:
"""Get a dataset by its ID."""
return await svc.get_dataset(dataset_id=dataset_id)
@router.get(
"/datasets",
response_model=ListDatasetsResponse,
summary="List all datasets",
description="List all datasets",
deprecated=True,
)
@router_v1beta.get(
"/datasets",
response_model=ListDatasetsResponse,
summary="List all datasets",
description="List all datasets",
)
async def list_datasets(svc: DatasetsService = Depends(get_datasets_service)) -> ListDatasetsResponse:
"""List all datasets."""
return await svc.list_datasets()
@router.delete(
"/datasets/{dataset_id:path}",
response_model=None,
status_code=204,
summary="Unregister a dataset by its ID",
description="Unregister a dataset by its ID",
deprecated=True,
)
@router_v1beta.delete(
"/datasets/{{dataset_id:path}}",
response_model=None,
status_code=204,
summary="Unregister a dataset by its ID",
description="Unregister a dataset by its ID",
)
async def unregister_dataset(
dataset_id: Annotated[str, FastAPIPath(..., description="The ID of the dataset to unregister")],
svc: DatasetsService = Depends(get_datasets_service),
) -> None:
"""Unregister a dataset by its ID."""
await svc.unregister_dataset(dataset_id=dataset_id)
# For backward compatibility with the router registry system
def create_datasets_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Datasets API (legacy compatibility)."""
main_router = APIRouter()
main_router.include_router(router)
main_router.include_router(router_v1beta)
return main_router
# Register the router factory
register_router(Api.datasets, create_datasets_router)

View file

@ -130,23 +130,22 @@ class Api(Enum, metaclass=DynamicApiMeta):
# built-in API # built-in API
inspect = "inspect" inspect = "inspect"
synthetic_data_generation = "synthetic_data_generation"
@json_schema_type @json_schema_type
class Error(BaseModel): class Error(BaseModel):
""" """Error response from the API. Roughly follows RFC 7807."""
Error response from the API. Roughly follows RFC 7807.
:param status: HTTP status code status: int = Field(..., description="HTTP status code")
:param title: Error title, a short summary of the error which is invariant for an error type title: str = Field(
:param detail: Error detail, a longer human-readable description of the error ..., description="Error title, a short summary of the error which is invariant for an error type"
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error )
""" detail: str = Field(..., description="Error detail, a longer human-readable description of the error")
instance: str | None = Field(
status: int None,
title: str description="(Optional) A URL which can be used to retrieve more information about the specific occurrence of the error",
detail: str )
instance: str | None = None
class ExternalApiSpec(BaseModel): class ExternalApiSpec(BaseModel):

View file

@ -4,4 +4,28 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .eval import * # Import routes to trigger router registration
from . import routes # noqa: F401
from .eval_service import EvalService
from .models import (
AgentCandidate,
BenchmarkConfig,
EvalCandidate,
EvaluateResponse,
EvaluateRowsRequest,
ModelCandidate,
)
# Backward compatibility - export Eval as alias for EvalService
Eval = EvalService
__all__ = [
"Eval",
"EvalService",
"ModelCandidate",
"AgentCandidate",
"EvalCandidate",
"BenchmarkConfig",
"EvaluateResponse",
"EvaluateRowsRequest",
]

View file

@ -1,150 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class ModelCandidate(BaseModel):
"""A model candidate for evaluation.
:param model: The model ID to evaluate.
:param sampling_params: The sampling parameters for the model.
:param system_message: (Optional) The system message providing instructions or context to the model.
"""
type: Literal["model"] = "model"
model: str
sampling_params: SamplingParams
system_message: SystemMessage | None = None
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
config: AgentConfig
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate")
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
:param eval_candidate: The candidate to evaluate.
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
"""
eval_candidate: EvalCandidate
scoring_params: dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
num_examples: int | None = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
# we could optinally add any specific dataset config here
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation.
:param generations: The generations from the evaluation.
:param scores: The scores from the evaluation.
"""
generations: list[dict[str, Any]]
# each key in the dict is a scoring function name
scores: dict[str, ScoringResult]
class Eval(Protocol):
"""Evaluations
Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:returns: The job that was created to run the evaluation.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:returns: EvaluateResponse object containing generations and scores.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:returns: The status of the evaluation job.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to cancel.
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET", level=LLAMA_STACK_API_V1ALPHA
)
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:returns: The result of the job.
"""
...

View file

@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from llama_stack.apis.common.job_types import Job
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from .models import BenchmarkConfig, EvaluateResponse
@runtime_checkable
@trace_protocol
class EvalService(Protocol):
"""Evaluations
Llama Stack Evaluation API for running evaluations on model and agent candidates."""
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation on a benchmark."""
...
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark."""
...
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job."""
...
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job."""
...
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job."""
...

View file

@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring.models import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type
class ModelCandidate(BaseModel):
"""A model candidate for evaluation."""
type: Literal["model"] = Field(default="model", description="The type of candidate.")
model: str = Field(..., description="The model ID to evaluate.")
sampling_params: SamplingParams = Field(..., description="The sampling parameters for the model.")
system_message: SystemMessage | None = Field(
default=None, description="The system message providing instructions or context to the model."
)
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation."""
type: Literal["agent"] = Field(default="agent", description="The type of candidate.")
config: AgentConfig = Field(..., description="The configuration for the agent candidate.")
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate")
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation."""
eval_candidate: EvalCandidate = Field(..., description="The candidate to evaluate.")
scoring_params: dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run.",
default_factory=dict,
)
num_examples: int | None = Field(
description="The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated.",
default=None,
)
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation."""
generations: list[dict[str, Any]] = Field(..., description="The generations from the evaluation.")
scores: dict[str, ScoringResult] = Field(
..., description="The scores from the evaluation. Each key in the dict is a scoring function name."
)
@json_schema_type
class EvaluateRowsRequest(BaseModel):
"""Request model for evaluating rows."""
input_rows: list[dict[str, Any]] = Field(..., description="The rows to evaluate.")
scoring_functions: list[str] = Field(..., description="The scoring functions to use for the evaluation.")
benchmark_config: BenchmarkConfig = Field(..., description="The configuration for the benchmark.")

View file

@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated
from fastapi import Body, Depends, Request
from fastapi import Path as FastAPIPath
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .eval_service import EvalService
from .models import BenchmarkConfig, EvaluateResponse, EvaluateRowsRequest
def get_eval_service(request: Request) -> EvalService:
"""Dependency to get the eval service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.eval not in impls:
raise ValueError("Eval API implementation not found")
return impls[Api.eval]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Eval"],
responses=standard_responses,
)
router_v1alpha = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
tags=["Eval"],
responses=standard_responses,
)
@router.post(
"/eval/benchmarks/{benchmark_id}/jobs",
response_model=Job,
summary="Run an evaluation on a benchmark",
description="Run an evaluation on a benchmark",
deprecated=True,
)
@router_v1alpha.post(
"/eval/benchmarks/{{benchmark_id}}/jobs",
response_model=Job,
summary="Run an evaluation on a benchmark",
description="Run an evaluation on a benchmark",
)
async def run_eval(
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to run the evaluation on")],
benchmark_config: BenchmarkConfig = Body(...),
svc: EvalService = Depends(get_eval_service),
) -> Job:
"""Run an evaluation on a benchmark."""
return await svc.run_eval(benchmark_id=benchmark_id, benchmark_config=benchmark_config)
@router.post(
"/eval/benchmarks/{benchmark_id}/evaluations",
response_model=EvaluateResponse,
summary="Evaluate a list of rows on a benchmark",
description="Evaluate a list of rows on a benchmark",
deprecated=True,
)
@router_v1alpha.post(
"/eval/benchmarks/{{benchmark_id}}/evaluations",
response_model=EvaluateResponse,
summary="Evaluate a list of rows on a benchmark",
description="Evaluate a list of rows on a benchmark",
)
async def evaluate_rows(
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to run the evaluation on")],
body: EvaluateRowsRequest = Body(...),
svc: EvalService = Depends(get_eval_service),
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark."""
return await svc.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=body.input_rows,
scoring_functions=body.scoring_functions,
benchmark_config=body.benchmark_config,
)
@router.get(
"/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
response_model=Job,
summary="Get the status of a job",
description="Get the status of a job",
deprecated=True,
)
@router_v1alpha.get(
"/eval/benchmarks/{{benchmark_id}}/jobs/{{job_id}}",
response_model=Job,
summary="Get the status of a job",
description="Get the status of a job",
)
async def job_status(
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to run the evaluation on")],
job_id: Annotated[str, FastAPIPath(..., description="The ID of the job to get the status of")],
svc: EvalService = Depends(get_eval_service),
) -> Job:
"""Get the status of a job."""
return await svc.job_status(benchmark_id=benchmark_id, job_id=job_id)
@router.delete(
"/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
response_model=None,
status_code=204,
summary="Cancel a job",
description="Cancel a job",
deprecated=True,
)
@router_v1alpha.delete(
"/eval/benchmarks/{{benchmark_id}}/jobs/{{job_id}}",
response_model=None,
status_code=204,
summary="Cancel a job",
description="Cancel a job",
)
async def job_cancel(
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to run the evaluation on")],
job_id: Annotated[str, FastAPIPath(..., description="The ID of the job to cancel")],
svc: EvalService = Depends(get_eval_service),
) -> None:
"""Cancel a job."""
await svc.job_cancel(benchmark_id=benchmark_id, job_id=job_id)
@router.get(
"/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result",
response_model=EvaluateResponse,
summary="Get the result of a job",
description="Get the result of a job",
deprecated=True,
)
@router_v1alpha.get(
"/eval/benchmarks/{{benchmark_id}}/jobs/{{job_id}}/result",
response_model=EvaluateResponse,
summary="Get the result of a job",
description="Get the result of a job",
)
async def job_result(
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to run the evaluation on")],
job_id: Annotated[str, FastAPIPath(..., description="The ID of the job to get the result of")],
svc: EvalService = Depends(get_eval_service),
) -> EvaluateResponse:
"""Get the result of a job."""
return await svc.job_result(benchmark_id=benchmark_id, job_id=job_id)
# For backward compatibility with the router registry system
def create_eval_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Eval API (legacy compatibility)."""
main_router = APIRouter()
main_router.include_router(router)
main_router.include_router(router_v1alpha)
return main_router
# Register the router factory
register_router(Api.eval, create_eval_router)

View file

@ -4,4 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .files import * # Import routes to trigger router registration
from . import routes # noqa: F401
from .files_service import FileService
from .models import (
ExpiresAfter,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
)
# Backward compatibility - export Files as alias for FileService
Files = FileService
__all__ = [
"Files",
"FileService",
"OpenAIFileObject",
"OpenAIFilePurpose",
"ExpiresAfter",
"ListOpenAIFileResponse",
"OpenAIFileDeleteResponse",
]

View file

@ -1,194 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Annotated, ClassVar, Literal, Protocol, runtime_checkable
from fastapi import File, Form, Response, UploadFile
from pydantic import BaseModel, Field
from llama_stack.apis.common.responses import Order
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
# OpenAI Files API Models
class OpenAIFilePurpose(StrEnum):
"""
Valid purpose values for OpenAI Files API.
"""
ASSISTANTS = "assistants"
BATCH = "batch"
# TODO: Add other purposes as needed
@json_schema_type
class OpenAIFileObject(BaseModel):
"""
OpenAI File object as defined in the OpenAI Files API.
:param object: The object type, which is always "file"
:param id: The file identifier, which can be referenced in the API endpoints
:param bytes: The size of the file, in bytes
:param created_at: The Unix timestamp (in seconds) for when the file was created
:param expires_at: The Unix timestamp (in seconds) for when the file expires
:param filename: The name of the file
:param purpose: The intended purpose of the file
"""
object: Literal["file"] = "file"
id: str
bytes: int
created_at: int
expires_at: int
filename: str
purpose: OpenAIFilePurpose
@json_schema_type
class ExpiresAfter(BaseModel):
"""
Control expiration of uploaded files.
Params:
- anchor, must be "created_at"
- seconds, must be int between 3600 and 2592000 (1 hour to 30 days)
"""
MIN: ClassVar[int] = 3600 # 1 hour
MAX: ClassVar[int] = 2592000 # 30 days
anchor: Literal["created_at"]
seconds: int = Field(..., ge=3600, le=2592000)
@json_schema_type
class ListOpenAIFileResponse(BaseModel):
"""
Response for listing files in OpenAI Files API.
:param data: List of file objects
:param has_more: Whether there are more files available beyond this page
:param first_id: ID of the first file in the list for pagination
:param last_id: ID of the last file in the list for pagination
:param object: The object type, which is always "list"
"""
data: list[OpenAIFileObject]
has_more: bool
first_id: str
last_id: str
object: Literal["list"] = "list"
@json_schema_type
class OpenAIFileDeleteResponse(BaseModel):
"""
Response for deleting a file in OpenAI Files API.
:param id: The file identifier that was deleted
:param object: The object type, which is always "file"
:param deleted: Whether the file was successfully deleted
"""
id: str
object: Literal["file"] = "file"
deleted: bool
@runtime_checkable
@trace_protocol
class Files(Protocol):
"""Files
This API is used to upload documents that can be used with other Llama Stack APIs.
"""
# OpenAI Files API Endpoints
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
) -> OpenAIFileObject:
"""Upload file.
Upload a file that can be used across various endpoints.
The file upload should be a multipart form request with:
- file: The File object (not file name) to be uploaded.
- purpose: The intended purpose of the uploaded file.
- expires_after: Optional form values describing expiration for the file.
:param file: The uploaded file object containing content and metadata (filename, content_type, etc.).
:param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune").
:param expires_after: Optional form values describing expiration for the file.
:returns: An OpenAIFileObject representing the uploaded file.
"""
...
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_files(
self,
after: str | None = None,
limit: int | None = 10000,
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
"""List files.
Returns a list of files that belong to the user's organization.
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 10,000, and the default is 10,000.
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
:param purpose: Only return files with the given purpose.
:returns: An ListOpenAIFileResponse containing the list of files.
"""
...
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_file(
self,
file_id: str,
) -> OpenAIFileObject:
"""Retrieve file.
Returns information about a specific file.
:param file_id: The ID of the file to use for this request.
:returns: An OpenAIFileObject containing file information.
"""
...
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_file(
self,
file_id: str,
) -> OpenAIFileDeleteResponse:
"""Delete file.
:param file_id: The ID of the file to use for this request.
:returns: An OpenAIFileDeleteResponse indicating successful deletion.
"""
...
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_file_content(
self,
file_id: str,
) -> Response:
"""Retrieve file content.
Returns the contents of the specified file.
:param file_id: The ID of the file to use for this request.
:returns: The raw file content as a binary response.
"""
...

View file

@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Protocol, runtime_checkable
from fastapi import File, Form, Response, UploadFile
from llama_stack.apis.common.responses import Order
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from .models import (
ExpiresAfter,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
)
@runtime_checkable
@trace_protocol
class FileService(Protocol):
"""Files
This API is used to upload documents that can be used with other Llama Stack APIs.
"""
# OpenAI Files API Endpoints
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
) -> OpenAIFileObject:
"""Upload file."""
...
async def openai_list_files(
self,
after: str | None = None,
limit: int | None = 10000,
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
"""List files."""
...
async def openai_retrieve_file(
self,
file_id: str,
) -> OpenAIFileObject:
"""Retrieve file."""
...
async def openai_delete_file(
self,
file_id: str,
) -> OpenAIFileDeleteResponse:
"""Delete file."""
...
async def openai_retrieve_file_content(
self,
file_id: str,
) -> Response:
"""Retrieve file content."""
...

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import ClassVar, Literal
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class OpenAIFilePurpose(StrEnum):
"""
Valid purpose values for OpenAI Files API.
"""
ASSISTANTS = "assistants"
BATCH = "batch"
# TODO: Add other purposes as needed
@json_schema_type
class OpenAIFileObject(BaseModel):
"""OpenAI File object as defined in the OpenAI Files API."""
object: Literal["file"] = Field(default="file", description="The object type, which is always 'file'.")
id: str = Field(..., description="The file identifier, which can be referenced in the API endpoints.")
bytes: int = Field(..., description="The size of the file, in bytes.")
created_at: int = Field(..., description="The Unix timestamp (in seconds) for when the file was created.")
expires_at: int = Field(..., description="The Unix timestamp (in seconds) for when the file expires.")
filename: str = Field(..., description="The name of the file.")
purpose: OpenAIFilePurpose = Field(..., description="The intended purpose of the file.")
@json_schema_type
class ExpiresAfter(BaseModel):
"""Control expiration of uploaded files."""
MIN: ClassVar[int] = 3600 # 1 hour
MAX: ClassVar[int] = 2592000 # 30 days
anchor: Literal["created_at"] = Field(..., description="Anchor must be 'created_at'.")
seconds: int = Field(..., ge=3600, le=2592000, description="Seconds between 3600 and 2592000 (1 hour to 30 days).")
@json_schema_type
class ListOpenAIFileResponse(BaseModel):
"""Response for listing files in OpenAI Files API."""
data: list[OpenAIFileObject] = Field(..., description="List of file objects.")
has_more: bool = Field(..., description="Whether there are more files available beyond this page.")
first_id: str = Field(..., description="ID of the first file in the list for pagination.")
last_id: str = Field(..., description="ID of the last file in the list for pagination.")
object: Literal["list"] = Field(default="list", description="The object type, which is always 'list'.")
@json_schema_type
class OpenAIFileDeleteResponse(BaseModel):
"""Response for deleting a file in OpenAI Files API."""
id: str = Field(..., description="The file identifier that was deleted.")
object: Literal["file"] = Field(default="file", description="The object type, which is always 'file'.")
deleted: bool = Field(..., description="Whether the file was successfully deleted.")

View file

@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated
from fastapi import Depends, File, Form, Query, Request, Response, UploadFile
from fastapi import Path as FastAPIPath
from llama_stack.apis.common.responses import Order
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .files_service import FileService
from .models import (
ExpiresAfter,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
)
def get_file_service(request: Request) -> FileService:
"""Dependency to get the file service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.files not in impls:
raise ValueError("Files API implementation not found")
return impls[Api.files]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Files"],
responses=standard_responses,
)
@router.post(
"/files",
response_model=OpenAIFileObject,
summary="Upload file.",
description="Upload a file that can be used across various endpoints.",
)
async def openai_upload_file(
file: Annotated[UploadFile, File(..., description="The File object to be uploaded.")],
purpose: Annotated[OpenAIFilePurpose, Form(..., description="The intended purpose of the uploaded file.")],
expires_after: Annotated[
ExpiresAfter | None, Form(description="Optional form values describing expiration for the file.")
] = None,
svc: FileService = Depends(get_file_service),
) -> OpenAIFileObject:
"""Upload a file."""
return await svc.openai_upload_file(file=file, purpose=purpose, expires_after=expires_after)
@router.get(
"/files",
response_model=ListOpenAIFileResponse,
summary="List files.",
description="Returns a list of files that belong to the user's organization.",
)
async def openai_list_files(
after: str | None = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
limit: int | None = Query(
10000,
description="A limit on the number of objects to be returned. Limit can range between 1 and 10,000, and the default is 10,000.",
),
order: Order | None = Query(
Order.desc,
description="Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.",
),
purpose: OpenAIFilePurpose | None = Query(None, description="Only return files with the given purpose."),
svc: FileService = Depends(get_file_service),
) -> ListOpenAIFileResponse:
"""List files."""
return await svc.openai_list_files(after=after, limit=limit, order=order, purpose=purpose)
@router.get(
"/files/{file_id}",
response_model=OpenAIFileObject,
summary="Retrieve file.",
description="Returns information about a specific file.",
)
async def openai_retrieve_file(
file_id: Annotated[str, FastAPIPath(..., description="The ID of the file to use for this request.")],
svc: FileService = Depends(get_file_service),
) -> OpenAIFileObject:
"""Retrieve file information."""
return await svc.openai_retrieve_file(file_id=file_id)
@router.delete(
"/files/{file_id}",
response_model=OpenAIFileDeleteResponse,
summary="Delete file.",
description="Delete a file.",
)
async def openai_delete_file(
file_id: Annotated[str, FastAPIPath(..., description="The ID of the file to use for this request.")],
svc: FileService = Depends(get_file_service),
) -> OpenAIFileDeleteResponse:
"""Delete a file."""
return await svc.openai_delete_file(file_id=file_id)
@router.get(
"/files/{file_id}/content",
response_class=Response,
summary="Retrieve file content.",
description="Returns the contents of the specified file.",
)
async def openai_retrieve_file_content(
file_id: Annotated[str, FastAPIPath(..., description="The ID of the file to use for this request.")],
svc: FileService = Depends(get_file_service),
) -> Response:
"""Retrieve file content."""
return await svc.openai_retrieve_file_content(file_id=file_id)
# For backward compatibility with the router registry system
def create_files_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Files API (legacy compatibility)."""
return router
# Register the router factory
register_router(Api.files, create_files_router)

View file

@ -4,4 +4,206 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .inference import * # Import routes to trigger router registration
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.responses import Order
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat
from . import routes # noqa: F401
from .inference_service import InferenceService
from .models import (
Bf16QuantizationConfig,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Fp8QuantizationConfig,
GrammarResponseFormat,
GreedySamplingStrategy,
Int4QuantizationConfig,
JsonSchemaResponseFormat,
ListOpenAIChatCompletionResponse,
LogProbConfig,
Message,
ModelStore,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionMessageContent,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionTextOnlyMessageContent,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChatCompletionUsage,
OpenAIChatCompletionUsageCompletionTokensDetails,
OpenAIChatCompletionUsagePromptTokensDetails,
OpenAIChoice,
OpenAIChoiceDelta,
OpenAIChoiceLogprobs,
OpenAIChunkChoice,
OpenAICompletion,
OpenAICompletionChoice,
OpenAICompletionLogprobs,
OpenAICompletionRequestWithExtraBody,
OpenAICompletionWithInputMessages,
OpenAIDeveloperMessageParam,
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIFile,
OpenAIFileFile,
OpenAIImageURL,
OpenAIJSONSchema,
OpenAIMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatParam,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAITokenLogProb,
OpenAIToolMessageParam,
OpenAITopLogProb,
OpenAIUserMessageParam,
QuantizationConfig,
QuantizationType,
RerankData,
RerankResponse,
ResponseFormat,
ResponseFormatType,
SamplingParams,
SamplingStrategy,
SystemMessage,
SystemMessageBehavior,
TextTruncation,
TokenLogProbs,
ToolChoice,
ToolConfig,
ToolResponse,
ToolResponseMessage,
TopKSamplingStrategy,
TopPSamplingStrategy,
UserMessage,
)
# Backward compatibility - export Inference as alias for InferenceService
Inference = InferenceService
InferenceProvider = InferenceService
__all__ = [
"Inference",
"InferenceProvider",
"InferenceService",
"InterleavedContent",
"ModelStore",
"Order",
# Sampling
"SamplingParams",
"SamplingStrategy",
"GreedySamplingStrategy",
"TopPSamplingStrategy",
"TopKSamplingStrategy",
# Quantization
"QuantizationConfig",
"QuantizationType",
"Bf16QuantizationConfig",
"Fp8QuantizationConfig",
"Int4QuantizationConfig",
# Messages
"Message",
"UserMessage",
"SystemMessage",
"ToolResponseMessage",
"CompletionMessage",
# Tools
"BuiltinTool",
"ToolCall",
"ToolChoice",
"ToolConfig",
"ToolDefinition",
"ToolPromptFormat",
"ToolResponse",
# StopReason
"StopReason",
# Completion
"CompletionRequest",
"CompletionResponse",
"CompletionResponseStreamChunk",
# Chat Completion
"ChatCompletionRequest",
"ChatCompletionResponse",
"ChatCompletionResponseStreamChunk",
"ChatCompletionResponseEvent",
"ChatCompletionResponseEventType",
# Embeddings
"EmbeddingsResponse",
"EmbeddingTaskType",
"TextTruncation",
# Rerank
"RerankResponse",
"RerankData",
# Response Format
"ResponseFormat",
"ResponseFormatType",
"JsonSchemaResponseFormat",
"GrammarResponseFormat",
# Log Probs
"LogProbConfig",
"TokenLogProbs",
# System Message Behavior
"SystemMessageBehavior",
# OpenAI Models
"OpenAICompletion",
"OpenAICompletionRequestWithExtraBody",
"OpenAICompletionChoice",
"OpenAICompletionLogprobs",
"OpenAIChatCompletion",
"OpenAIChatCompletionRequestWithExtraBody",
"OpenAIChatCompletionChunk",
"OpenAIChatCompletionUsage",
"OpenAIChatCompletionUsageCompletionTokensDetails",
"OpenAIChatCompletionUsagePromptTokensDetails",
"OpenAIChoice",
"OpenAIChoiceDelta",
"OpenAIChoiceLogprobs",
"OpenAIChunkChoice",
"OpenAIMessageParam",
"OpenAIUserMessageParam",
"OpenAISystemMessageParam",
"OpenAIAssistantMessageParam",
"OpenAIToolMessageParam",
"OpenAIDeveloperMessageParam",
"OpenAIChatCompletionContentPartParam",
"OpenAIChatCompletionContentPartTextParam",
"OpenAIChatCompletionContentPartImageParam",
"OpenAIChatCompletionMessageContent",
"OpenAIChatCompletionTextOnlyMessageContent",
"OpenAIChatCompletionToolCall",
"OpenAIChatCompletionToolCallFunction",
"OpenAIEmbeddingsRequestWithExtraBody",
"OpenAIEmbeddingsResponse",
"OpenAIEmbeddingData",
"OpenAIEmbeddingUsage",
"OpenAIResponseFormatParam",
"OpenAIResponseFormatText",
"OpenAIResponseFormatJSONSchema",
"OpenAIResponseFormatJSONObject",
"OpenAIJSONSchema",
"OpenAIImageURL",
"OpenAIFile",
"OpenAIFileFile",
"OpenAITokenLogProb",
"OpenAITopLogProb",
"OpenAICompletionWithInputMessages",
"ListOpenAIChatCompletionResponse",
]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Annotated, Protocol, runtime_checkable
from fastapi import Body
from llama_stack.apis.common.responses import Order
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from .models import (
ListOpenAIChatCompletionResponse,
ModelStore,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAICompletionWithInputMessages,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
RerankResponse,
)
@runtime_checkable
@trace_protocol
class InferenceService(Protocol):
"""
This protocol defines the interface that should be implemented by all inference providers.
Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models reorder the documents based on their relevance to a query.
"""
API_NAMESPACE: str = "Inference"
model_store: ModelStore | None = None
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query."""
...
async def openai_completion(
self,
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
) -> OpenAICompletion:
"""Create completion."""
...
async def openai_chat_completion(
self,
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Create chat completions."""
...
async def openai_embeddings(
self,
params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)],
) -> OpenAIEmbeddingsResponse:
"""Create embeddings."""
...
async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 20,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
"""List chat completions."""
...
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
"""Get chat completion."""
...

View file

@ -0,0 +1,818 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.core.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import json_schema_type, register_schema
register_schema(ToolCall)
register_schema(ToolDefinition)
@json_schema_type
class GreedySamplingStrategy(BaseModel):
"""Greedy sampling strategy that selects the highest probability token at each step."""
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
"""Top-p (nucleus) sampling strategy that samples from the smallest set of tokens with cumulative probability >= p."""
type: Literal["top_p"] = "top_p"
temperature: float | None = Field(..., gt=0.0)
top_p: float | None = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
"""Top-k sampling strategy that restricts sampling to the k most likely tokens."""
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = Annotated[
GreedySamplingStrategy | TopPSamplingStrategy | TopKSamplingStrategy,
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters."""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: int | None = None
repetition_penalty: float | None = 1.0
stop: list[str] | None = None
class LogProbConfig(BaseModel):
"""Configuration for log probability generation."""
top_k: int | None = 0
class QuantizationType(Enum):
"""Type of model quantization to run inference with."""
bf16 = "bf16"
fp8_mixed = "fp8_mixed"
int4_mixed = "int4_mixed"
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
"""Configuration for 8-bit floating point quantization."""
type: Literal["fp8_mixed"] = "fp8_mixed"
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
"""Configuration for BFloat16 precision (typically no quantization)."""
type: Literal["bf16"] = "bf16"
@json_schema_type
class Int4QuantizationConfig(BaseModel):
"""Configuration for 4-bit integer quantization."""
type: Literal["int4_mixed"] = "int4_mixed"
scheme: str | None = "int4_weight_int8_dynamic_activation"
QuantizationConfig = Annotated[
Bf16QuantizationConfig | Fp8QuantizationConfig | Int4QuantizationConfig,
Field(discriminator="type"),
]
@json_schema_type
class UserMessage(BaseModel):
"""A message from the user in a chat conversation."""
role: Literal["user"] = "user"
content: InterleavedContent
context: InterleavedContent | None = None
@json_schema_type
class SystemMessage(BaseModel):
"""A system message providing instructions or context to the model."""
role: Literal["system"] = "system"
content: InterleavedContent
@json_schema_type
class ToolResponseMessage(BaseModel):
"""A message representing the result of a tool invocation."""
role: Literal["tool"] = "tool"
call_id: str
content: InterleavedContent
@json_schema_type
class CompletionMessage(BaseModel):
"""A message containing the model's (assistant) response in a chat conversation.
- `StopReason.end_of_turn`: The model finished generating the entire response.
- `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response.
- `StopReason.out_of_tokens`: The model ran out of token budget.
"""
role: Literal["assistant"] = "assistant"
content: InterleavedContent
stop_reason: StopReason
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
Message = Annotated[
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
Field(discriminator="role"),
]
register_schema(Message, name="Message")
@json_schema_type
class ToolResponse(BaseModel):
"""Response from a tool invocation."""
call_id: str
tool_name: BuiltinTool | str
content: InterleavedContent
metadata: dict[str, Any] | None = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class ToolChoice(Enum):
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."""
auto = "auto"
required = "required"
none = "none"
@json_schema_type
class TokenLogProbs(BaseModel):
"""Log probabilities for generated tokens."""
logprobs_by_token: dict[str, float]
class ChatCompletionResponseEventType(Enum):
"""Types of events that can occur during chat completion."""
start = "start"
complete = "complete"
progress = "progress"
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""An event during chat completion generation."""
event_type: ChatCompletionResponseEventType
delta: ContentDelta
logprobs: list[TokenLogProbs] | None = None
stop_reason: StopReason | None = None
class ResponseFormatType(StrEnum):
"""Types of formats for structured (guided) decoding."""
json_schema = "json_schema"
grammar = "grammar"
@json_schema_type
class JsonSchemaResponseFormat(BaseModel):
"""Configuration for JSON schema-guided response generation."""
type: Literal[ResponseFormatType.json_schema] = ResponseFormatType.json_schema
json_schema: dict[str, Any]
@json_schema_type
class GrammarResponseFormat(BaseModel):
"""Configuration for grammar-guided response generation."""
type: Literal[ResponseFormatType.grammar] = ResponseFormatType.grammar
bnf: dict[str, Any]
ResponseFormat = Annotated[
JsonSchemaResponseFormat | GrammarResponseFormat,
Field(discriminator="type"),
]
register_schema(ResponseFormat, name="ResponseFormat")
# This is an internally used class
class CompletionRequest(BaseModel):
content: InterleavedContent
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
response_format: ResponseFormat | None = None
stream: bool | None = False
logprobs: LogProbConfig | None = None
@json_schema_type
class CompletionResponse(MetricResponseMixin):
"""Response from a completion request."""
content: str
stop_reason: StopReason
logprobs: list[TokenLogProbs] | None = None
@json_schema_type
class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response."""
delta: str
stop_reason: StopReason | None = None
logprobs: list[TokenLogProbs] | None = None
class SystemMessageBehavior(Enum):
"""Config for how to override the default system prompt.
https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt-
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""
append = "append"
replace = "replace"
@json_schema_type
class ToolConfig(BaseModel):
"""Configuration for tool use.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
- `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt.
- `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
def model_post_init(self, __context: Any) -> None:
if isinstance(self.tool_choice, str):
try:
self.tool_choice = ToolChoice[self.tool_choice]
except KeyError:
pass
# This is an internally used class
@json_schema_type
class ChatCompletionRequest(BaseModel):
messages: list[Message]
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
response_format: ResponseFormat | None = None
stream: bool | None = False
logprobs: LogProbConfig | None = None
@json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed chat completion response."""
event: ChatCompletionResponseEvent
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request."""
completion_message: CompletionMessage
logprobs: list[TokenLogProbs] | None = None
@json_schema_type
class EmbeddingsResponse(BaseModel):
"""Response containing generated embeddings."""
embeddings: list[list[float]]
@json_schema_type
class RerankData(BaseModel):
"""A single rerank result from a reranking response."""
index: int
relevance_score: float
@json_schema_type
class RerankResponse(BaseModel):
"""Response from a reranking request."""
data: list[RerankData]
@json_schema_type
class OpenAIChatCompletionContentPartTextParam(BaseModel):
"""Text content part for OpenAI-compatible chat completion messages."""
type: Literal["text"] = "text"
text: str
@json_schema_type
class OpenAIImageURL(BaseModel):
"""Image URL specification for OpenAI-compatible chat completion messages."""
url: str
detail: str | None = None
@json_schema_type
class OpenAIChatCompletionContentPartImageParam(BaseModel):
"""Image content part for OpenAI-compatible chat completion messages."""
type: Literal["image_url"] = "image_url"
image_url: OpenAIImageURL
@json_schema_type
class OpenAIFileFile(BaseModel):
file_id: str | None = None
filename: str | None = None
@json_schema_type
class OpenAIFile(BaseModel):
type: Literal["file"] = "file"
file: OpenAIFileFile
OpenAIChatCompletionContentPartParam = Annotated[
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
Field(discriminator="type"),
]
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
@json_schema_type
class OpenAIUserMessageParam(BaseModel):
"""A message from the user in an OpenAI-compatible chat completion request."""
role: Literal["user"] = "user"
content: OpenAIChatCompletionMessageContent
name: str | None = None
@json_schema_type
class OpenAISystemMessageParam(BaseModel):
"""A system message providing instructions or context to the model."""
role: Literal["system"] = "system"
content: OpenAIChatCompletionTextOnlyMessageContent
name: str | None = None
@json_schema_type
class OpenAIChatCompletionToolCallFunction(BaseModel):
"""Function call details for OpenAI-compatible tool calls."""
name: str | None = None
arguments: str | None = None
@json_schema_type
class OpenAIChatCompletionToolCall(BaseModel):
"""Tool call specification for OpenAI-compatible chat completion responses."""
index: int | None = None
id: str | None = None
type: Literal["function"] = "function"
function: OpenAIChatCompletionToolCallFunction | None = None
@json_schema_type
class OpenAIAssistantMessageParam(BaseModel):
"""A message containing the model's (assistant) response in an OpenAI-compatible chat completion request."""
role: Literal["assistant"] = "assistant"
content: OpenAIChatCompletionTextOnlyMessageContent | None = None
name: str | None = None
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
@json_schema_type
class OpenAIToolMessageParam(BaseModel):
"""A message representing the result of a tool invocation in an OpenAI-compatible chat completion request."""
role: Literal["tool"] = "tool"
tool_call_id: str
content: OpenAIChatCompletionTextOnlyMessageContent
@json_schema_type
class OpenAIDeveloperMessageParam(BaseModel):
"""A message from the developer in an OpenAI-compatible chat completion request."""
role: Literal["developer"] = "developer"
content: OpenAIChatCompletionTextOnlyMessageContent
name: str | None = None
OpenAIMessageParam = Annotated[
OpenAIUserMessageParam
| OpenAISystemMessageParam
| OpenAIAssistantMessageParam
| OpenAIToolMessageParam
| OpenAIDeveloperMessageParam,
Field(discriminator="role"),
]
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
@json_schema_type
class OpenAIResponseFormatText(BaseModel):
"""Text response format for OpenAI-compatible chat completion requests."""
type: Literal["text"] = "text"
@json_schema_type
class OpenAIJSONSchema(TypedDict, total=False):
"""JSON schema specification for OpenAI-compatible structured response format."""
name: str
description: str | None
strict: bool | None
# Pydantic BaseModel cannot be used with a schema param, since it already
# has one. And, we don't want to alias here because then have to handle
# that alias when converting to OpenAI params. So, to support schema,
# we use a TypedDict.
schema: dict[str, Any] | None
@json_schema_type
class OpenAIResponseFormatJSONSchema(BaseModel):
"""JSON schema response format for OpenAI-compatible chat completion requests."""
type: Literal["json_schema"] = "json_schema"
json_schema: OpenAIJSONSchema
@json_schema_type
class OpenAIResponseFormatJSONObject(BaseModel):
"""JSON object response format for OpenAI-compatible chat completion requests."""
type: Literal["json_object"] = "json_object"
OpenAIResponseFormatParam = Annotated[
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
Field(discriminator="type"),
]
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
@json_schema_type
class OpenAITopLogProb(BaseModel):
"""The top log probability for a token from an OpenAI-compatible chat completion response.
:token: The token
:bytes: (Optional) The bytes for the token
:logprob: The log probability of the token
"""
token: str
bytes: list[int] | None = None
logprob: float
@json_schema_type
class OpenAITokenLogProb(BaseModel):
"""The log probability for a token from an OpenAI-compatible chat completion response.
:token: The token
:bytes: (Optional) The bytes for the token
:logprob: The log probability of the token
:top_logprobs: The top log probabilities for the token
"""
token: str
bytes: list[int] | None = None
logprob: float
top_logprobs: list[OpenAITopLogProb]
@json_schema_type
class OpenAIChoiceLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response."""
content: list[OpenAITokenLogProb] | None = None
refusal: list[OpenAITokenLogProb] | None = None
@json_schema_type
class OpenAIChoiceDelta(BaseModel):
"""A delta from an OpenAI-compatible chat completion streaming response."""
content: str | None = None
refusal: str | None = None
role: str | None = None
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
reasoning_content: str | None = None
@json_schema_type
class OpenAIChunkChoice(BaseModel):
"""A chunk choice from an OpenAI-compatible chat completion streaming response."""
delta: OpenAIChoiceDelta
finish_reason: str
index: int
logprobs: OpenAIChoiceLogprobs | None = None
@json_schema_type
class OpenAIChoice(BaseModel):
"""A choice from an OpenAI-compatible chat completion response."""
message: OpenAIMessageParam
finish_reason: str
index: int
logprobs: OpenAIChoiceLogprobs | None = None
class OpenAIChatCompletionUsageCompletionTokensDetails(BaseModel):
"""Token details for output tokens in OpenAI chat completion usage."""
reasoning_tokens: int | None = None
class OpenAIChatCompletionUsagePromptTokensDetails(BaseModel):
"""Token details for prompt tokens in OpenAI chat completion usage."""
cached_tokens: int | None = None
@json_schema_type
class OpenAIChatCompletionUsage(BaseModel):
"""Usage information for OpenAI chat completion."""
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: OpenAIChatCompletionUsagePromptTokensDetails | None = None
completion_tokens_details: OpenAIChatCompletionUsageCompletionTokensDetails | None = None
@json_schema_type
class OpenAIChatCompletion(BaseModel):
"""Response from an OpenAI-compatible chat completion request."""
id: str
choices: list[OpenAIChoice]
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
usage: OpenAIChatCompletionUsage | None = None
@json_schema_type
class OpenAIChatCompletionChunk(BaseModel):
"""Chunk from a streaming response to an OpenAI-compatible chat completion request."""
id: str
choices: list[OpenAIChunkChoice]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
usage: OpenAIChatCompletionUsage | None = None
@json_schema_type
class OpenAICompletionLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
:text_offset: (Optional) The offset of the token in the text
:token_logprobs: (Optional) The log probabilities for the tokens
:tokens: (Optional) The tokens
:top_logprobs: (Optional) The top log probabilities for the tokens
"""
text_offset: list[int] | None = None
token_logprobs: list[float] | None = None
tokens: list[str] | None = None
top_logprobs: list[dict[str, float]] | None = None
@json_schema_type
class OpenAICompletionChoice(BaseModel):
"""A choice from an OpenAI-compatible completion response.
:finish_reason: The reason the model stopped generating
:text: The text of the choice
:index: The index of the choice
:logprobs: (Optional) The log probabilities for the tokens in the choice
"""
finish_reason: str
text: str
index: int
logprobs: OpenAIChoiceLogprobs | None = None
@json_schema_type
class OpenAICompletion(BaseModel):
"""Response from an OpenAI-compatible completion request.
:id: The ID of the completion
:choices: List of choices
:created: The Unix timestamp in seconds when the completion was created
:model: The model that was used to generate the completion
:object: The object type, which will be "text_completion"
"""
id: str
choices: list[OpenAICompletionChoice]
created: int
model: str
object: Literal["text_completion"] = "text_completion"
@json_schema_type
class OpenAIEmbeddingData(BaseModel):
"""A single embedding data object from an OpenAI-compatible embeddings response."""
object: Literal["embedding"] = "embedding"
# TODO: consider dropping str and using openai.types.embeddings.Embedding instead of OpenAIEmbeddingData
embedding: list[float] | str
index: int
@json_schema_type
class OpenAIEmbeddingUsage(BaseModel):
"""Usage information for an OpenAI-compatible embeddings response."""
prompt_tokens: int
total_tokens: int
@json_schema_type
class OpenAIEmbeddingsResponse(BaseModel):
"""Response from an OpenAI-compatible embeddings request."""
object: Literal["list"] = "list"
data: list[OpenAIEmbeddingData]
model: str
usage: OpenAIEmbeddingUsage
class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ...
class TextTruncation(Enum):
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left."""
none = "none"
start = "start"
end = "end"
class EmbeddingTaskType(Enum):
"""How is the embedding being used? This is only supported by asymmetric embedding models."""
query = "query"
document = "document"
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
input_messages: list[OpenAIMessageParam]
@json_schema_type
class ListOpenAIChatCompletionResponse(BaseModel):
"""Response from listing OpenAI-compatible chat completions."""
data: list[OpenAICompletionWithInputMessages]
has_more: bool
first_id: str
last_id: str
object: Literal["list"] = "list"
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAICompletionRequestWithExtraBody(BaseModel, extra="allow"):
"""Request parameters for OpenAI-compatible completion endpoint."""
# Standard OpenAI completion parameters
model: str
prompt: str | list[str] | list[int] | list[list[int]]
best_of: int | None = None
echo: bool | None = None
frequency_penalty: float | None = None
logit_bias: dict[str, float] | None = None
logprobs: int | None = Field(None, ge=0, le=5)
max_tokens: int | None = None
n: int | None = None
presence_penalty: float | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
top_p: float | None = None
user: str | None = None
suffix: str | None = None
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAIChatCompletionRequestWithExtraBody(BaseModel, extra="allow"):
"""Request parameters for OpenAI-compatible chat completion endpoint."""
# Standard OpenAI chat completion parameters
model: str
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)]
frequency_penalty: float | None = None
function_call: str | dict[str, Any] | None = None
functions: list[dict[str, Any]] | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_completion_tokens: int | None = None
max_tokens: int | None = None
n: int | None = None
parallel_tool_calls: bool | None = None
presence_penalty: float | None = None
response_format: OpenAIResponseFormatParam | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
tool_choice: str | dict[str, Any] | None = None
tools: list[dict[str, Any]] | None = None
top_logprobs: int | None = None
top_p: float | None = None
user: str | None = None
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAIEmbeddingsRequestWithExtraBody(BaseModel, extra="allow"):
"""Request parameters for OpenAI-compatible embeddings endpoint."""
model: str
input: str | list[str]
encoding_format: str | None = "float"
dimensions: int | None = None
user: str | None = None

View file

@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Annotated
from fastapi import Body, Depends, Query, Request
from fastapi import Path as FastAPIPath
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from llama_stack.apis.common.responses import Order
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .inference_service import InferenceService
from .models import (
ListOpenAIChatCompletionResponse,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAICompletionWithInputMessages,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
RerankResponse,
)
def get_inference_service(request: Request) -> InferenceService:
"""Dependency to get the inference service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.inference not in impls:
raise ValueError("Inference API implementation not found")
return impls[Api.inference]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Inference"],
responses=standard_responses,
)
router_v1alpha = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
tags=["Inference"],
responses=standard_responses,
)
@router_v1alpha.post(
"/inference/rerank",
response_model=RerankResponse,
summary="Rerank a list of documents.",
description="Rerank a list of documents based on their relevance to a query.",
)
async def rerank(
model: str = Body(..., description="The identifier of the reranking model to use."),
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam = Body(
..., description="The search query to rank items against."
),
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam] = Body(
..., description="List of items to rerank."
),
max_num_results: int | None = Body(None, description="Maximum number of results to return. Default: returns all."),
svc: InferenceService = Depends(get_inference_service),
) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query."""
return await svc.rerank(model=model, query=query, items=items, max_num_results=max_num_results)
@router.post(
"/completions",
response_model=OpenAICompletion,
summary="Create completion.",
description="Create completion.",
)
async def openai_completion(
params: OpenAICompletionRequestWithExtraBody = Body(...),
svc: InferenceService = Depends(get_inference_service),
) -> OpenAICompletion:
"""Create completion."""
return await svc.openai_completion(params=params)
@router.post(
"/chat/completions",
summary="Create chat completions.",
description="Create chat completions.",
)
async def openai_chat_completion(
params: OpenAIChatCompletionRequestWithExtraBody = Body(...),
svc: InferenceService = Depends(get_inference_service),
):
"""Create chat completions."""
response = await svc.openai_chat_completion(params=params)
# Check if response is an async generator/iterator (streaming response)
# Check for __aiter__ method which all async iterators have
if hasattr(response, "__aiter__"):
# Convert async generator to SSE stream
async def sse_stream():
try:
async for chunk in response:
if isinstance(chunk, BaseModel):
data = chunk.model_dump_json()
else:
data = json.dumps(chunk)
yield f"data: {data}\n\n"
except Exception as e:
# Send error as SSE event
error_data = json.dumps({"error": {"message": str(e)}})
yield f"data: {error_data}\n\n"
return StreamingResponse(sse_stream(), media_type="text/event-stream")
return response
@router.post(
"/embeddings",
response_model=OpenAIEmbeddingsResponse,
summary="Create embeddings.",
description="Create embeddings.",
)
async def openai_embeddings(
params: OpenAIEmbeddingsRequestWithExtraBody = Body(...),
svc: InferenceService = Depends(get_inference_service),
) -> OpenAIEmbeddingsResponse:
"""Create embeddings."""
return await svc.openai_embeddings(params=params)
@router.get(
"/chat/completions",
response_model=ListOpenAIChatCompletionResponse,
summary="List chat completions.",
description="List chat completions.",
)
async def list_chat_completions(
after: str | None = Query(None, description="The ID of the last chat completion to return."),
limit: int | None = Query(20, description="The maximum number of chat completions to return."),
model: str | None = Query(None, description="The model to filter by."),
order: Order | None = Query(
Order.desc, description="The order to sort the chat completions by: 'asc' or 'desc'. Defaults to 'desc'."
),
svc: InferenceService = Depends(get_inference_service),
) -> ListOpenAIChatCompletionResponse:
"""List chat completions."""
return await svc.list_chat_completions(after=after, limit=limit, model=model, order=order)
@router.get(
"/chat/completions/{completion_id}",
response_model=OpenAICompletionWithInputMessages,
summary="Get chat completion.",
description="Get chat completion.",
)
async def get_chat_completion(
completion_id: Annotated[str, FastAPIPath(..., description="ID of the chat completion.")],
svc: InferenceService = Depends(get_inference_service),
) -> OpenAICompletionWithInputMessages:
"""Get chat completion."""
return await svc.get_chat_completion(completion_id=completion_id)
# For backward compatibility with the router registry system
def create_inference_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Inference API (legacy compatibility)."""
main_router = APIRouter()
main_router.include_router(router)
main_router.include_router(router_v1alpha)
return main_router
# Register the router factory
register_router(Api.inference, create_inference_router)

View file

@ -4,4 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .inspect import * # Import routes to trigger router registration
from . import routes # noqa: F401
from .inspect_service import InspectService
from .models import HealthInfo, ListRoutesResponse, RouteInfo, VersionInfo
# Backward compatibility - export Inspect as alias for InspectService
Inspect = InspectService
__all__ = ["Inspect", "InspectService", "ListRoutesResponse", "RouteInfo", "HealthInfo", "VersionInfo"]

View file

@ -1,102 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.version import (
LLAMA_STACK_API_V1,
)
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.schema_utils import json_schema_type, webmethod
# Valid values for the route filter parameter.
# Actual API levels: v1, v1alpha, v1beta (filters by level, excludes deprecated)
# Special filter value: "deprecated" (shows deprecated routes regardless of level)
ApiFilter = Literal["v1", "v1alpha", "v1beta", "deprecated"]
@json_schema_type
class RouteInfo(BaseModel):
"""Information about an API route including its path, method, and implementing providers.
:param route: The API endpoint path
:param method: HTTP method for the route
:param provider_types: List of provider types that implement this route
"""
route: str
method: str
provider_types: list[str]
@json_schema_type
class HealthInfo(BaseModel):
"""Health status information for the service.
:param status: Current health status of the service
"""
status: HealthStatus
@json_schema_type
class VersionInfo(BaseModel):
"""Version information for the service.
:param version: Version number of the service
"""
version: str
class ListRoutesResponse(BaseModel):
"""Response containing a list of all available API routes.
:param data: List of available route information objects
"""
data: list[RouteInfo]
@runtime_checkable
class Inspect(Protocol):
"""Inspect
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
"""
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
async def list_routes(self, api_filter: ApiFilter | None = None) -> ListRoutesResponse:
"""List routes.
List all available API routes with their methods and implementing providers.
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
:returns: Response containing information about all available routes.
"""
...
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
async def health(self) -> HealthInfo:
"""Get health status.
Get the current health status of the service.
:returns: Health information indicating if the service is operational.
"""
...
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
async def version(self) -> VersionInfo:
"""Get version.
Get the version of the service.
:returns: Version information containing the service version number.
"""
...

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol, runtime_checkable
from .models import HealthInfo, ListRoutesResponse, VersionInfo
@runtime_checkable
class InspectService(Protocol):
"""Inspect
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
"""
async def list_routes(self) -> ListRoutesResponse:
"""List routes."""
...
async def health(self) -> HealthInfo:
"""Get health status."""
...
async def version(self) -> VersionInfo:
"""Get version."""
...

View file

@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class RouteInfo(BaseModel):
"""Information about an API route including its path, method, and implementing providers."""
route: str = Field(..., description="The API endpoint path")
method: str = Field(..., description="HTTP method for the route")
provider_types: list[str] = Field(..., description="List of provider types that implement this route")
@json_schema_type
class HealthInfo(BaseModel):
"""Health status information for the service."""
status: HealthStatus = Field(..., description="Current health status of the service")
@json_schema_type
class VersionInfo(BaseModel):
"""Version information for the service."""
version: str = Field(..., description="Version number of the service")
class ListRoutesResponse(BaseModel):
"""Response containing a list of all available API routes."""
data: list[RouteInfo] = Field(..., description="List of available route information objects")

View file

@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from fastapi import Depends, Request
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .inspect_service import InspectService
from .models import HealthInfo, ListRoutesResponse, VersionInfo
def get_inspect_service(request: Request) -> InspectService:
"""Dependency to get the inspect service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.inspect not in impls:
raise ValueError("Inspect API implementation not found")
return impls[Api.inspect]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Inspect"],
responses=standard_responses,
)
@router.get(
"/inspect/routes",
response_model=ListRoutesResponse,
summary="List routes.",
description="List all available API routes with their methods and implementing providers.",
)
async def list_routes(svc: InspectService = Depends(get_inspect_service)) -> ListRoutesResponse:
"""List all available API routes."""
return await svc.list_routes()
@router.get(
"/health",
response_model=HealthInfo,
summary="Get health status.",
description="Get the current health status of the service.",
)
async def health(svc: InspectService = Depends(get_inspect_service)) -> HealthInfo:
"""Get the current health status of the service."""
return await svc.health()
@router.get(
"/version",
response_model=VersionInfo,
summary="Get version.",
description="Get the version of the service.",
)
async def version(svc: InspectService = Depends(get_inspect_service)) -> VersionInfo:
"""Get the version of the service."""
return await svc.version()
# For backward compatibility with the router registry system
def create_inspect_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Inspect API (legacy compatibility)."""
return router
# Register the router factory
register_router(Api.inspect, create_inspect_router)

View file

@ -4,4 +4,30 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .models import * # Import routes to trigger router registration
from . import routes # noqa: F401
from .model_schemas import (
ListModelsResponse,
Model,
ModelInput,
ModelType,
OpenAIListModelsResponse,
OpenAIModel,
RegisterModelRequest,
)
from .models_service import ModelService
# Backward compatibility - export Models as alias for ModelService
Models = ModelService
__all__ = [
"Models",
"ModelService",
"Model",
"ModelInput",
"ModelType",
"ListModelsResponse",
"RegisterModelRequest",
"OpenAIModel",
"OpenAIListModelsResponse",
]

View file

@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_validator
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type
class CommonModelFields(BaseModel):
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this model.",
)
@json_schema_type
class ModelType(StrEnum):
"""Enumeration of supported model types in Llama Stack."""
llm = "llm"
embedding = "embedding"
rerank = "rerank"
@json_schema_type
class Model(CommonModelFields, Resource):
"""A model resource representing an AI model registered in Llama Stack."""
type: Literal[ResourceType.model] = Field(
default=ResourceType.model, description="The resource type, always 'model' for model resources."
)
model_type: ModelType = Field(default=ModelType.llm, description="The type of model (LLM or embedding model).")
@property
def model_id(self) -> str:
return self.identifier
@property
def provider_model_id(self) -> str:
assert self.provider_resource_id is not None, "Provider resource ID must be set"
return self.provider_resource_id
model_config = ConfigDict(protected_namespaces=())
@field_validator("provider_resource_id")
@classmethod
def validate_provider_resource_id(cls, v):
if v is None:
raise ValueError("provider_resource_id cannot be None")
return v
class ModelInput(CommonModelFields):
model_id: str
provider_id: str | None = None
provider_model_id: str | None = None
model_type: ModelType | None = ModelType.llm
model_config = ConfigDict(protected_namespaces=())
class ListModelsResponse(BaseModel):
"""Response model for listing models."""
data: list[Model] = Field(description="List of model resources.")
@json_schema_type
class RegisterModelRequest(BaseModel):
"""Request model for registering a new model."""
model_id: str = Field(..., description="The identifier of the model to register.")
provider_model_id: str | None = Field(default=None, description="The identifier of the model in the provider.")
provider_id: str | None = Field(default=None, description="The identifier of the provider.")
metadata: dict[str, Any] | None = Field(default=None, description="Any additional metadata for this model.")
model_type: ModelType | None = Field(default=None, description="The type of model to register.")
@json_schema_type
class OpenAIModel(BaseModel):
"""A model from OpenAI."""
id: str = Field(..., description="The ID of the model.")
object: Literal["model"] = Field(default="model", description="The object type, which will be 'model'.")
created: int = Field(..., description="The Unix timestamp in seconds when the model was created.")
owned_by: str = Field(..., description="The owner of the model.")
class OpenAIListModelsResponse(BaseModel):
"""Response model for listing OpenAI models."""
data: list[OpenAIModel] = Field(description="List of OpenAI model objects.")

View file

@ -1,172 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field, field_validator
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonModelFields(BaseModel):
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this model",
)
@json_schema_type
class ModelType(StrEnum):
"""Enumeration of supported model types in Llama Stack.
:cvar llm: Large language model for text generation and completion
:cvar embedding: Embedding model for converting text to vector representations
:cvar rerank: Reranking model for reordering documents based on their relevance to a query
"""
llm = "llm"
embedding = "embedding"
rerank = "rerank"
@json_schema_type
class Model(CommonModelFields, Resource):
"""A model resource representing an AI model registered in Llama Stack.
:param type: The resource type, always 'model' for model resources
:param model_type: The type of model (LLM or embedding model)
:param metadata: Any additional metadata for this model
:param identifier: Unique identifier for this resource in llama stack
:param provider_resource_id: Unique identifier for this resource in the provider
:param provider_id: ID of the provider that owns this resource
"""
type: Literal[ResourceType.model] = ResourceType.model
@property
def model_id(self) -> str:
return self.identifier
@property
def provider_model_id(self) -> str:
assert self.provider_resource_id is not None, "Provider resource ID must be set"
return self.provider_resource_id
model_config = ConfigDict(protected_namespaces=())
model_type: ModelType = Field(default=ModelType.llm)
@field_validator("provider_resource_id")
@classmethod
def validate_provider_resource_id(cls, v):
if v is None:
raise ValueError("provider_resource_id cannot be None")
return v
class ModelInput(CommonModelFields):
model_id: str
provider_id: str | None = None
provider_model_id: str | None = None
model_type: ModelType | None = ModelType.llm
model_config = ConfigDict(protected_namespaces=())
class ListModelsResponse(BaseModel):
data: list[Model]
@json_schema_type
class OpenAIModel(BaseModel):
"""A model from OpenAI.
:id: The ID of the model
:object: The object type, which will be "model"
:created: The Unix timestamp in seconds when the model was created
:owned_by: The owner of the model
:custom_metadata: Llama Stack-specific metadata including model_type, provider info, and additional metadata
"""
id: str
object: Literal["model"] = "model"
created: int
owned_by: str
custom_metadata: dict[str, Any] | None = None
class OpenAIListModelsResponse(BaseModel):
data: list[OpenAIModel]
@runtime_checkable
@trace_protocol
class Models(Protocol):
async def list_models(self) -> ListModelsResponse:
"""List all models.
:returns: A ListModelsResponse.
"""
...
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API.
:returns: A OpenAIListModelsResponse.
"""
...
@webmethod(route="/models/{model_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_model(
self,
model_id: str,
) -> Model:
"""Get model.
Get a model by its identifier.
:param model_id: The identifier of the model to get.
:returns: A Model.
"""
...
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1)
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
"""Register model.
Register a model.
:param model_id: The identifier of the model to register.
:param provider_model_id: The identifier of the model in the provider.
:param provider_id: The identifier of the provider.
:param metadata: Any additional metadata for this model.
:param model_type: The type of model to register.
:returns: A Model.
"""
...
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_model(
self,
model_id: str,
) -> None:
"""Unregister model.
Unregister a model.
:param model_id: The identifier of the model to unregister.
"""
...

View file

@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_validator
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type
class CommonModelFields(BaseModel):
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this model.",
)
@json_schema_type
class ModelType(StrEnum):
"""Enumeration of supported model types in Llama Stack."""
llm = "llm"
embedding = "embedding"
rerank = "rerank"
@json_schema_type
class Model(CommonModelFields, Resource):
"""A model resource representing an AI model registered in Llama Stack."""
type: Literal[ResourceType.model] = Field(
default=ResourceType.model, description="The resource type, always 'model' for model resources."
)
model_type: ModelType = Field(default=ModelType.llm, description="The type of model (LLM or embedding model).")
@property
def model_id(self) -> str:
return self.identifier
@property
def provider_model_id(self) -> str:
assert self.provider_resource_id is not None, "Provider resource ID must be set"
return self.provider_resource_id
model_config = ConfigDict(protected_namespaces=())
@field_validator("provider_resource_id")
@classmethod
def validate_provider_resource_id(cls, v):
if v is None:
raise ValueError("provider_resource_id cannot be None")
return v
class ModelInput(CommonModelFields):
model_id: str
provider_id: str | None = None
provider_model_id: str | None = None
model_type: ModelType | None = ModelType.llm
model_config = ConfigDict(protected_namespaces=())
class ListModelsResponse(BaseModel):
"""Response model for listing models."""
data: list[Model] = Field(description="List of model resources.")
@json_schema_type
class RegisterModelRequest(BaseModel):
"""Request model for registering a new model."""
model_id: str = Field(..., description="The identifier of the model to register.")
provider_model_id: str | None = Field(default=None, description="The identifier of the model in the provider.")
provider_id: str | None = Field(default=None, description="The identifier of the provider.")
metadata: dict[str, Any] | None = Field(default=None, description="Any additional metadata for this model.")
model_type: ModelType | None = Field(default=None, description="The type of model to register.")
@json_schema_type
class OpenAIModel(BaseModel):
"""A model from OpenAI."""
id: str = Field(..., description="The ID of the model.")
object: Literal["model"] = Field(default="model", description="The object type, which will be 'model'.")
created: int = Field(..., description="The Unix timestamp in seconds when the model was created.")
owned_by: str = Field(..., description="The owner of the model.")
class OpenAIListModelsResponse(BaseModel):
"""Response model for listing OpenAI models."""
data: list[OpenAIModel] = Field(description="List of OpenAI model objects.")

View file

@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from .model_schemas import (
ListModelsResponse,
Model,
ModelType,
OpenAIListModelsResponse,
)
@runtime_checkable
@trace_protocol
class ModelService(Protocol):
async def list_models(self) -> ListModelsResponse:
"""List all models."""
...
async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API."""
...
async def get_model(
self,
model_id: str,
) -> Model:
"""Get model."""
...
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
"""Register model."""
...
async def unregister_model(
self,
model_id: str,
) -> None:
"""Unregister model."""
...

View file

@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated
from fastapi import Body, Depends, Request
from fastapi import Path as FastAPIPath
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .model_schemas import (
ListModelsResponse,
Model,
RegisterModelRequest,
)
from .models_service import ModelService
def get_model_service(request: Request) -> ModelService:
"""Dependency to get the model service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.models not in impls:
raise ValueError("Models API implementation not found")
return impls[Api.models]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Models"],
responses=standard_responses,
)
@router.get(
"/models",
response_model=ListModelsResponse,
summary="List all models.",
description="List all models registered in Llama Stack.",
)
async def list_models(svc: ModelService = Depends(get_model_service)) -> ListModelsResponse:
"""List all models."""
return await svc.list_models()
@router.get(
"/models/{model_id:path}",
response_model=Model,
summary="Get model.",
description="Get a model by its identifier.",
)
async def get_model(
model_id: Annotated[str, FastAPIPath(..., description="The identifier of the model to get.")],
svc: ModelService = Depends(get_model_service),
) -> Model:
"""Get model by its identifier."""
return await svc.get_model(model_id=model_id)
@router.post(
"/models",
response_model=Model,
summary="Register model.",
description="Register a new model in Llama Stack.",
)
async def register_model(
body: RegisterModelRequest = Body(...),
svc: ModelService = Depends(get_model_service),
) -> Model:
"""Register a new model."""
return await svc.register_model(
model_id=body.model_id,
provider_model_id=body.provider_model_id,
provider_id=body.provider_id,
metadata=body.metadata,
model_type=body.model_type,
)
@router.delete(
"/models/{model_id:path}",
response_model=None,
status_code=204,
summary="Unregister model.",
description="Unregister a model from Llama Stack.",
)
async def unregister_model(
model_id: Annotated[str, FastAPIPath(..., description="The identifier of the model to unregister.")],
svc: ModelService = Depends(get_model_service),
) -> None:
"""Unregister a model."""
await svc.unregister_model(model_id=model_id)
# For backward compatibility with the router registry system
def create_models_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Models API (legacy compatibility)."""
return router
# Register the router factory
register_router(Api.models, create_models_router)

View file

@ -4,4 +4,61 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .post_training import * # Import routes to trigger router registration
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from . import routes # noqa: F401
from .models import (
AlgorithmConfig,
DataConfig,
DatasetFormat,
DPOAlignmentConfig,
DPOLossType,
EfficiencyConfig,
ListPostTrainingJobsResponse,
LoraFinetuningConfig,
OptimizerConfig,
OptimizerType,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobLogStream,
PostTrainingJobStatusResponse,
PostTrainingRLHFRequest,
PreferenceOptimizeRequest,
QATFinetuningConfig,
RLHFAlgorithm,
SupervisedFineTuneRequest,
TrainingConfig,
)
from .post_training_service import PostTrainingService
# Backward compatibility - export PostTraining as alias for PostTrainingService
PostTraining = PostTrainingService
__all__ = [
"PostTraining",
"PostTrainingService",
"Checkpoint",
"JobStatus",
"OptimizerType",
"DatasetFormat",
"DataConfig",
"OptimizerConfig",
"EfficiencyConfig",
"TrainingConfig",
"LoraFinetuningConfig",
"QATFinetuningConfig",
"AlgorithmConfig",
"PostTrainingJobLogStream",
"RLHFAlgorithm",
"DPOLossType",
"DPOAlignmentConfig",
"PostTrainingRLHFRequest",
"PostTrainingJob",
"PostTrainingJobStatusResponse",
"ListPostTrainingJobsResponse",
"PostTrainingJobArtifactsResponse",
"SupervisedFineTuneRequest",
"PreferenceOptimizeRequest",
]

View file

@ -0,0 +1,222 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type
class OptimizerType(Enum):
"""Available optimizer algorithms for training."""
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class DatasetFormat(Enum):
"""Format of the training dataset."""
instruct = "instruct"
dialog = "dialog"
@json_schema_type
class DataConfig(BaseModel):
"""Configuration for training data and data loading."""
dataset_id: str
batch_size: int
shuffle: bool
data_format: DatasetFormat
validation_dataset_id: str | None = None
packed: bool | None = False
train_on_input: bool | None = False
@json_schema_type
class OptimizerConfig(BaseModel):
"""Configuration parameters for the optimization algorithm."""
optimizer_type: OptimizerType
lr: float
weight_decay: float
num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
"""Configuration for memory and compute efficiency optimizations."""
enable_activation_checkpointing: bool | None = False
enable_activation_offloading: bool | None = False
memory_efficient_fsdp_wrap: bool | None = False
fsdp_cpu_offload: bool | None = False
@json_schema_type
class TrainingConfig(BaseModel):
"""Comprehensive configuration for the training process."""
n_epochs: int
max_steps_per_epoch: int = 1
gradient_accumulation_steps: int = 1
max_validation_steps: int | None = 1
data_config: DataConfig | None = None
optimizer_config: OptimizerConfig | None = None
efficiency_config: EfficiencyConfig | None = None
dtype: str | None = "bf16"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
"""Configuration for Low-Rank Adaptation (LoRA) fine-tuning."""
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: list[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: bool | None = False
quantize_base: bool | None = False
@json_schema_type
class QATFinetuningConfig(BaseModel):
"""Configuration for Quantization-Aware Training (QAT) fine-tuning."""
type: Literal["QAT"] = "QAT"
quantizer_name: str
group_size: int
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job."""
job_uuid: str
log_lines: list[str]
@json_schema_type
class RLHFAlgorithm(Enum):
"""Available reinforcement learning from human feedback algorithms."""
dpo = "dpo"
@json_schema_type
class DPOLossType(Enum):
sigmoid = "sigmoid"
hinge = "hinge"
ipo = "ipo"
kto_pair = "kto_pair"
@json_schema_type
class DPOAlignmentConfig(BaseModel):
"""Configuration for Direct Preference Optimization (DPO) alignment."""
beta: float
loss_type: DPOLossType = DPOLossType.sigmoid
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model using reinforcement learning from human feedback."""
job_uuid: str
finetuned_model: URL
dataset_id: str
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: dict[str, Any]
logger_config: dict[str, Any]
class PostTrainingJob(BaseModel):
job_uuid: str = Field(..., description="The UUID of the job")
@json_schema_type
class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job."""
job_uuid: str
status: JobStatus
scheduled_at: datetime | None = None
started_at: datetime | None = None
completed_at: datetime | None = None
resources_allocated: dict[str, Any] | None = None
checkpoints: list[Checkpoint] = Field(default_factory=list)
class ListPostTrainingJobsResponse(BaseModel):
data: list[PostTrainingJob] = Field(..., description="The list of training jobs")
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job."""
job_uuid: str = Field(..., description="The UUID of the job")
checkpoints: list[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals
@json_schema_type
class SupervisedFineTuneRequest(BaseModel):
"""Request to run supervised fine-tuning of a model."""
job_uuid: str = Field(..., description="The UUID of the job to create")
training_config: TrainingConfig = Field(..., description="The training configuration")
hyperparam_search_config: dict[str, Any] = Field(..., description="The hyperparam search configuration")
logger_config: dict[str, Any] = Field(..., description="The logger configuration")
model: str | None = Field(
default=None,
description="Model descriptor for training if not in provider config`",
)
checkpoint_dir: str | None = Field(default=None, description="The directory to save checkpoint(s) to")
algorithm_config: AlgorithmConfig | None = Field(default=None, description="The algorithm configuration")
@json_schema_type
class PreferenceOptimizeRequest(BaseModel):
"""Request to run preference optimization of a model."""
job_uuid: str = Field(..., description="The UUID of the job to create")
finetuned_model: str = Field(..., description="The model to fine-tune")
algorithm_config: DPOAlignmentConfig = Field(..., description="The algorithm configuration")
training_config: TrainingConfig = Field(..., description="The training configuration")
hyperparam_search_config: dict[str, Any] = Field(..., description="The hyperparam search configuration")
logger_config: dict[str, Any] = Field(..., description="The logger configuration")

View file

@ -1,368 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class OptimizerType(Enum):
"""Available optimizer algorithms for training.
:cvar adam: Adaptive Moment Estimation optimizer
:cvar adamw: AdamW optimizer with weight decay
:cvar sgd: Stochastic Gradient Descent optimizer
"""
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class DatasetFormat(Enum):
"""Format of the training dataset.
:cvar instruct: Instruction-following format with prompt and completion
:cvar dialog: Multi-turn conversation format with messages
"""
instruct = "instruct"
dialog = "dialog"
@json_schema_type
class DataConfig(BaseModel):
"""Configuration for training data and data loading.
:param dataset_id: Unique identifier for the training dataset
:param batch_size: Number of samples per training batch
:param shuffle: Whether to shuffle the dataset during training
:param data_format: Format of the dataset (instruct or dialog)
:param validation_dataset_id: (Optional) Unique identifier for the validation dataset
:param packed: (Optional) Whether to pack multiple samples into a single sequence for efficiency
:param train_on_input: (Optional) Whether to compute loss on input tokens as well as output tokens
"""
dataset_id: str
batch_size: int
shuffle: bool
data_format: DatasetFormat
validation_dataset_id: str | None = None
packed: bool | None = False
train_on_input: bool | None = False
@json_schema_type
class OptimizerConfig(BaseModel):
"""Configuration parameters for the optimization algorithm.
:param optimizer_type: Type of optimizer to use (adam, adamw, or sgd)
:param lr: Learning rate for the optimizer
:param weight_decay: Weight decay coefficient for regularization
:param num_warmup_steps: Number of steps for learning rate warmup
"""
optimizer_type: OptimizerType
lr: float
weight_decay: float
num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
"""Configuration for memory and compute efficiency optimizations.
:param enable_activation_checkpointing: (Optional) Whether to use activation checkpointing to reduce memory usage
:param enable_activation_offloading: (Optional) Whether to offload activations to CPU to save GPU memory
:param memory_efficient_fsdp_wrap: (Optional) Whether to use memory-efficient FSDP wrapping
:param fsdp_cpu_offload: (Optional) Whether to offload FSDP parameters to CPU
"""
enable_activation_checkpointing: bool | None = False
enable_activation_offloading: bool | None = False
memory_efficient_fsdp_wrap: bool | None = False
fsdp_cpu_offload: bool | None = False
@json_schema_type
class TrainingConfig(BaseModel):
"""Comprehensive configuration for the training process.
:param n_epochs: Number of training epochs to run
:param max_steps_per_epoch: Maximum number of steps to run per epoch
:param gradient_accumulation_steps: Number of steps to accumulate gradients before updating
:param max_validation_steps: (Optional) Maximum number of validation steps per epoch
:param data_config: (Optional) Configuration for data loading and formatting
:param optimizer_config: (Optional) Configuration for the optimization algorithm
:param efficiency_config: (Optional) Configuration for memory and compute optimizations
:param dtype: (Optional) Data type for model parameters (bf16, fp16, fp32)
"""
n_epochs: int
max_steps_per_epoch: int = 1
gradient_accumulation_steps: int = 1
max_validation_steps: int | None = 1
data_config: DataConfig | None = None
optimizer_config: OptimizerConfig | None = None
efficiency_config: EfficiencyConfig | None = None
dtype: str | None = "bf16"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
"""Configuration for Low-Rank Adaptation (LoRA) fine-tuning.
:param type: Algorithm type identifier, always "LoRA"
:param lora_attn_modules: List of attention module names to apply LoRA to
:param apply_lora_to_mlp: Whether to apply LoRA to MLP layers
:param apply_lora_to_output: Whether to apply LoRA to output projection layers
:param rank: Rank of the LoRA adaptation (lower rank = fewer parameters)
:param alpha: LoRA scaling parameter that controls adaptation strength
:param use_dora: (Optional) Whether to use DoRA (Weight-Decomposed Low-Rank Adaptation)
:param quantize_base: (Optional) Whether to quantize the base model weights
"""
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: list[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: bool | None = False
quantize_base: bool | None = False
@json_schema_type
class QATFinetuningConfig(BaseModel):
"""Configuration for Quantization-Aware Training (QAT) fine-tuning.
:param type: Algorithm type identifier, always "QAT"
:param quantizer_name: Name of the quantization algorithm to use
:param group_size: Size of groups for grouped quantization
"""
type: Literal["QAT"] = "QAT"
quantizer_name: str
group_size: int
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job.
:param job_uuid: Unique identifier for the training job
:param log_lines: List of log message strings from the training process
"""
job_uuid: str
log_lines: list[str]
@json_schema_type
class RLHFAlgorithm(Enum):
"""Available reinforcement learning from human feedback algorithms.
:cvar dpo: Direct Preference Optimization algorithm
"""
dpo = "dpo"
@json_schema_type
class DPOLossType(Enum):
sigmoid = "sigmoid"
hinge = "hinge"
ipo = "ipo"
kto_pair = "kto_pair"
@json_schema_type
class DPOAlignmentConfig(BaseModel):
"""Configuration for Direct Preference Optimization (DPO) alignment.
:param beta: Temperature parameter for the DPO loss
:param loss_type: The type of loss function to use for DPO
"""
beta: float
loss_type: DPOLossType = DPOLossType.sigmoid
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model using reinforcement learning from human feedback.
:param job_uuid: Unique identifier for the training job
:param finetuned_model: URL or path to the base model to fine-tune
:param dataset_id: Unique identifier for the training dataset
:param validation_dataset_id: Unique identifier for the validation dataset
:param algorithm: RLHF algorithm to use for training
:param algorithm_config: Configuration parameters for the RLHF algorithm
:param optimizer_config: Configuration parameters for the optimization algorithm
:param training_config: Configuration parameters for the training process
:param hyperparam_search_config: Configuration for hyperparameter search
:param logger_config: Configuration for training logging
"""
job_uuid: str
finetuned_model: URL
dataset_id: str
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: dict[str, Any]
logger_config: dict[str, Any]
class PostTrainingJob(BaseModel):
job_uuid: str
@json_schema_type
class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job.
:param job_uuid: Unique identifier for the training job
:param status: Current status of the training job
:param scheduled_at: (Optional) Timestamp when the job was scheduled
:param started_at: (Optional) Timestamp when the job execution began
:param completed_at: (Optional) Timestamp when the job finished, if completed
:param resources_allocated: (Optional) Information about computational resources allocated to the job
:param checkpoints: List of model checkpoints created during training
"""
job_uuid: str
status: JobStatus
scheduled_at: datetime | None = None
started_at: datetime | None = None
completed_at: datetime | None = None
resources_allocated: dict[str, Any] | None = None
checkpoints: list[Checkpoint] = Field(default_factory=list)
class ListPostTrainingJobsResponse(BaseModel):
data: list[PostTrainingJob]
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job.
:param job_uuid: Unique identifier for the training job
:param checkpoints: List of model checkpoints created during training
"""
job_uuid: str
checkpoints: list[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str | None = Field(
default=None,
description="Model descriptor for training if not in provider config`",
),
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
"""Run supervised fine-tuning of a model.
:param job_uuid: The UUID of the job to create.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:param model: The model to fine-tune.
:param checkpoint_dir: The directory to save checkpoint(s) to.
:param algorithm_config: The algorithm configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
"""Run preference optimization of a model.
:param job_uuid: The UUID of the job to create.
:param finetuned_model: The model to fine-tune.
:param algorithm_config: The algorithm configuration.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all training jobs.
:returns: A ListPostTrainingJobsResponse.
"""
...
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
"""Get the status of a training job.
:param job_uuid: The UUID of the job to get the status of.
:returns: A PostTrainingJobStatusResponse.
"""
...
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancel a training job.
:param job_uuid: The UUID of the job to cancel.
"""
...
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job.
:param job_uuid: The UUID of the job to get the artifacts of.
:returns: A PostTrainingJobArtifactsResponse.
"""
...

View file

@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from .models import (
AlgorithmConfig,
DPOAlignmentConfig,
ListPostTrainingJobsResponse,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
@runtime_checkable
@trace_protocol
class PostTrainingService(Protocol):
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str | None = None,
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
"""Run supervised fine-tuning of a model."""
...
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
"""Run preference optimization of a model."""
...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all training jobs."""
...
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
"""Get the status of a training job."""
...
async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancel a training job."""
...
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job."""
...

View file

@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from fastapi import Body, Depends, Query, Request
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .models import (
ListPostTrainingJobsResponse,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
PreferenceOptimizeRequest,
SupervisedFineTuneRequest,
)
from .post_training_service import PostTrainingService
def get_post_training_service(request: Request) -> PostTrainingService:
"""Dependency to get the post training service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.post_training not in impls:
raise ValueError("Post Training API implementation not found")
return impls[Api.post_training]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Post Training"],
responses=standard_responses,
)
router_v1alpha = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
tags=["Post Training"],
responses=standard_responses,
)
@router.post(
"/post-training/supervised-fine-tune",
response_model=PostTrainingJob,
summary="Run supervised fine-tuning of a model",
description="Run supervised fine-tuning of a model",
deprecated=True,
)
@router_v1alpha.post(
"/post-training/supervised-fine-tune",
response_model=PostTrainingJob,
summary="Run supervised fine-tuning of a model",
description="Run supervised fine-tuning of a model",
)
async def supervised_fine_tune(
body: SupervisedFineTuneRequest = Body(...),
svc: PostTrainingService = Depends(get_post_training_service),
) -> PostTrainingJob:
"""Run supervised fine-tuning of a model."""
return await svc.supervised_fine_tune(
job_uuid=body.job_uuid,
training_config=body.training_config,
hyperparam_search_config=body.hyperparam_search_config,
logger_config=body.logger_config,
model=body.model,
checkpoint_dir=body.checkpoint_dir,
algorithm_config=body.algorithm_config,
)
@router.post(
"/post-training/preference-optimize",
response_model=PostTrainingJob,
summary="Run preference optimization of a model",
description="Run preference optimization of a model",
deprecated=True,
)
@router_v1alpha.post(
"/post-training/preference-optimize",
response_model=PostTrainingJob,
summary="Run preference optimization of a model",
description="Run preference optimization of a model",
)
async def preference_optimize(
body: PreferenceOptimizeRequest = Body(...),
svc: PostTrainingService = Depends(get_post_training_service),
) -> PostTrainingJob:
"""Run preference optimization of a model."""
return await svc.preference_optimize(
job_uuid=body.job_uuid,
finetuned_model=body.finetuned_model,
algorithm_config=body.algorithm_config,
training_config=body.training_config,
hyperparam_search_config=body.hyperparam_search_config,
logger_config=body.logger_config,
)
@router.get(
"/post-training/jobs",
response_model=ListPostTrainingJobsResponse,
summary="Get all training jobs",
description="Get all training jobs",
deprecated=True,
)
@router_v1alpha.get(
"/post-training/jobs",
response_model=ListPostTrainingJobsResponse,
summary="Get all training jobs",
description="Get all training jobs",
)
async def get_training_jobs(
svc: PostTrainingService = Depends(get_post_training_service),
) -> ListPostTrainingJobsResponse:
"""Get all training jobs."""
return await svc.get_training_jobs()
@router.get(
"/post-training/job/status",
response_model=PostTrainingJobStatusResponse,
summary="Get the status of a training job",
description="Get the status of a training job",
deprecated=True,
)
@router_v1alpha.get(
"/post-training/job/status",
response_model=PostTrainingJobStatusResponse,
summary="Get the status of a training job",
description="Get the status of a training job",
)
async def get_training_job_status(
job_uuid: str = Query(..., description="The UUID of the job to get the status of"),
svc: PostTrainingService = Depends(get_post_training_service),
) -> PostTrainingJobStatusResponse:
"""Get the status of a training job."""
return await svc.get_training_job_status(job_uuid=job_uuid)
@router.post(
"/post-training/job/cancel",
response_model=None,
status_code=204,
summary="Cancel a training job",
description="Cancel a training job",
deprecated=True,
)
@router_v1alpha.post(
"/post-training/job/cancel",
response_model=None,
status_code=204,
summary="Cancel a training job",
description="Cancel a training job",
)
async def cancel_training_job(
job_uuid: str = Query(..., description="The UUID of the job to cancel"),
svc: PostTrainingService = Depends(get_post_training_service),
) -> None:
"""Cancel a training job."""
await svc.cancel_training_job(job_uuid=job_uuid)
@router.get(
"/post-training/job/artifacts",
response_model=PostTrainingJobArtifactsResponse,
summary="Get the artifacts of a training job",
description="Get the artifacts of a training job",
deprecated=True,
)
@router_v1alpha.get(
"/post-training/job/artifacts",
response_model=PostTrainingJobArtifactsResponse,
summary="Get the artifacts of a training job",
description="Get the artifacts of a training job",
)
async def get_training_job_artifacts(
job_uuid: str = Query(..., description="The UUID of the job to get the artifacts of"),
svc: PostTrainingService = Depends(get_post_training_service),
) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job."""
return await svc.get_training_job_artifacts(job_uuid=job_uuid)
# For backward compatibility with the router registry system
def create_post_training_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Post Training API (legacy compatibility)."""
main_router = APIRouter()
main_router.include_router(router)
main_router.include_router(router_v1alpha)
return main_router
# Register the router factory
register_router(Api.post_training, create_post_training_router)

View file

@ -4,6 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .prompts import ListPromptsResponse, Prompt, Prompts # Import routes to trigger router registration
from . import routes # noqa: F401
from .models import (
CreatePromptRequest,
ListPromptsResponse,
Prompt,
SetDefaultVersionRequest,
UpdatePromptRequest,
)
from .prompts_service import PromptService
__all__ = ["Prompt", "Prompts", "ListPromptsResponse"] # Backward compatibility - export Prompts as alias for PromptService
Prompts = PromptService
__all__ = [
"Prompts",
"PromptService",
"Prompt",
"ListPromptsResponse",
"CreatePromptRequest",
"UpdatePromptRequest",
"SetDefaultVersionRequest",
]

View file

@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import secrets
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Prompt(BaseModel):
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack."""
prompt: str | None = Field(
default=None,
description="The system prompt text with variable placeholders. Variables are only supported when using the Responses API.",
)
version: int = Field(description="Version (integer starting at 1, incremented on save).", ge=1)
prompt_id: str = Field(description="Unique identifier formatted as 'pmpt_<48-digit-hash>'.")
variables: list[str] = Field(
default_factory=list, description="List of prompt variable names that can be used in the prompt template."
)
is_default: bool = Field(
default=False, description="Boolean indicating whether this version is the default version for this prompt."
)
@field_validator("prompt_id")
@classmethod
def validate_prompt_id(cls, prompt_id: str) -> str:
if not isinstance(prompt_id, str):
raise TypeError("prompt_id must be a string in format 'pmpt_<48-digit-hash>'")
if not prompt_id.startswith("pmpt_"):
raise ValueError("prompt_id must start with 'pmpt_' prefix")
hex_part = prompt_id[5:]
if len(hex_part) != 48:
raise ValueError("prompt_id must be in format 'pmpt_<48-digit-hash>' (48 lowercase hex chars)")
for char in hex_part:
if char not in "0123456789abcdef":
raise ValueError("prompt_id hex part must contain only lowercase hex characters [0-9a-f]")
return prompt_id
@field_validator("version")
@classmethod
def validate_version(cls, prompt_version: int) -> int:
if prompt_version < 1:
raise ValueError("version must be >= 1")
return prompt_version
@model_validator(mode="after")
def validate_prompt_variables(self):
"""Validate that all variables used in the prompt are declared in the variables list."""
if not self.prompt:
return self
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
declared_variables = set(self.variables)
undeclared = prompt_variables - declared_variables
if undeclared:
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
return self
@classmethod
def generate_prompt_id(cls) -> str:
# Generate 48 hex characters (24 bytes)
random_bytes = secrets.token_bytes(24)
hex_string = random_bytes.hex()
return f"pmpt_{hex_string}"
class ListPromptsResponse(BaseModel):
"""Response model to list prompts."""
data: list[Prompt] = Field(description="List of prompt resources.")
@json_schema_type
class CreatePromptRequest(BaseModel):
"""Request model for creating a new prompt."""
prompt: str = Field(..., description="The prompt text content with variable placeholders.")
variables: list[str] | None = Field(
default=None, description="List of variable names that can be used in the prompt template."
)
@json_schema_type
class UpdatePromptRequest(BaseModel):
"""Request model for updating an existing prompt."""
prompt: str = Field(..., description="The updated prompt text content.")
version: int = Field(..., description="The current version of the prompt being updated.")
variables: list[str] | None = Field(
default=None, description="Updated list of variable names that can be used in the prompt template."
)
set_as_default: bool = Field(default=True, description="Set the new version as the default (default=True).")
@json_schema_type
class SetDefaultVersionRequest(BaseModel):
"""Request model for setting a prompt version as default."""
version: int = Field(..., description="The version to set as default.")

View file

@ -1,204 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import secrets
from typing import Protocol, runtime_checkable
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class Prompt(BaseModel):
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack.
:param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API.
:param version: Version (integer starting at 1, incremented on save)
:param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>'
:param variables: List of prompt variable names that can be used in the prompt template
:param is_default: Boolean indicating whether this version is the default version for this prompt
"""
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1)
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
variables: list[str] = Field(
default_factory=list, description="List of variable names that can be used in the prompt template"
)
is_default: bool = Field(
default=False, description="Boolean indicating whether this version is the default version"
)
@field_validator("prompt_id")
@classmethod
def validate_prompt_id(cls, prompt_id: str) -> str:
if not isinstance(prompt_id, str):
raise TypeError("prompt_id must be a string in format 'pmpt_<48-digit-hash>'")
if not prompt_id.startswith("pmpt_"):
raise ValueError("prompt_id must start with 'pmpt_' prefix")
hex_part = prompt_id[5:]
if len(hex_part) != 48:
raise ValueError("prompt_id must be in format 'pmpt_<48-digit-hash>' (48 lowercase hex chars)")
for char in hex_part:
if char not in "0123456789abcdef":
raise ValueError("prompt_id hex part must contain only lowercase hex characters [0-9a-f]")
return prompt_id
@field_validator("version")
@classmethod
def validate_version(cls, prompt_version: int) -> int:
if prompt_version < 1:
raise ValueError("version must be >= 1")
return prompt_version
@model_validator(mode="after")
def validate_prompt_variables(self):
"""Validate that all variables used in the prompt are declared in the variables list."""
if not self.prompt:
return self
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
declared_variables = set(self.variables)
undeclared = prompt_variables - declared_variables
if undeclared:
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
return self
@classmethod
def generate_prompt_id(cls) -> str:
# Generate 48 hex characters (24 bytes)
random_bytes = secrets.token_bytes(24)
hex_string = random_bytes.hex()
return f"pmpt_{hex_string}"
class ListPromptsResponse(BaseModel):
"""Response model to list prompts."""
data: list[Prompt]
@runtime_checkable
@trace_protocol
class Prompts(Protocol):
"""Prompts
Protocol for prompt management operations."""
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
async def list_prompts(self) -> ListPromptsResponse:
"""List all prompts.
:returns: A ListPromptsResponse containing all prompts.
"""
...
@webmethod(route="/prompts/{prompt_id}/versions", method="GET", level=LLAMA_STACK_API_V1)
async def list_prompt_versions(
self,
prompt_id: str,
) -> ListPromptsResponse:
"""List prompt versions.
List all versions of a specific prompt.
:param prompt_id: The identifier of the prompt to list versions for.
:returns: A ListPromptsResponse containing all versions of the prompt.
"""
...
@webmethod(route="/prompts/{prompt_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_prompt(
self,
prompt_id: str,
version: int | None = None,
) -> Prompt:
"""Get prompt.
Get a prompt by its identifier and optional version.
:param prompt_id: The identifier of the prompt to get.
:param version: The version of the prompt to get (defaults to latest).
:returns: A Prompt resource.
"""
...
@webmethod(route="/prompts", method="POST", level=LLAMA_STACK_API_V1)
async def create_prompt(
self,
prompt: str,
variables: list[str] | None = None,
) -> Prompt:
"""Create prompt.
Create a new prompt.
:param prompt: The prompt text content with variable placeholders.
:param variables: List of variable names that can be used in the prompt template.
:returns: The created Prompt resource.
"""
...
@webmethod(route="/prompts/{prompt_id}", method="PUT", level=LLAMA_STACK_API_V1)
async def update_prompt(
self,
prompt_id: str,
prompt: str,
version: int,
variables: list[str] | None = None,
set_as_default: bool = True,
) -> Prompt:
"""Update prompt.
Update an existing prompt (increments version).
:param prompt_id: The identifier of the prompt to update.
:param prompt: The updated prompt text content.
:param version: The current version of the prompt being updated.
:param variables: Updated list of variable names that can be used in the prompt template.
:param set_as_default: Set the new version as the default (default=True).
:returns: The updated Prompt resource with incremented version.
"""
...
@webmethod(route="/prompts/{prompt_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def delete_prompt(
self,
prompt_id: str,
) -> None:
"""Delete prompt.
Delete a prompt.
:param prompt_id: The identifier of the prompt to delete.
"""
...
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT", level=LLAMA_STACK_API_V1)
async def set_default_version(
self,
prompt_id: str,
version: int,
) -> Prompt:
"""Set prompt version.
Set which version of a prompt should be the default in get_prompt (latest).
:param prompt_id: The identifier of the prompt.
:param version: The version to set as default.
:returns: The prompt with the specified version now set as default.
"""
...

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol, runtime_checkable
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from .models import ListPromptsResponse, Prompt
@runtime_checkable
@trace_protocol
class PromptService(Protocol):
"""Prompts
Protocol for prompt management operations."""
async def list_prompts(self) -> ListPromptsResponse:
"""List all prompts."""
...
async def list_prompt_versions(
self,
prompt_id: str,
) -> ListPromptsResponse:
"""List prompt versions."""
...
async def get_prompt(
self,
prompt_id: str,
version: int | None = None,
) -> Prompt:
"""Get prompt."""
...
async def create_prompt(
self,
prompt: str,
variables: list[str] | None = None,
) -> Prompt:
"""Create prompt."""
...
async def update_prompt(
self,
prompt_id: str,
prompt: str,
version: int,
variables: list[str] | None = None,
set_as_default: bool = True,
) -> Prompt:
"""Update prompt."""
...
async def delete_prompt(
self,
prompt_id: str,
) -> None:
"""Delete prompt."""
...
async def set_default_version(
self,
prompt_id: str,
version: int,
) -> Prompt:
"""Set prompt version."""
...

View file

@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated
from fastapi import Body, Depends, Query, Request
from fastapi import Path as FastAPIPath
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .models import (
CreatePromptRequest,
ListPromptsResponse,
Prompt,
SetDefaultVersionRequest,
UpdatePromptRequest,
)
from .prompts_service import PromptService
def get_prompt_service(request: Request) -> PromptService:
"""Dependency to get the prompt service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.prompts not in impls:
raise ValueError("Prompts API implementation not found")
return impls[Api.prompts]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Prompts"],
responses=standard_responses,
)
@router.get(
"/prompts",
response_model=ListPromptsResponse,
summary="List all prompts",
description="List all prompts registered in Llama Stack",
)
async def list_prompts(svc: PromptService = Depends(get_prompt_service)) -> ListPromptsResponse:
"""List all prompts."""
return await svc.list_prompts()
@router.get(
"/prompts/{prompt_id}/versions",
response_model=ListPromptsResponse,
summary="List prompt versions",
description="List all versions of a specific prompt",
)
async def list_prompt_versions(
prompt_id: Annotated[str, FastAPIPath(..., description="The identifier of the prompt to list versions for")],
svc: PromptService = Depends(get_prompt_service),
) -> ListPromptsResponse:
"""List prompt versions."""
return await svc.list_prompt_versions(prompt_id=prompt_id)
@router.get(
"/prompts/{prompt_id}",
response_model=Prompt,
summary="Get prompt",
description="Get a prompt by its identifier and optional version",
)
async def get_prompt(
prompt_id: Annotated[str, FastAPIPath(..., description="The identifier of the prompt to get")],
version: int | None = Query(None, description="The version of the prompt to get (defaults to latest)"),
svc: PromptService = Depends(get_prompt_service),
) -> Prompt:
"""Get prompt by its identifier and optional version."""
return await svc.get_prompt(prompt_id=prompt_id, version=version)
@router.post(
"/prompts",
response_model=Prompt,
summary="Create prompt",
description="Create a new prompt",
)
async def create_prompt(
body: CreatePromptRequest = Body(...),
svc: PromptService = Depends(get_prompt_service),
) -> Prompt:
"""Create a new prompt."""
return await svc.create_prompt(prompt=body.prompt, variables=body.variables)
@router.post(
"/prompts/{prompt_id}",
response_model=Prompt,
summary="Update prompt",
description="Update an existing prompt (increments version)",
)
async def update_prompt(
prompt_id: Annotated[str, FastAPIPath(..., description="The identifier of the prompt to update")],
body: UpdatePromptRequest = Body(...),
svc: PromptService = Depends(get_prompt_service),
) -> Prompt:
"""Update an existing prompt."""
return await svc.update_prompt(
prompt_id=prompt_id,
prompt=body.prompt,
version=body.version,
variables=body.variables,
set_as_default=body.set_as_default,
)
@router.delete(
"/prompts/{prompt_id}",
response_model=None,
status_code=204,
summary="Delete prompt",
description="Delete a prompt",
)
async def delete_prompt(
prompt_id: Annotated[str, FastAPIPath(..., description="The identifier of the prompt to delete")],
svc: PromptService = Depends(get_prompt_service),
) -> None:
"""Delete a prompt."""
await svc.delete_prompt(prompt_id=prompt_id)
@router.post(
"/prompts/{prompt_id}/set-default-version",
response_model=Prompt,
summary="Set prompt version",
description="Set which version of a prompt should be the default in get_prompt (latest)",
)
async def set_default_version(
prompt_id: Annotated[str, FastAPIPath(..., description="The identifier of the prompt")],
body: SetDefaultVersionRequest = Body(...),
svc: PromptService = Depends(get_prompt_service),
) -> Prompt:
"""Set which version of a prompt should be the default."""
return await svc.set_default_version(prompt_id=prompt_id, version=body.version)
# For backward compatibility with the router registry system
def create_prompts_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Prompts API (legacy compatibility)."""
return router
# Register the router factory
register_router(Api.prompts, create_prompts_router)

View file

@ -4,4 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .providers import * # Import routes to trigger router registration
from . import routes # noqa: F401
from .models import ListProvidersResponse, ProviderInfo
from .providers_service import ProviderService
# Backward compatibility - export Providers as alias for ProviderService
Providers = ProviderService
__all__ = ["Providers", "ProviderService", "ListProvidersResponse", "ProviderInfo"]

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.datatypes import HealthResponse
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class ProviderInfo(BaseModel):
"""Information about a registered provider including its configuration and health status."""
api: str = Field(..., description="The API name this provider implements")
provider_id: str = Field(..., description="Unique identifier for the provider")
provider_type: str = Field(..., description="The type of provider implementation")
config: dict[str, Any] = Field(..., description="Configuration parameters for the provider")
health: HealthResponse = Field(..., description="Current health status of the provider")
class ListProvidersResponse(BaseModel):
"""Response containing a list of all available providers."""
data: list[ProviderInfo] = Field(..., description="List of provider information objects")

View file

@ -1,69 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.datatypes import HealthResponse
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ProviderInfo(BaseModel):
"""Information about a registered provider including its configuration and health status.
:param api: The API name this provider implements
:param provider_id: Unique identifier for the provider
:param provider_type: The type of provider implementation
:param config: Configuration parameters for the provider
:param health: Current health status of the provider
"""
api: str
provider_id: str
provider_type: str
config: dict[str, Any]
health: HealthResponse
class ListProvidersResponse(BaseModel):
"""Response containing a list of all available providers.
:param data: List of provider information objects
"""
data: list[ProviderInfo]
@runtime_checkable
class Providers(Protocol):
"""Providers
Providers API for inspecting, listing, and modifying providers and their configurations.
"""
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
async def list_providers(self) -> ListProvidersResponse:
"""List providers.
List all available providers.
:returns: A ListProvidersResponse containing information about all providers.
"""
...
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get provider.
Get detailed information about a specific provider.
:param provider_id: The ID of the provider to inspect.
:returns: A ProviderInfo object containing the provider's details.
"""
...

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol, runtime_checkable
from .models import ListProvidersResponse, ProviderInfo
@runtime_checkable
class ProviderService(Protocol):
"""Providers
Providers API for inspecting, listing, and modifying providers and their configurations.
"""
async def list_providers(self) -> ListProvidersResponse:
"""List providers."""
...
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get provider."""
...

View file

@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated
from fastapi import Depends, Request
from fastapi import Path as FastAPIPath
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .models import ListProvidersResponse, ProviderInfo
from .providers_service import ProviderService
def get_provider_service(request: Request) -> ProviderService:
"""Dependency to get the provider service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.providers not in impls:
raise ValueError("Providers API implementation not found")
return impls[Api.providers]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Providers"],
responses=standard_responses,
)
@router.get(
"/providers",
response_model=ListProvidersResponse,
summary="List providers",
description="List all available providers",
)
async def list_providers(svc: ProviderService = Depends(get_provider_service)) -> ListProvidersResponse:
"""List all available providers."""
return await svc.list_providers()
@router.get(
"/providers/{provider_id}",
response_model=ProviderInfo,
summary="Get provider",
description="Get detailed information about a specific provider",
)
async def inspect_provider(
provider_id: Annotated[str, FastAPIPath(..., description="The ID of the provider to inspect")],
svc: ProviderService = Depends(get_provider_service),
) -> ProviderInfo:
"""Get detailed information about a specific provider."""
return await svc.inspect_provider(provider_id=provider_id)
# For backward compatibility with the router registry system
def create_providers_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Providers API (legacy compatibility)."""
return router
# Register the router factory
register_router(Api.providers, create_providers_router)

View file

@ -4,4 +4,31 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .safety import * # Import routes to trigger router registration
from . import routes # noqa: F401
from .models import (
ModerationObject,
ModerationObjectResults,
RunModerationRequest,
RunShieldRequest,
RunShieldResponse,
SafetyViolation,
ViolationLevel,
)
from .safety_service import SafetyService, ShieldStore
# Backward compatibility - export Safety as alias for SafetyService
Safety = SafetyService
__all__ = [
"Safety",
"SafetyService",
"ShieldStore",
"ModerationObject",
"ModerationObjectResults",
"RunShieldRequest",
"RunShieldResponse",
"RunModerationRequest",
"SafetyViolation",
"ViolationLevel",
]

View file

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class ModerationObjectResults(BaseModel):
"""A moderation object."""
flagged: bool = Field(..., description="Whether any of the below categories are flagged.")
categories: dict[str, bool] | None = Field(
default=None, description="A list of the categories, and whether they are flagged or not."
)
category_applied_input_types: dict[str, list[str]] | None = Field(
default=None,
description="A list of the categories along with the input type(s) that the score applies to.",
)
category_scores: dict[str, float] | None = Field(
default=None, description="A list of the categories along with their scores as predicted by model."
)
user_message: str | None = Field(default=None, description="User message.")
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata.")
@json_schema_type
class ModerationObject(BaseModel):
"""A moderation object."""
id: str = Field(..., description="The unique identifier for the moderation request.")
model: str = Field(..., description="The model used to generate the moderation results.")
results: list[ModerationObjectResults] = Field(..., description="A list of moderation objects.")
@json_schema_type
class ViolationLevel(Enum):
"""Severity level of a safety violation.
:cvar INFO: Informational level violation that does not require action
:cvar WARN: Warning level violation that suggests caution but allows continuation
:cvar ERROR: Error level violation that requires blocking or intervention
"""
INFO = "info"
WARN = "warn"
ERROR = "error"
@json_schema_type
class SafetyViolation(BaseModel):
"""Details of a safety violation detected by content moderation."""
violation_level: ViolationLevel = Field(..., description="Severity level of the violation.")
user_message: str | None = Field(default=None, description="Message to convey to the user about the violation.")
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Additional metadata including specific violation codes for debugging and telemetry.",
)
@json_schema_type
class RunShieldResponse(BaseModel):
"""Response from running a safety shield."""
violation: SafetyViolation | None = Field(
default=None, description="Safety violation detected by the shield, if any."
)
@json_schema_type
class RunShieldRequest(BaseModel):
"""Request model for running a shield."""
shield_id: str = Field(..., description="The identifier of the shield to run.")
messages: list[OpenAIMessageParam] = Field(..., description="The messages to run the shield on.")
params: dict[str, Any] = Field(..., description="The parameters of the shield.")
@json_schema_type
class RunModerationRequest(BaseModel):
"""Request model for running moderation."""
input: str | list[str] = Field(
...,
description="Input (or inputs) to classify. Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.",
)
model: str | None = Field(default=None, description="The content moderation model you would like to use.")

View file

@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from fastapi import Body, Depends, Request
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .models import ModerationObject, RunModerationRequest, RunShieldRequest, RunShieldResponse
from .safety_service import SafetyService
def get_safety_service(request: Request) -> SafetyService:
"""Dependency to get the safety service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.safety not in impls:
raise ValueError("Safety API implementation not found")
return impls[Api.safety]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Safety"],
responses=standard_responses,
)
@router.post(
"/safety/run-shield",
response_model=RunShieldResponse,
summary="Run shield.",
description="Run a shield.",
)
async def run_shield(
body: RunShieldRequest = Body(...),
svc: SafetyService = Depends(get_safety_service),
) -> RunShieldResponse:
"""Run a shield."""
return await svc.run_shield(shield_id=body.shield_id, messages=body.messages, params=body.params)
@router.post(
"/moderations",
response_model=ModerationObject,
summary="Create moderation.",
description="Classifies if text and/or image inputs are potentially harmful.",
)
async def run_moderation(
body: RunModerationRequest = Body(...),
svc: SafetyService = Depends(get_safety_service),
) -> ModerationObject:
"""Create moderation."""
return await svc.run_moderation(input=body.input, model=body.model)
# For backward compatibility with the router registry system
def create_safety_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Safety API (legacy compatibility)."""
return router
# Register the router factory
register_router(Api.safety, create_safety_router)

View file

@ -1,134 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.shields import Shield
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ModerationObjectResults(BaseModel):
"""A moderation object.
:param flagged: Whether any of the below categories are flagged.
:param categories: A list of the categories, and whether they are flagged or not.
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
:param category_scores: A list of the categories along with their scores as predicted by model.
"""
flagged: bool
categories: dict[str, bool] | None = None
category_applied_input_types: dict[str, list[str]] | None = None
category_scores: dict[str, float] | None = None
user_message: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class ModerationObject(BaseModel):
"""A moderation object.
:param id: The unique identifier for the moderation request.
:param model: The model used to generate the moderation results.
:param results: A list of moderation objects
"""
id: str
model: str
results: list[ModerationObjectResults]
@json_schema_type
class ViolationLevel(Enum):
"""Severity level of a safety violation.
:cvar INFO: Informational level violation that does not require action
:cvar WARN: Warning level violation that suggests caution but allows continuation
:cvar ERROR: Error level violation that requires blocking or intervention
"""
INFO = "info"
WARN = "warn"
ERROR = "error"
@json_schema_type
class SafetyViolation(BaseModel):
"""Details of a safety violation detected by content moderation.
:param violation_level: Severity level of the violation
:param user_message: (Optional) Message to convey to the user about the violation
:param metadata: Additional metadata including specific violation codes for debugging and telemetry
"""
violation_level: ViolationLevel
# what message should you convey to the user
user_message: str | None = None
# additional metadata (including specific violation codes) more for
# debugging, telemetry
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RunShieldResponse(BaseModel):
"""Response from running a safety shield.
:param violation: (Optional) Safety violation detected by the shield, if any
"""
violation: SafetyViolation | None = None
class ShieldStore(Protocol):
async def get_shield(self, identifier: str) -> Shield: ...
@runtime_checkable
@trace_protocol
class Safety(Protocol):
"""Safety
OpenAI-compatible Moderations API.
"""
shield_store: ShieldStore
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
async def run_shield(
self,
shield_id: str,
messages: list[OpenAIMessageParam],
params: dict[str, Any],
) -> RunShieldResponse:
"""Run shield.
Run a shield.
:param shield_id: The identifier of the shield to run.
:param messages: The messages to run the shield on.
:param params: The parameters of the shield.
:returns: A RunShieldResponse.
"""
...
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
"""Create moderation.
Classifies if text and/or image inputs are potentially harmful.
:param input: Input (or inputs) to classify.
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
:param model: (Optional) The content moderation model you would like to use.
:returns: A moderation object.
"""
...

View file

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.shields import Shield
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from .models import ModerationObject, RunShieldResponse
class ShieldStore(Protocol):
async def get_shield(self, identifier: str) -> Shield: ...
@runtime_checkable
@trace_protocol
class SafetyService(Protocol):
"""Safety
OpenAI-compatible Moderations API.
"""
shield_store: ShieldStore
async def run_shield(
self,
shield_id: str,
messages: list[OpenAIMessageParam],
params: dict[str, Any],
) -> RunShieldResponse:
"""Run shield."""
...
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
"""Create moderation."""
...

View file

@ -4,4 +4,29 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .scoring import * # Import routes to trigger router registration
from . import routes # noqa: F401
from .models import (
ScoreBatchRequest,
ScoreBatchResponse,
ScoreRequest,
ScoreResponse,
ScoringResult,
ScoringResultRow,
)
from .scoring_service import ScoringFunctionStore, ScoringService
# Backward compatibility - export Scoring as alias for ScoringService
Scoring = ScoringService
__all__ = [
"Scoring",
"ScoringService",
"ScoringFunctionStore",
"ScoreBatchRequest",
"ScoreBatchResponse",
"ScoreRequest",
"ScoreResponse",
"ScoringResult",
"ScoringResultRow",
]

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type
# mapping of metric to value
ScoringResultRow = dict[str, Any]
@json_schema_type
class ScoringResult(BaseModel):
"""A scoring result for a single row."""
score_rows: list[ScoringResultRow] = Field(
..., description="The scoring result for each row. Each row is a map of column name to value"
)
aggregated_results: dict[str, Any] = Field(..., description="Map of metric name to aggregated value")
@json_schema_type
class ScoreBatchResponse(BaseModel):
"""Response from batch scoring operations on datasets."""
dataset_id: str | None = Field(default=None, description="The identifier of the dataset that was scored")
results: dict[str, ScoringResult] = Field(..., description="A map of scoring function name to ScoringResult")
@json_schema_type
class ScoreResponse(BaseModel):
"""The response from scoring."""
results: dict[str, ScoringResult] = Field(..., description="A map of scoring function name to ScoringResult")
@json_schema_type
class ScoreBatchRequest(BaseModel):
"""Request for batch scoring operations."""
dataset_id: str = Field(..., description="The ID of the dataset to score")
scoring_functions: dict[str, ScoringFnParams | None] = Field(
..., description="The scoring functions to use for the scoring"
)
save_results_dataset: bool = Field(default=False, description="Whether to save the results to a dataset")
@json_schema_type
class ScoreRequest(BaseModel):
"""Request for scoring a list of rows."""
input_rows: list[dict[str, Any]] = Field(..., description="The rows to score")
scoring_functions: dict[str, ScoringFnParams | None] = Field(
..., description="The scoring functions to use for the scoring"
)

View file

@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from fastapi import Body, Depends, Request
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .models import ScoreBatchRequest, ScoreBatchResponse, ScoreRequest, ScoreResponse
from .scoring_service import ScoringService
def get_scoring_service(request: Request) -> ScoringService:
"""Dependency to get the scoring service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.scoring not in impls:
raise ValueError("Scoring API implementation not found")
return impls[Api.scoring]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Scoring"],
responses=standard_responses,
)
@router.post(
"/scoring/score-batch",
response_model=ScoreBatchResponse,
summary="Score a batch of rows",
description="Score a batch of rows from a dataset",
)
async def score_batch(
body: ScoreBatchRequest = Body(...),
svc: ScoringService = Depends(get_scoring_service),
) -> ScoreBatchResponse:
"""Score a batch of rows from a dataset."""
return await svc.score_batch(
dataset_id=body.dataset_id,
scoring_functions=body.scoring_functions,
save_results_dataset=body.save_results_dataset,
)
@router.post(
"/scoring/score",
response_model=ScoreResponse,
summary="Score a list of rows",
description="Score a list of rows",
)
async def score(
body: ScoreRequest = Body(...),
svc: ScoringService = Depends(get_scoring_service),
) -> ScoreResponse:
"""Score a list of rows."""
return await svc.score(
input_rows=body.input_rows,
scoring_functions=body.scoring_functions,
)
# For backward compatibility with the router registry system
def create_scoring_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Scoring API (legacy compatibility)."""
return router
# Register the router factory
register_router(Api.scoring, create_scoring_router)

View file

@ -1,93 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
# mapping of metric to value
ScoringResultRow = dict[str, Any]
@json_schema_type
class ScoringResult(BaseModel):
"""
A scoring result for a single row.
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
:param aggregated_results: Map of metric name to aggregated value
"""
score_rows: list[ScoringResultRow]
# aggregated metrics to value
aggregated_results: dict[str, Any]
@json_schema_type
class ScoreBatchResponse(BaseModel):
"""Response from batch scoring operations on datasets.
:param dataset_id: (Optional) The identifier of the dataset that was scored
:param results: A map of scoring function name to ScoringResult
"""
dataset_id: str | None = None
results: dict[str, ScoringResult]
@json_schema_type
class ScoreResponse(BaseModel):
"""
The response from scoring.
:param results: A map of scoring function name to ScoringResult.
"""
# each key in the dict is a scoring function name
results: dict[str, ScoringResult]
class ScoringFunctionStore(Protocol):
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
@runtime_checkable
class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score-batch", method="POST", level=LLAMA_STACK_API_V1)
async def score_batch(
self,
dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
"""Score a batch of rows.
:param dataset_id: The ID of the dataset to score.
:param scoring_functions: The scoring functions to use for the scoring.
:param save_results_dataset: Whether to save the results to a dataset.
:returns: A ScoreBatchResponse.
"""
...
@webmethod(route="/scoring/score", method="POST", level=LLAMA_STACK_API_V1)
async def score(
self,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None],
) -> ScoreResponse:
"""Score a list of rows.
:param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring.
:returns: A ScoreResponse object containing rows and aggregated results.
"""
...

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from .models import ScoreBatchResponse, ScoreResponse
class ScoringFunctionStore(Protocol):
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
@runtime_checkable
class ScoringService(Protocol):
scoring_function_store: ScoringFunctionStore
async def score_batch(
self,
dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
"""Score a batch of rows."""
...
async def score(
self,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None],
) -> ScoreResponse:
"""Score a list of rows."""
...

View file

@ -4,4 +4,38 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .scoring_functions import * # Import routes to trigger router registration
from . import routes # noqa: F401
from .models import (
AggregationFunctionType,
BasicScoringFnParams,
CommonScoringFnFields,
ListScoringFunctionsResponse,
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
RegisterScoringFunctionRequest,
ScoringFn,
ScoringFnInput,
ScoringFnParams,
ScoringFnParamsType,
)
from .scoring_functions_service import ScoringFunctionsService
# Backward compatibility - export ScoringFunctions as alias for ScoringFunctionsService
ScoringFunctions = ScoringFunctionsService
__all__ = [
"ScoringFunctions",
"ScoringFunctionsService",
"ScoringFn",
"ScoringFnInput",
"CommonScoringFnFields",
"ScoringFnParams",
"ScoringFnParamsType",
"LLMAsJudgeScoringFnParams",
"RegexParserScoringFnParams",
"BasicScoringFnParams",
"AggregationFunctionType",
"ListScoringFunctionsResponse",
"RegisterScoringFunctionRequest",
]

View file

@ -0,0 +1,143 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type
class ScoringFnParamsType(StrEnum):
"""Types of scoring function parameter configurations."""
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
basic = "basic"
@json_schema_type
class AggregationFunctionType(StrEnum):
"""Types of aggregation functions for scoring results."""
average = "average"
weighted_average = "weighted_average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
"""Parameters for LLM-as-judge scoring function configuration."""
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
judge_model: str
prompt_template: str | None = None
judge_score_regexes: list[str] = Field(
description="Regexes to extract the answer from generated response",
default_factory=lambda: [],
)
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=lambda: [],
)
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
"""Parameters for regex parser scoring function configuration."""
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
parsing_regexes: list[str] = Field(
description="Regex to extract the answer from generated response",
default_factory=lambda: [],
)
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=lambda: [],
)
@json_schema_type
class BasicScoringFnParams(BaseModel):
"""Parameters for basic scoring function configuration."""
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
ScoringFnParams = Annotated[
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
Field(discriminator="type"),
]
register_schema(ScoringFnParams, name="ScoringFnParams")
class CommonScoringFnFields(BaseModel):
description: str | None = None
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
params: ScoringFnParams | None = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
"""A scoring function resource for evaluating model outputs."""
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
@property
def scoring_fn_id(self) -> str:
return self.identifier
@property
def provider_scoring_fn_id(self) -> str | None:
return self.provider_resource_id
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: str | None = None
provider_scoring_fn_id: str | None = None
class ListScoringFunctionsResponse(BaseModel):
"""Response model for listing scoring functions."""
data: list[ScoringFn] = Field(..., description="List of scoring function resources")
@json_schema_type
class RegisterScoringFunctionRequest(BaseModel):
"""Request model for registering a scoring function."""
scoring_fn_id: str = Field(..., description="The ID of the scoring function to register")
description: str = Field(..., description="The description of the scoring function")
return_type: ParamType = Field(..., description="The return type of the scoring function")
provider_scoring_fn_id: str | None = Field(
default=None, description="The ID of the provider scoring function to use for the scoring function"
)
provider_id: str | None = Field(default=None, description="The ID of the provider to use for the scoring function")
params: ScoringFnParams | None = Field(
default=None,
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
)

View file

@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated
from fastapi import Body, Depends, Request
from fastapi import Path as FastAPIPath
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .models import (
ListScoringFunctionsResponse,
RegisterScoringFunctionRequest,
ScoringFn,
)
from .scoring_functions_service import ScoringFunctionsService
def get_scoring_functions_service(request: Request) -> ScoringFunctionsService:
"""Dependency to get the scoring functions service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.scoring_functions not in impls:
raise ValueError("Scoring Functions API implementation not found")
return impls[Api.scoring_functions]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Scoring Functions"],
responses=standard_responses,
)
@router.get(
"/scoring-functions",
response_model=ListScoringFunctionsResponse,
summary="List all scoring functions",
description="List all scoring functions",
)
async def list_scoring_functions(
svc: ScoringFunctionsService = Depends(get_scoring_functions_service),
) -> ListScoringFunctionsResponse:
"""List all scoring functions."""
return await svc.list_scoring_functions()
@router.get(
"/scoring-functions/{scoring_fn_id:path}",
response_model=ScoringFn,
summary="Get a scoring function by its ID",
description="Get a scoring function by its ID",
)
async def get_scoring_function(
scoring_fn_id: Annotated[str, FastAPIPath(..., description="The ID of the scoring function to get")],
svc: ScoringFunctionsService = Depends(get_scoring_functions_service),
) -> ScoringFn:
"""Get a scoring function by its ID."""
return await svc.get_scoring_function(scoring_fn_id)
@router.post(
"/scoring-functions",
response_model=None,
status_code=204,
summary="Register a scoring function",
description="Register a scoring function",
)
async def register_scoring_function(
body: RegisterScoringFunctionRequest = Body(...),
svc: ScoringFunctionsService = Depends(get_scoring_functions_service),
) -> None:
"""Register a scoring function."""
return await svc.register_scoring_function(
scoring_fn_id=body.scoring_fn_id,
description=body.description,
return_type=body.return_type,
provider_scoring_fn_id=body.provider_scoring_fn_id,
provider_id=body.provider_id,
params=body.params,
)
@router.delete(
"/scoring-functions/{scoring_fn_id:path}",
response_model=None,
status_code=204,
summary="Unregister a scoring function",
description="Unregister a scoring function",
)
async def unregister_scoring_function(
scoring_fn_id: Annotated[str, FastAPIPath(..., description="The ID of the scoring function to unregister")],
svc: ScoringFunctionsService = Depends(get_scoring_functions_service),
) -> None:
"""Unregister a scoring function."""
await svc.unregister_scoring_function(scoring_fn_id)
# For backward compatibility with the router registry system
def create_scoring_functions_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Scoring Functions API (legacy compatibility)."""
return router
# Register the router factory
register_router(Api.scoring_functions, create_scoring_functions_router)

View file

@ -1,208 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# TODO: use enum.StrEnum when we drop support for python 3.10
from enum import StrEnum
from typing import (
Annotated,
Any,
Literal,
Protocol,
runtime_checkable,
)
from pydantic import BaseModel, Field
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
@json_schema_type
class ScoringFnParamsType(StrEnum):
"""Types of scoring function parameter configurations.
:cvar llm_as_judge: Use an LLM model to evaluate and score responses
:cvar regex_parser: Use regex patterns to extract and score specific parts of responses
:cvar basic: Basic scoring with simple aggregation functions
"""
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
basic = "basic"
@json_schema_type
class AggregationFunctionType(StrEnum):
"""Types of aggregation functions for scoring results.
:cvar average: Calculate the arithmetic mean of scores
:cvar weighted_average: Calculate a weighted average of scores
:cvar median: Calculate the median value of scores
:cvar categorical_count: Count occurrences of categorical values
:cvar accuracy: Calculate accuracy as the proportion of correct answers
"""
average = "average"
weighted_average = "weighted_average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
"""Parameters for LLM-as-judge scoring function configuration.
:param type: The type of scoring function parameters, always llm_as_judge
:param judge_model: Identifier of the LLM model to use as a judge for scoring
:param prompt_template: (Optional) Custom prompt template for the judge model
:param judge_score_regexes: Regexes to extract the answer from generated response
:param aggregation_functions: Aggregation functions to apply to the scores of each row
"""
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
judge_model: str
prompt_template: str | None = None
judge_score_regexes: list[str] = Field(
description="Regexes to extract the answer from generated response",
default_factory=lambda: [],
)
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=lambda: [],
)
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
"""Parameters for regex parser scoring function configuration.
:param type: The type of scoring function parameters, always regex_parser
:param parsing_regexes: Regex to extract the answer from generated response
:param aggregation_functions: Aggregation functions to apply to the scores of each row
"""
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
parsing_regexes: list[str] = Field(
description="Regex to extract the answer from generated response",
default_factory=lambda: [],
)
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=lambda: [],
)
@json_schema_type
class BasicScoringFnParams(BaseModel):
"""Parameters for basic scoring function configuration.
:param type: The type of scoring function parameters, always basic
:param aggregation_functions: Aggregation functions to apply to the scores of each row
"""
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
ScoringFnParams = Annotated[
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
Field(discriminator="type"),
]
register_schema(ScoringFnParams, name="ScoringFnParams")
class CommonScoringFnFields(BaseModel):
description: str | None = None
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
params: ScoringFnParams | None = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
"""A scoring function resource for evaluating model outputs.
:param type: The resource type, always scoring_function
"""
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
@property
def scoring_fn_id(self) -> str:
return self.identifier
@property
def provider_scoring_fn_id(self) -> str | None:
return self.provider_resource_id
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: str | None = None
provider_scoring_fn_id: str | None = None
class ListScoringFunctionsResponse(BaseModel):
data: list[ScoringFn]
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET", level=LLAMA_STACK_API_V1)
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
"""List all scoring functions.
:returns: A ListScoringFunctionsResponse.
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
"""Get a scoring function by its ID.
:param scoring_fn_id: The ID of the scoring function to get.
:returns: A ScoringFn.
"""
...
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1)
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None:
"""Register a scoring function.
:param scoring_fn_id: The ID of the scoring function to register.
:param description: The description of the scoring function.
:param return_type: The return type of the scoring function.
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
:param provider_id: The ID of the provider to use for the scoring function.
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
"""Unregister a scoring function.
:param scoring_fn_id: The ID of the scoring function to unregister.
"""
...

Some files were not shown because too many files have changed in this diff Show more