chore(package): migrate to src/ layout (#3920)

Migrates package structure to src/ layout following Python packaging
best practices.

All code moved from `llama_stack/` to `src/llama_stack/`. Public API
unchanged - imports remain `import llama_stack.*`.

Updated build configs, pre-commit hooks, scripts, and GitHub workflows
accordingly. All hooks pass, package builds cleanly.

**Developer note**: Reinstall after pulling: `pip install -e .`
This commit is contained in:
Ashwin Bharambe 2025-10-27 12:02:21 -07:00 committed by GitHub
parent 98a5047f9d
commit 471b1b248b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
791 changed files with 2983 additions and 456 deletions

View file

@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.library_client import ( # noqa: F401
AsyncLlamaStackAsLibraryClient,
LlamaStackAsLibraryClient,
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .agents import *

View file

@ -0,0 +1,894 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from datetime import datetime
from enum import StrEnum
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import Order, PaginatedResponse
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
SamplingParams,
ToolCall,
ToolChoice,
ToolConfig,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
from .openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseText,
)
@json_schema_type
class ResponseGuardrailSpec(BaseModel):
"""Specification for a guardrail to apply during response generation.
:param type: The type/identifier of the guardrail.
"""
type: str
# TODO: more fields to be added for guardrail configuration
ResponseGuardrail = str | ResponseGuardrailSpec
class Attachment(BaseModel):
"""An attachment to an agent turn.
:param content: The content of the attachment.
:param mime_type: The MIME type of the attachment.
"""
content: InterleavedContent | URL
mime_type: str
class Document(BaseModel):
"""A document to be used by an agent.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
"""
content: InterleavedContent | URL
mime_type: str
class StepCommon(BaseModel):
"""A common step in an agent turn.
:param turn_id: The ID of the turn.
:param step_id: The ID of the step.
:param started_at: The time the step started.
:param completed_at: The time the step completed.
"""
turn_id: str
step_id: str
started_at: datetime | None = None
completed_at: datetime | None = None
class StepType(StrEnum):
"""Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM.
:cvar tool_execution: The step is a tool execution step that executes a tool call.
:cvar shield_call: The step is a shield call step that checks for safety violations.
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
"""
inference = "inference"
tool_execution = "tool_execution"
shield_call = "shield_call"
memory_retrieval = "memory_retrieval"
@json_schema_type
class InferenceStep(StepCommon):
"""An inference step in an agent turn.
:param model_response: The response from the LLM.
"""
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference] = StepType.inference
model_response: CompletionMessage
@json_schema_type
class ToolExecutionStep(StepCommon):
"""A tool execution step in an agent turn.
:param tool_calls: The tool calls to execute.
:param tool_responses: The tool responses from the tool calls.
"""
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
tool_calls: list[ToolCall]
tool_responses: list[ToolResponse]
@json_schema_type
class ShieldCallStep(StepCommon):
"""A shield call step in an agent turn.
:param violation: The violation from the shield call.
"""
step_type: Literal[StepType.shield_call] = StepType.shield_call
violation: SafetyViolation | None
@json_schema_type
class MemoryRetrievalStep(StepCommon):
"""A memory retrieval step in an agent turn.
:param vector_db_ids: The IDs of the vector databases to retrieve context from.
:param inserted_context: The context retrieved from the vector databases.
"""
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]?
vector_db_ids: str
inserted_context: InterleavedContent
Step = Annotated[
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
Field(discriminator="step_type"),
]
@json_schema_type
class Turn(BaseModel):
"""A single turn in an interaction with an Agentic System.
:param turn_id: Unique identifier for the turn within a session
:param session_id: Unique identifier for the conversation session
:param input_messages: List of messages that initiated this turn
:param steps: Ordered list of processing steps executed during this turn
:param output_message: The model's generated response containing content and metadata
:param output_attachments: (Optional) Files or media attached to the agent's response
:param started_at: Timestamp when the turn began
:param completed_at: (Optional) Timestamp when the turn finished, if completed
"""
turn_id: str
session_id: str
input_messages: list[UserMessage | ToolResponseMessage]
steps: list[Step]
output_message: CompletionMessage
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
started_at: datetime
completed_at: datetime | None = None
@json_schema_type
class Session(BaseModel):
"""A single session of an interaction with an Agentic System.
:param session_id: Unique identifier for the conversation session
:param session_name: Human-readable name for the session
:param turns: List of all turns that have occurred in this session
:param started_at: Timestamp when the session was created
"""
session_id: str
session_name: str
turns: list[Turn]
started_at: datetime
class AgentToolGroupWithArgs(BaseModel):
name: str
args: dict[str, Any]
AgentToolGroup = str | AgentToolGroupWithArgs
register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None)
max_infer_iters: int | None = 10
def model_post_init(self, __context):
if self.tool_config:
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
else:
params = {}
if self.tool_choice:
params["tool_choice"] = self.tool_choice
if self.tool_prompt_format:
params["tool_prompt_format"] = self.tool_prompt_format
self.tool_config = ToolConfig(**params)
@json_schema_type
class AgentConfig(AgentConfigCommon):
"""Configuration for an agent.
:param model: The model identifier to use for the agent
:param instructions: The system instructions for the agent
:param name: Optional name for the agent, used in telemetry and identification
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
:param response_format: Optional response format configuration
"""
model: str
instructions: str
name: str | None = None
enable_session_persistence: bool | None = False
response_format: ResponseFormat | None = None
@json_schema_type
class Agent(BaseModel):
"""An agent instance with configuration and metadata.
:param agent_id: Unique identifier for the agent
:param agent_config: Configuration settings for the agent
:param created_at: Timestamp when the agent was created
"""
agent_id: str
agent_config: AgentConfig
created_at: datetime
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = None
class AgentTurnResponseEventType(StrEnum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
turn_start = "turn_start"
turn_complete = "turn_complete"
turn_awaiting_input = "turn_awaiting_input"
@json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel):
"""Payload for step start events in agent turn responses.
:param event_type: Type of event being reported
:param step_type: Type of step being executed
:param step_id: Unique identifier for the step within a turn
:param metadata: (Optional) Additional metadata for the step
"""
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
step_type: StepType
step_id: str
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
@json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel):
"""Payload for step completion events in agent turn responses.
:param event_type: Type of event being reported
:param step_type: Type of step being executed
:param step_id: Unique identifier for the step within a turn
:param step_details: Complete details of the executed step
"""
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
step_type: StepType
step_id: str
step_details: Step
@json_schema_type
class AgentTurnResponseStepProgressPayload(BaseModel):
"""Payload for step progress events in agent turn responses.
:param event_type: Type of event being reported
:param step_type: Type of step being executed
:param step_id: Unique identifier for the step within a turn
:param delta: Incremental content changes during step execution
"""
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
step_type: StepType
step_id: str
delta: ContentDelta
@json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel):
"""Payload for turn start events in agent turn responses.
:param event_type: Type of event being reported
:param turn_id: Unique identifier for the turn within a session
"""
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
turn_id: str
@json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel):
"""Payload for turn completion events in agent turn responses.
:param event_type: Type of event being reported
:param turn: Complete turn data including all steps and results
"""
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
turn: Turn
@json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
"""Payload for turn awaiting input events in agent turn responses.
:param event_type: Type of event being reported
:param turn: Turn data when waiting for external tool responses
"""
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
turn: Turn
AgentTurnResponseEventPayload = Annotated[
AgentTurnResponseStepStartPayload
| AgentTurnResponseStepProgressPayload
| AgentTurnResponseStepCompletePayload
| AgentTurnResponseTurnStartPayload
| AgentTurnResponseTurnCompletePayload
| AgentTurnResponseTurnAwaitingInputPayload,
Field(discriminator="event_type"),
]
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
@json_schema_type
class AgentTurnResponseEvent(BaseModel):
"""An event in an agent turn response stream.
:param payload: Event-specific payload containing event data
"""
payload: AgentTurnResponseEventPayload
@json_schema_type
class AgentCreateResponse(BaseModel):
"""Response returned when creating a new agent.
:param agent_id: Unique identifier for the created agent
"""
agent_id: str
@json_schema_type
class AgentSessionCreateResponse(BaseModel):
"""Response returned when creating a new agent session.
:param session_id: Unique identifier for the created session
"""
session_id: str
@json_schema_type
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
"""Request to create a new turn for an agent.
:param agent_id: Unique identifier for the agent
:param session_id: Unique identifier for the conversation session
:param messages: List of messages to start the turn with
:param documents: (Optional) List of documents to provide to the agent
:param toolgroups: (Optional) List of tool groups to make available for this turn
:param stream: (Optional) Whether to stream the response
:param tool_config: (Optional) Tool configuration to override agent defaults
"""
agent_id: str
session_id: str
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: list[UserMessage | ToolResponseMessage]
documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
stream: bool | None = False
tool_config: ToolConfig | None = None
@json_schema_type
class AgentTurnResumeRequest(BaseModel):
"""Request to resume an agent turn with tool responses.
:param agent_id: Unique identifier for the agent
:param session_id: Unique identifier for the conversation session
:param turn_id: Unique identifier for the turn within a session
:param tool_responses: List of tool responses to submit to continue the turn
:param stream: (Optional) Whether to stream the response
"""
agent_id: str
session_id: str
turn_id: str
tool_responses: list[ToolResponse]
stream: bool | None = False
@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
"""Streamed agent turn completion response.
:param event: Individual event in the agent turn response stream
"""
event: AgentTurnResponseEvent
@json_schema_type
class AgentStepResponse(BaseModel):
"""Response containing details of a specific agent step.
:param step: The complete step data and execution details
"""
step: Step
@runtime_checkable
class Agents(Protocol):
"""Agents
APIs for creating and interacting with agentic systems."""
@webmethod(
route="/agents",
method="POST",
descriptive_name="create_agent",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents",
method="POST",
descriptive_name="create_agent",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse:
"""Create an agent with the given configuration.
:param agent_config: The configuration for the agent.
:returns: An AgentCreateResponse with the agent ID.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
descriptive_name="create_agent_turn",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
descriptive_name="create_agent_turn",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Create a new turn for an agent.
:param agent_id: The ID of the agent to create the turn for.
:param session_id: The ID of the session to create the turn for.
:param messages: List of messages to start the turn with.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param documents: (Optional) List of documents to create the turn with.
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
level=LLAMA_STACK_API_V1ALPHA,
)
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: list[ToolResponse],
stream: bool | None = False,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
:param agent_id: The ID of the agent to resume.
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
) -> Turn:
"""Retrieve an agent turn by its ID.
:param agent_id: The ID of the agent to get the turn for.
:param session_id: The ID of the session to get the turn for.
:param turn_id: The ID of the turn to get.
:returns: A Turn.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_step(
self,
agent_id: str,
session_id: str,
turn_id: str,
step_id: str,
) -> AgentStepResponse:
"""Retrieve an agent step by its ID.
:param agent_id: The ID of the agent to get the step for.
:param session_id: The ID of the session to get the step for.
:param turn_id: The ID of the turn to get the step for.
:param step_id: The ID of the step to get.
:returns: An AgentStepResponse.
"""
...
@webmethod(
route="/agents/{agent_id}/session",
method="POST",
descriptive_name="create_agent_session",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session",
method="POST",
descriptive_name="create_agent_session",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
"""Create a new session for an agent.
:param agent_id: The ID of the agent to create the session for.
:param session_name: The name of the session to create.
:returns: An AgentSessionCreateResponse.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_session(
self,
session_id: str,
agent_id: str,
turn_ids: list[str] | None = None,
) -> Session:
"""Retrieve an agent session by its ID.
:param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by.
:returns: A Session.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="DELETE",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="DELETE",
level=LLAMA_STACK_API_V1ALPHA,
)
async def delete_agents_session(
self,
session_id: str,
agent_id: str,
) -> None:
"""Delete an agent session by its ID and its associated turns.
:param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for.
"""
...
@webmethod(
route="/agents/{agent_id}",
method="DELETE",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def delete_agent(
self,
agent_id: str,
) -> None:
"""Delete an agent by its ID and its associated sessions and turns.
:param agent_id: The ID of the agent to delete.
"""
...
@webmethod(route="/agents", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents.
:param start_index: The index to start the pagination from.
:param limit: The number of agents to return.
:returns: A PaginatedResponse.
"""
...
@webmethod(
route="/agents/{agent_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
:param agent_id: ID of the agent.
:returns: An Agent of the agent.
"""
...
@webmethod(
route="/agents/{agent_id}/sessions",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agent_sessions(
self,
agent_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
:param start_index: The index to start the pagination from.
:param limit: The number of sessions to return.
:returns: A PaginatedResponse.
"""
...
# We situate the OpenAI Responses API in the Agents API just like we did things
# for Inference. The Responses API, in its intent, serves the same purpose as
# the Agents API above -- it is essentially a lightweight "agentic loop" with
# integrated tool calling.
#
# Both of these APIs are inherently stateful.
@webmethod(
route="/openai/v1/responses/{response_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_openai_response(
self,
response_id: str,
) -> OpenAIResponseObject:
"""Get a model response.
:param response_id: The ID of the OpenAI response to retrieve.
:returns: An OpenAIResponseObject.
"""
...
@webmethod(route="/openai/v1/responses", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
guardrails: Annotated[
list[ResponseGuardrail] | None,
ExtraBodyField(
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
),
] = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a model response.
:param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response.
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
:returns: An OpenAIResponseObject.
"""
...
@webmethod(route="/openai/v1/responses", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
"""List all responses.
:param after: The ID of the last response to return.
:param limit: The number of responses to return.
:param model: The model to filter responses by.
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
:returns: A ListOpenAIResponseObject.
"""
...
@webmethod(
route="/openai/v1/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
:param before: An item ID to list items before, used for pagination.
:param include: Additional fields to include in the response.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
:param order: The order to return the input items in. Default is desc.
:returns: An ListOpenAIResponseInputItem.
"""
...
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
"""Delete a response.
:param response_id: The ID of the OpenAI response to delete.
:returns: An OpenAIDeleteResponseObject
"""
...

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .batches import Batches, BatchObject, ListBatchesResponse
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
try:
from openai.types import Batch as BatchObject
except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
@json_schema_type
class ListBatchesResponse(BaseModel):
"""Response containing a list of batch objects."""
object: Literal["list"] = "list"
data: list[BatchObject] = Field(..., description="List of batch objects")
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
has_more: bool = Field(default=False, description="Whether there are more batches available")
@runtime_checkable
class Batches(Protocol):
"""
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions:
- idempotent batch creation
Note: This API is currently under active development and may undergo changes.
"""
@webmethod(route="/openai/v1/batches", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject:
"""Create a new batch for processing multiple API requests.
:param input_file_id: The ID of an uploaded file containing requests for the batch.
:param endpoint: The endpoint to be used for all requests in the batch.
:param completion_window: The time window within which the batch should be processed.
:param metadata: Optional metadata for the batch.
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
:returns: The created batch object.
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch.
:param batch_id: The ID of the batch to retrieve.
:returns: The batch object.
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress.
:param batch_id: The ID of the batch to cancel.
:returns: The updated batch object.
"""
...
@webmethod(route="/openai/v1/batches", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
) -> ListBatchesResponse:
"""List all batches for the current user.
:param after: A cursor for pagination; returns batches after this batch ID.
:param limit: Number of batches to return (default 20, max 100).
:returns: A list of batch objects.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .benchmarks import *

View file

@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonBenchmarkFields(BaseModel):
dataset_id: str
scoring_functions: list[str]
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Metadata for this evaluation task",
)
@json_schema_type
class Benchmark(CommonBenchmarkFields, Resource):
"""A benchmark resource for evaluating model performance.
:param dataset_id: Identifier of the dataset to use for the benchmark evaluation
:param scoring_functions: List of scoring function identifiers to apply during evaluation
:param metadata: Metadata for this evaluation task
:param type: The resource type, always benchmark
"""
type: Literal[ResourceType.benchmark] = ResourceType.benchmark
@property
def benchmark_id(self) -> str:
return self.identifier
@property
def provider_benchmark_id(self) -> str | None:
return self.provider_resource_id
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
benchmark_id: str
provider_id: str | None = None
provider_benchmark_id: str | None = None
class ListBenchmarksResponse(BaseModel):
data: list[Benchmark]
@runtime_checkable
class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_benchmarks(self) -> ListBenchmarksResponse:
"""List all benchmarks.
:returns: A ListBenchmarksResponse.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_benchmark(
self,
benchmark_id: str,
) -> Benchmark:
"""Get a benchmark by its ID.
:param benchmark_id: The ID of the benchmark to get.
:returns: A Benchmark.
"""
...
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: list[str],
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Register a benchmark.
:param benchmark_id: The ID of the benchmark to register.
:param dataset_id: The ID of the dataset to use for the benchmark.
:param scoring_functions: The scoring functions to use for the benchmark.
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
:param provider_id: The ID of the provider to use for the benchmark.
:param metadata: The metadata to use for the benchmark.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark.
:param benchmark_id: The ID of the benchmark to unregister.
"""
...

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,143 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, Literal
from pydantic import BaseModel, Field, model_validator
from llama_stack.models.llama.datatypes import ToolCall
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type
class URL(BaseModel):
"""A URL reference to external content.
:param uri: The URL string pointing to the resource
"""
uri: str
class _URLOrData(BaseModel):
"""
A URL or a base64 encoded string
:param url: A URL of the image or data URL in the format of data:image/{type};base64,{data}. Note that URL could have length limits.
:param data: base64 encoded image data as string
"""
url: URL | None = None
# data is a base64 encoded string, hint with contentEncoding=base64
data: str | None = Field(default=None, json_schema_extra={"contentEncoding": "base64"})
@model_validator(mode="before")
@classmethod
def validator(cls, values):
if isinstance(values, dict):
return values
return {"url": values}
@json_schema_type
class ImageContentItem(BaseModel):
"""A image content item
:param type: Discriminator type of the content item. Always "image"
:param image: Image as a base64 encoded string or an URL
"""
type: Literal["image"] = "image"
image: _URLOrData
@json_schema_type
class TextContentItem(BaseModel):
"""A text content item
:param type: Discriminator type of the content item. Always "text"
:param text: Text content
"""
type: Literal["text"] = "text"
text: str
# other modalities can be added here
InterleavedContentItem = Annotated[
ImageContentItem | TextContentItem,
Field(discriminator="type"),
]
register_schema(InterleavedContentItem, name="InterleavedContentItem")
# accept a single "str" as a special case since it is common
InterleavedContent = str | InterleavedContentItem | list[InterleavedContentItem]
register_schema(InterleavedContent, name="InterleavedContent")
@json_schema_type
class TextDelta(BaseModel):
"""A text content delta for streaming responses.
:param type: Discriminator type of the delta. Always "text"
:param text: The incremental text content
"""
type: Literal["text"] = "text"
text: str
@json_schema_type
class ImageDelta(BaseModel):
"""An image content delta for streaming responses.
:param type: Discriminator type of the delta. Always "image"
:param image: The incremental image data as bytes
"""
type: Literal["image"] = "image"
image: bytes
class ToolCallParseStatus(Enum):
"""Status of tool call parsing during streaming.
:cvar started: Tool call parsing has begun
:cvar in_progress: Tool call parsing is ongoing
:cvar failed: Tool call parsing failed
:cvar succeeded: Tool call parsing completed successfully
"""
started = "started"
in_progress = "in_progress"
failed = "failed"
succeeded = "succeeded"
@json_schema_type
class ToolCallDelta(BaseModel):
"""A tool call content delta for streaming responses.
:param type: Discriminator type of the delta. Always "tool_call"
:param tool_call: Either an in-progress tool call string or the final parsed tool call
:param parse_status: Current parsing status of the tool call
"""
type: Literal["tool_call"] = "tool_call"
# you either send an in-progress tool call so the client can stream a long
# code generation or you send the final parsed tool call at the end of the
# stream
tool_call: str | ToolCall
parse_status: ToolCallParseStatus
# streaming completions send a stream of ContentDeltas
ContentDelta = Annotated[
TextDelta | ImageDelta | ToolCallDelta,
Field(discriminator="type"),
]
register_schema(ContentDelta, name="ContentDelta")

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Custom Llama Stack Exception classes should follow the following schema
# 1. All classes should inherit from an existing Built-In Exception class: https://docs.python.org/3/library/exceptions.html
# 2. All classes should have a custom error message with the goal of informing the Llama Stack user specifically
# 3. All classes should propogate the inherited __init__ function otherwise via 'super().__init__(message)'
class ResourceNotFoundError(ValueError):
"""generic exception for a missing Llama Stack resource"""
def __init__(self, resource_name: str, resource_type: str, client_list: str) -> None:
message = (
f"{resource_type} '{resource_name}' not found. Use '{client_list}' to list available {resource_type}s."
)
super().__init__(message)
class UnsupportedModelError(ValueError):
"""raised when model is not present in the list of supported models"""
def __init__(self, model_name: str, supported_models_list: list[str]):
message = f"'{model_name}' model is not supported. Supported models are: {', '.join(supported_models_list)}"
super().__init__(message)
class ModelNotFoundError(ResourceNotFoundError):
"""raised when Llama Stack cannot find a referenced model"""
def __init__(self, model_name: str) -> None:
super().__init__(model_name, "Model", "client.models.list()")
class VectorStoreNotFoundError(ResourceNotFoundError):
"""raised when Llama Stack cannot find a referenced vector store"""
def __init__(self, vector_store_name: str) -> None:
super().__init__(vector_store_name, "Vector Store", "client.vector_dbs.list()")
class DatasetNotFoundError(ResourceNotFoundError):
"""raised when Llama Stack cannot find a referenced dataset"""
def __init__(self, dataset_name: str) -> None:
super().__init__(dataset_name, "Dataset", "client.datasets.list()")
class ToolGroupNotFoundError(ResourceNotFoundError):
"""raised when Llama Stack cannot find a referenced tool group"""
def __init__(self, toolgroup_name: str) -> None:
super().__init__(toolgroup_name, "Tool Group", "client.toolgroups.list()")
class SessionNotFoundError(ValueError):
"""raised when Llama Stack cannot find a referenced session or access is denied"""
def __init__(self, session_name: str) -> None:
message = f"Session '{session_name}' not found or access denied."
super().__init__(message)
class ModelTypeError(TypeError):
"""raised when a model is present but not the correct type"""
def __init__(self, model_name: str, model_type: str, expected_model_type: str) -> None:
message = (
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
)
super().__init__(message)
class ConflictError(ValueError):
"""raised when an operation cannot be performed due to a conflict with the current state"""
def __init__(self, message: str) -> None:
super().__init__(message)
class TokenValidationError(ValueError):
"""raised when token validation fails during authentication"""
def __init__(self, message: str) -> None:
super().__init__(message)
class ConversationNotFoundError(ResourceNotFoundError):
"""raised when Llama Stack cannot find a referenced conversation"""
def __init__(self, conversation_id: str) -> None:
super().__init__(conversation_id, "Conversation", "client.conversations.list()")
class InvalidConversationIdError(ValueError):
"""raised when a conversation ID has an invalid format"""
def __init__(self, conversation_id: str) -> None:
message = f"Invalid conversation ID '{conversation_id}'. Expected an ID that begins with 'conv_'."
super().__init__(message)

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
class JobStatus(Enum):
"""Status of a job execution.
:cvar completed: Job has finished successfully
:cvar in_progress: Job is currently running
:cvar failed: Job has failed during execution
:cvar scheduled: Job is scheduled but not yet started
:cvar cancelled: Job was cancelled before completion
"""
completed = "completed"
in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"
cancelled = "cancelled"
@json_schema_type
class Job(BaseModel):
"""A job execution instance with status tracking.
:param job_id: Unique identifier for the job
:param status: Current execution status of the job
"""
job_id: str
status: JobStatus

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
class Order(Enum):
"""Sort order for paginated responses.
:cvar asc: Ascending order
:cvar desc: Descending order
"""
asc = "asc"
desc = "desc"
@json_schema_type
class PaginatedResponse(BaseModel):
"""A generic paginated response that follows a simple format.
:param data: The list of items for the current page
:param has_more: Whether there are more items available after this set
:param url: The URL for accessing this list
"""
data: list[dict[str, Any]]
has_more: bool
url: str | None = None

View file

@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class PostTrainingMetric(BaseModel):
"""Training metrics captured during post-training jobs.
:param epoch: Training epoch number
:param train_loss: Loss value on the training dataset
:param validation_loss: Loss value on the validation dataset
:param perplexity: Perplexity metric indicating model confidence
"""
epoch: int
train_loss: float
validation_loss: float
perplexity: float
@json_schema_type
class Checkpoint(BaseModel):
"""Checkpoint created during training runs.
:param identifier: Unique identifier for the checkpoint
:param created_at: Timestamp when the checkpoint was created
:param epoch: Training epoch when the checkpoint was saved
:param post_training_job_id: Identifier of the training job that created this checkpoint
:param path: File system path where the checkpoint is stored
:param training_metrics: (Optional) Training metrics associated with this checkpoint
"""
identifier: str
created_at: datetime
epoch: int
post_training_job_id: str
path: str
training_metrics: PostTrainingMetric | None = None

View file

@ -0,0 +1,158 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Literal
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type
class StringType(BaseModel):
"""Parameter type for string values.
:param type: Discriminator type. Always "string"
"""
type: Literal["string"] = "string"
@json_schema_type
class NumberType(BaseModel):
"""Parameter type for numeric values.
:param type: Discriminator type. Always "number"
"""
type: Literal["number"] = "number"
@json_schema_type
class BooleanType(BaseModel):
"""Parameter type for boolean values.
:param type: Discriminator type. Always "boolean"
"""
type: Literal["boolean"] = "boolean"
@json_schema_type
class ArrayType(BaseModel):
"""Parameter type for array values.
:param type: Discriminator type. Always "array"
"""
type: Literal["array"] = "array"
@json_schema_type
class ObjectType(BaseModel):
"""Parameter type for object values.
:param type: Discriminator type. Always "object"
"""
type: Literal["object"] = "object"
@json_schema_type
class JsonType(BaseModel):
"""Parameter type for JSON values.
:param type: Discriminator type. Always "json"
"""
type: Literal["json"] = "json"
@json_schema_type
class UnionType(BaseModel):
"""Parameter type for union values.
:param type: Discriminator type. Always "union"
"""
type: Literal["union"] = "union"
@json_schema_type
class ChatCompletionInputType(BaseModel):
"""Parameter type for chat completion input.
:param type: Discriminator type. Always "chat_completion_input"
"""
# expects List[Message] for messages
type: Literal["chat_completion_input"] = "chat_completion_input"
@json_schema_type
class CompletionInputType(BaseModel):
"""Parameter type for completion input.
:param type: Discriminator type. Always "completion_input"
"""
# expects InterleavedTextMedia for content
type: Literal["completion_input"] = "completion_input"
@json_schema_type
class AgentTurnInputType(BaseModel):
"""Parameter type for agent turn input.
:param type: Discriminator type. Always "agent_turn_input"
"""
# expects List[Message] for messages (may also include attachments?)
type: Literal["agent_turn_input"] = "agent_turn_input"
@json_schema_type
class DialogType(BaseModel):
"""Parameter type for dialog data with semantic output labels.
:param type: Discriminator type. Always "dialog"
"""
# expects List[Message] for messages
# this type semantically contains the output label whereas ChatCompletionInputType does not
type: Literal["dialog"] = "dialog"
ParamType = Annotated[
StringType
| NumberType
| BooleanType
| ArrayType
| ObjectType
| JsonType
| UnionType
| ChatCompletionInputType
| CompletionInputType
| AgentTurnInputType,
Field(discriminator="type"),
]
register_schema(ParamType, name="ParamType")
"""
# TODO: recursive definition of ParamType in these containers
# will cause infinite recursion in OpenAPI generation script
# since we are going with ChatCompletionInputType and CompletionInputType
# we don't need to worry about ArrayType/ObjectType/UnionType for now
ArrayType.model_rebuild()
ObjectType.model_rebuild()
UnionType.model_rebuild()
class CustomType(BaseModel):
pylint: disable=syntax-error
type: Literal["custom"] = "custom"
validator_class: str
"""

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .conversations import (
Conversation,
ConversationCreateRequest,
ConversationDeletedResource,
ConversationItem,
ConversationItemCreateRequest,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
ConversationUpdateRequest,
Metadata,
)
__all__ = [
"Conversation",
"ConversationCreateRequest",
"ConversationDeletedResource",
"ConversationItem",
"ConversationItemCreateRequest",
"ConversationItemDeletedResource",
"ConversationItemList",
"Conversations",
"ConversationUpdateRequest",
"Metadata",
]

View file

@ -0,0 +1,298 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Annotated, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseMessage,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
Metadata = dict[str, str]
@json_schema_type
class Conversation(BaseModel):
"""OpenAI-compatible conversation object."""
id: str = Field(..., description="The unique ID of the conversation.")
object: Literal["conversation"] = Field(
default="conversation", description="The object type, which is always conversation."
)
created_at: int = Field(
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
)
metadata: Metadata | None = Field(
default=None,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
)
items: list[dict] | None = Field(
default=None,
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
)
@json_schema_type
class ConversationMessage(BaseModel):
"""OpenAI-compatible message item for conversations."""
id: str = Field(..., description="unique identifier for this message")
content: list[dict] = Field(..., description="message content")
role: str = Field(..., description="message role")
status: str = Field(..., description="message status")
type: Literal["message"] = "message"
object: Literal["message"] = "message"
ConversationItem = Annotated[
OpenAIResponseMessage
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseInputFunctionToolCallOutput
| OpenAIResponseMCPApprovalRequest
| OpenAIResponseMCPApprovalResponse
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"),
]
register_schema(ConversationItem, name="ConversationItem")
# Using OpenAI types directly caused issues but some notes for reference:
# Note that ConversationItem is a Annotated Union of the types below:
# from openai.types.responses import *
# from openai.types.responses.response_item import *
# from openai.types.conversations import ConversationItem
# f = [
# ResponseFunctionToolCallItem,
# ResponseFunctionToolCallOutputItem,
# ResponseFileSearchToolCall,
# ResponseFunctionWebSearch,
# ImageGenerationCall,
# ResponseComputerToolCall,
# ResponseComputerToolCallOutputItem,
# ResponseReasoningItem,
# ResponseCodeInterpreterToolCall,
# LocalShellCall,
# LocalShellCallOutput,
# McpListTools,
# McpApprovalRequest,
# McpApprovalResponse,
# McpCall,
# ResponseCustomToolCall,
# ResponseCustomToolCallOutput
# ]
@json_schema_type
class ConversationCreateRequest(BaseModel):
"""Request body for creating a conversation."""
items: list[ConversationItem] | None = Field(
default=[],
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
metadata: Metadata | None = Field(
default={},
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
max_length=16,
)
@json_schema_type
class ConversationUpdateRequest(BaseModel):
"""Request body for updating a conversation."""
metadata: Metadata = Field(
...,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
)
@json_schema_type
class ConversationDeletedResource(BaseModel):
"""Response for deleted conversation."""
id: str = Field(..., description="The deleted conversation identifier")
object: str = Field(default="conversation.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted")
@json_schema_type
class ConversationItemCreateRequest(BaseModel):
"""Request body for creating conversation items."""
items: list[ConversationItem] = Field(
...,
description="Items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
class ConversationItemInclude(StrEnum):
"""
Specify additional output data to include in the model response.
"""
web_search_call_action_sources = "web_search_call.action.sources"
code_interpreter_call_outputs = "code_interpreter_call.outputs"
computer_call_output_output_image_url = "computer_call_output.output.image_url"
file_search_call_results = "file_search_call.results"
message_input_image_image_url = "message.input_image.image_url"
message_output_text_logprobs = "message.output_text.logprobs"
reasoning_encrypted_content = "reasoning.encrypted_content"
@json_schema_type
class ConversationItemList(BaseModel):
"""List of conversation items with pagination."""
object: str = Field(default="list", description="Object type")
data: list[ConversationItem] = Field(..., description="List of conversation items")
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
has_more: bool = Field(default=False, description="Whether there are more items available")
@json_schema_type
class ConversationItemDeletedResource(BaseModel):
"""Response for deleted conversation item."""
id: str = Field(..., description="The deleted item identifier")
object: str = Field(default="conversation.item.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted")
@runtime_checkable
@trace_protocol
class Conversations(Protocol):
"""Conversations
Protocol for conversation management operations."""
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation.
Create a conversation.
:param items: Initial items to include in the conversation context.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The created conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Retrieve a conversation.
Get a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation.
Update a conversation's metadata with the given ID.
:param conversation_id: The conversation identifier.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The updated conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation.
Delete a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The deleted conversation resource.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create items.
Create items in the conversation.
:param conversation_id: The conversation identifier.
:param items: Items to include in the conversation context.
:returns: List of created items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve an item.
Retrieve a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The conversation item.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
async def list_items(
self,
conversation_id: str,
after: str | None = None,
include: list[ConversationItemInclude] | None = None,
limit: int | None = None,
order: Literal["asc", "desc"] | None = None,
) -> ConversationItemList:
"""List items.
List items in the conversation.
:param conversation_id: The conversation identifier.
:param after: An item ID to list items after, used in pagination.
:param include: Specify additional output data to include in the response.
:param limit: A limit on the number of objects to be returned (1-100, default 20).
:param order: The order to return items in (asc or desc, default desc).
:returns: List of conversation items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete an item.
Delete a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The deleted item resource.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datasetio import *

View file

@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
from llama_stack.schema_utils import webmethod
class DatasetStore(Protocol):
def get_dataset(self, dataset_id: str) -> Dataset: ...
@runtime_checkable
class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
async def iterrows(
self,
dataset_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""Get a paginated list of rows from a dataset.
Uses offset-based pagination where:
- start_index: The starting index (0-based). If None, starts from beginning.
- limit: Number of items to return. If None or -1, returns all items.
The response includes:
- data: List of items for the current page.
- has_more: Whether there are more items available after this set.
:param dataset_id: The ID of the dataset to get the rows from.
:param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get.
:returns: A PaginatedResponse.
"""
...
@webmethod(
route="/datasetio/append-rows/{dataset_id:path}", method="POST", deprecated=True, level=LLAMA_STACK_API_V1
)
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1BETA)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
"""Append rows to a dataset.
:param dataset_id: The ID of the dataset to append the rows to.
:param rows: The rows to append to the dataset.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datasets import *

View file

@ -0,0 +1,251 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class DatasetPurpose(StrEnum):
"""
Purpose of the dataset. Each purpose has a required input data schema.
:cvar post-training/messages: The dataset contains messages used for post-training.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
:cvar eval/question-answer: The dataset contains a question column and an answer column.
{
"question": "What is the capital of France?",
"answer": "Paris"
}
:cvar eval/messages-answer: The dataset contains a messages column with list of messages and an answer column.
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
"""
post_training_messages = "post-training/messages"
eval_question_answer = "eval/question-answer"
eval_messages_answer = "eval/messages-answer"
# TODO: add more schemas here
class DatasetType(Enum):
"""
Type of the dataset source.
:cvar uri: The dataset can be obtained from a URI.
:cvar rows: The dataset is stored in rows.
"""
uri = "uri"
rows = "rows"
@json_schema_type
class URIDataSource(BaseModel):
"""A dataset that can be obtained from a URI.
:param uri: The dataset can be obtained from a URI. E.g.
- "https://mywebsite.com/mydata.jsonl"
- "lsfs://mydata.jsonl"
- "data:csv;base64,{base64_content}"
"""
type: Literal["uri"] = "uri"
uri: str
@json_schema_type
class RowsDataSource(BaseModel):
"""A dataset stored in rows.
:param rows: The dataset is stored in rows. E.g.
- [
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
]
"""
type: Literal["rows"] = "rows"
rows: list[dict[str, Any]]
DataSource = Annotated[
URIDataSource | RowsDataSource,
Field(discriminator="type"),
]
register_schema(DataSource, name="DataSource")
class CommonDatasetFields(BaseModel):
"""
Common fields for a dataset.
:param purpose: Purpose of the dataset indicating its intended use
:param source: Data source configuration for the dataset
:param metadata: Additional metadata for the dataset
"""
purpose: DatasetPurpose
source: DataSource
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this dataset",
)
@json_schema_type
class Dataset(CommonDatasetFields, Resource):
"""Dataset resource for storing and accessing training or evaluation data.
:param type: Type of resource, always 'dataset' for datasets
"""
type: Literal[ResourceType.dataset] = ResourceType.dataset
@property
def dataset_id(self) -> str:
return self.identifier
@property
def provider_dataset_id(self) -> str | None:
return self.provider_resource_id
class DatasetInput(CommonDatasetFields, BaseModel):
"""Input parameters for dataset operations.
:param dataset_id: Unique identifier for the dataset
"""
dataset_id: str
class ListDatasetsResponse(BaseModel):
"""Response from listing datasets.
:param data: List of datasets
"""
data: list[Dataset]
class Datasets(Protocol):
@webmethod(route="/datasets", method="POST", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
"""
Register a new dataset.
:param purpose: The purpose of the dataset.
One of:
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
{
"question": "What is the capital of France?",
"answer": "Paris"
}
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples:
- {
"type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl"
}
- {
"type": "uri",
"uri": "lsfs://mydata.jsonl"
}
- {
"type": "uri",
"uri": "data:csv;base64,{base64_content}"
}
- {
"type": "uri",
"uri": "huggingface://llamastack/simpleqa?split=train"
}
- {
"type": "rows",
"rows": [
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
]
}
:param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"}.
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
:returns: A Dataset.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
async def get_dataset(
self,
dataset_id: str,
) -> Dataset:
"""Get a dataset by its ID.
:param dataset_id: The ID of the dataset to get.
:returns: A Dataset.
"""
...
@webmethod(route="/datasets", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1BETA)
async def list_datasets(self) -> ListDatasetsResponse:
"""List all datasets.
:returns: A ListDatasetsResponse.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
async def unregister_dataset(
self,
dataset_id: str,
) -> None:
"""Unregister a dataset by its ID.
:param dataset_id: The ID of the dataset to unregister.
"""
...

View file

@ -0,0 +1,158 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, EnumMeta
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class DynamicApiMeta(EnumMeta):
def __new__(cls, name, bases, namespace):
# Store the original enum values
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
# Create the enum class
cls = super().__new__(cls, name, bases, namespace)
# Store the original values for reference
cls._original_values = original_values
# Initialize _dynamic_values
cls._dynamic_values = {}
return cls
def __call__(cls, value):
try:
return super().__call__(value)
except ValueError as e:
# If this value was already dynamically added, return it
if value in cls._dynamic_values:
return cls._dynamic_values[value]
# If the value doesn't exist, create a new enum member
# Create a new member name from the value
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return the existing member
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
# register new APIs explicitly
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
def __iter__(cls):
# Allow iteration over both static and dynamic members
yield from super().__iter__()
if hasattr(cls, "_dynamic_values"):
yield from cls._dynamic_values.values()
def add(cls, value):
"""
Add a new API to the enum.
Used to register external APIs.
"""
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return it
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Create a new enum member
member = object.__new__(cls)
member._name_ = member_name
member._value_ = value
# Add it to the enum class
cls._member_map_[member_name] = member
cls._member_names_.append(member_name)
cls._member_type_ = str
# Store it in our dynamic values
cls._dynamic_values[value] = member
return member
@json_schema_type
class Api(Enum, metaclass=DynamicApiMeta):
"""Enumeration of all available APIs in the Llama Stack system.
:cvar providers: Provider management and configuration
:cvar inference: Text generation, chat completions, and embeddings
:cvar safety: Content moderation and safety shields
:cvar agents: Agent orchestration and execution
:cvar batches: Batch processing for asynchronous API requests
:cvar vector_io: Vector database operations and queries
:cvar datasetio: Dataset input/output operations
:cvar scoring: Model output evaluation and scoring
:cvar eval: Model evaluation and benchmarking framework
:cvar post_training: Fine-tuning and model training
:cvar tool_runtime: Tool execution and management
:cvar telemetry: Observability and system monitoring
:cvar models: Model metadata and management
:cvar shields: Safety shield implementations
:cvar datasets: Dataset creation and management
:cvar scoring_functions: Scoring function definitions
:cvar benchmarks: Benchmark suite management
:cvar tool_groups: Tool group organization
:cvar files: File storage and management
:cvar prompts: Prompt versions and management
:cvar inspect: Built-in system inspection and introspection
"""
providers = "providers"
inference = "inference"
safety = "safety"
agents = "agents"
batches = "batches"
vector_io = "vector_io"
datasetio = "datasetio"
scoring = "scoring"
eval = "eval"
post_training = "post_training"
tool_runtime = "tool_runtime"
models = "models"
shields = "shields"
vector_stores = "vector_stores" # only used for routing table
datasets = "datasets"
scoring_functions = "scoring_functions"
benchmarks = "benchmarks"
tool_groups = "tool_groups"
files = "files"
prompts = "prompts"
conversations = "conversations"
# built-in API
inspect = "inspect"
@json_schema_type
class Error(BaseModel):
"""
Error response from the API. Roughly follows RFC 7807.
:param status: HTTP status code
:param title: Error title, a short summary of the error which is invariant for an error type
:param detail: Error detail, a longer human-readable description of the error
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
"""
status: int
title: str
detail: str
instance: str | None = None
class ExternalApiSpec(BaseModel):
"""Specification for an external API implementation."""
module: str = Field(..., description="Python module containing the API implementation")
name: str = Field(..., description="Name of the API")
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
protocol: str = Field(..., description="Name of the protocol class for the API")

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .eval import *

View file

@ -0,0 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class ModelCandidate(BaseModel):
"""A model candidate for evaluation.
:param model: The model ID to evaluate.
:param sampling_params: The sampling parameters for the model.
:param system_message: (Optional) The system message providing instructions or context to the model.
"""
type: Literal["model"] = "model"
model: str
sampling_params: SamplingParams
system_message: SystemMessage | None = None
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
config: AgentConfig
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate")
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
:param eval_candidate: The candidate to evaluate.
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
"""
eval_candidate: EvalCandidate
scoring_params: dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
num_examples: int | None = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
# we could optinally add any specific dataset config here
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation.
:param generations: The generations from the evaluation.
:param scores: The scores from the evaluation.
"""
generations: list[dict[str, Any]]
# each key in the dict is a scoring function name
scores: dict[str, ScoringResult]
class Eval(Protocol):
"""Evaluations
Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:returns: The job that was created to run the evaluation.
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:returns: EvaluateResponse object containing generations and scores.
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:returns: The status of the evaluation job.
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to cancel.
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET", level=LLAMA_STACK_API_V1ALPHA
)
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:returns: The result of the job.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .files import *

View file

@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Annotated, ClassVar, Literal, Protocol, runtime_checkable
from fastapi import File, Form, Response, UploadFile
from pydantic import BaseModel, Field
from llama_stack.apis.common.responses import Order
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
# OpenAI Files API Models
class OpenAIFilePurpose(StrEnum):
"""
Valid purpose values for OpenAI Files API.
"""
ASSISTANTS = "assistants"
BATCH = "batch"
# TODO: Add other purposes as needed
@json_schema_type
class OpenAIFileObject(BaseModel):
"""
OpenAI File object as defined in the OpenAI Files API.
:param object: The object type, which is always "file"
:param id: The file identifier, which can be referenced in the API endpoints
:param bytes: The size of the file, in bytes
:param created_at: The Unix timestamp (in seconds) for when the file was created
:param expires_at: The Unix timestamp (in seconds) for when the file expires
:param filename: The name of the file
:param purpose: The intended purpose of the file
"""
object: Literal["file"] = "file"
id: str
bytes: int
created_at: int
expires_at: int
filename: str
purpose: OpenAIFilePurpose
@json_schema_type
class ExpiresAfter(BaseModel):
"""
Control expiration of uploaded files.
Params:
- anchor, must be "created_at"
- seconds, must be int between 3600 and 2592000 (1 hour to 30 days)
"""
MIN: ClassVar[int] = 3600 # 1 hour
MAX: ClassVar[int] = 2592000 # 30 days
anchor: Literal["created_at"]
seconds: int = Field(..., ge=3600, le=2592000)
@json_schema_type
class ListOpenAIFileResponse(BaseModel):
"""
Response for listing files in OpenAI Files API.
:param data: List of file objects
:param has_more: Whether there are more files available beyond this page
:param first_id: ID of the first file in the list for pagination
:param last_id: ID of the last file in the list for pagination
:param object: The object type, which is always "list"
"""
data: list[OpenAIFileObject]
has_more: bool
first_id: str
last_id: str
object: Literal["list"] = "list"
@json_schema_type
class OpenAIFileDeleteResponse(BaseModel):
"""
Response for deleting a file in OpenAI Files API.
:param id: The file identifier that was deleted
:param object: The object type, which is always "file"
:param deleted: Whether the file was successfully deleted
"""
id: str
object: Literal["file"] = "file"
deleted: bool
@runtime_checkable
@trace_protocol
class Files(Protocol):
"""Files
This API is used to upload documents that can be used with other Llama Stack APIs.
"""
# OpenAI Files API Endpoints
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
) -> OpenAIFileObject:
"""Upload file.
Upload a file that can be used across various endpoints.
The file upload should be a multipart form request with:
- file: The File object (not file name) to be uploaded.
- purpose: The intended purpose of the uploaded file.
- expires_after: Optional form values describing expiration for the file.
:param file: The uploaded file object containing content and metadata (filename, content_type, etc.).
:param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune").
:param expires_after: Optional form values describing expiration for the file.
:returns: An OpenAIFileObject representing the uploaded file.
"""
...
@webmethod(route="/openai/v1/files", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_files(
self,
after: str | None = None,
limit: int | None = 10000,
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
"""List files.
Returns a list of files that belong to the user's organization.
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 10,000, and the default is 10,000.
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
:param purpose: Only return files with the given purpose.
:returns: An ListOpenAIFileResponse containing the list of files.
"""
...
@webmethod(route="/openai/v1/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_file(
self,
file_id: str,
) -> OpenAIFileObject:
"""Retrieve file.
Returns information about a specific file.
:param file_id: The ID of the file to use for this request.
:returns: An OpenAIFileObject containing file information.
"""
...
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_file(
self,
file_id: str,
) -> OpenAIFileDeleteResponse:
"""Delete file.
:param file_id: The ID of the file to use for this request.
:returns: An OpenAIFileDeleteResponse indicating successful deletion.
"""
...
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_file_content(
self,
file_id: str,
) -> Response:
"""Retrieve file content.
Returns the contents of the specified file.
:param file_id: The ID of the file to use for this request.
:returns: The raw file content as a binary response.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inference import *

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
)
class LogEvent:
def __init__(
self,
content: str = "",
end: str = "\n",
color="white",
):
self.content = content
self.color = color
self.end = "\n" if end is None else end
def print(self, flush=True):
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
class EventLogger:
async def log(self, event_generator):
async for chunk in event_generator:
if isinstance(chunk, ChatCompletionResponseStreamChunk):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == ChatCompletionResponseEventType.progress:
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("")
else:
yield LogEvent("Assistant> ", color="cyan", end="")
yield LogEvent(chunk.completion_message.content, color="yellow")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inspect import *

View file

@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class RouteInfo(BaseModel):
"""Information about an API route including its path, method, and implementing providers.
:param route: The API endpoint path
:param method: HTTP method for the route
:param provider_types: List of provider types that implement this route
"""
route: str
method: str
provider_types: list[str]
@json_schema_type
class HealthInfo(BaseModel):
"""Health status information for the service.
:param status: Current health status of the service
"""
status: HealthStatus
@json_schema_type
class VersionInfo(BaseModel):
"""Version information for the service.
:param version: Version number of the service
"""
version: str
class ListRoutesResponse(BaseModel):
"""Response containing a list of all available API routes.
:param data: List of available route information objects
"""
data: list[RouteInfo]
@runtime_checkable
class Inspect(Protocol):
"""Inspect
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
"""
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
async def list_routes(self) -> ListRoutesResponse:
"""List routes.
List all available API routes with their methods and implementing providers.
:returns: Response containing information about all available routes.
"""
...
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
async def health(self) -> HealthInfo:
"""Get health status.
Get the current health status of the service.
:returns: Health information indicating if the service is operational.
"""
...
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
async def version(self) -> VersionInfo:
"""Get version.
Get the version of the service.
:returns: Version information containing the service version number.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .models import *

View file

@ -0,0 +1,171 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field, field_validator
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonModelFields(BaseModel):
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this model",
)
@json_schema_type
class ModelType(StrEnum):
"""Enumeration of supported model types in Llama Stack.
:cvar llm: Large language model for text generation and completion
:cvar embedding: Embedding model for converting text to vector representations
:cvar rerank: Reranking model for reordering documents based on their relevance to a query
"""
llm = "llm"
embedding = "embedding"
rerank = "rerank"
@json_schema_type
class Model(CommonModelFields, Resource):
"""A model resource representing an AI model registered in Llama Stack.
:param type: The resource type, always 'model' for model resources
:param model_type: The type of model (LLM or embedding model)
:param metadata: Any additional metadata for this model
:param identifier: Unique identifier for this resource in llama stack
:param provider_resource_id: Unique identifier for this resource in the provider
:param provider_id: ID of the provider that owns this resource
"""
type: Literal[ResourceType.model] = ResourceType.model
@property
def model_id(self) -> str:
return self.identifier
@property
def provider_model_id(self) -> str:
assert self.provider_resource_id is not None, "Provider resource ID must be set"
return self.provider_resource_id
model_config = ConfigDict(protected_namespaces=())
model_type: ModelType = Field(default=ModelType.llm)
@field_validator("provider_resource_id")
@classmethod
def validate_provider_resource_id(cls, v):
if v is None:
raise ValueError("provider_resource_id cannot be None")
return v
class ModelInput(CommonModelFields):
model_id: str
provider_id: str | None = None
provider_model_id: str | None = None
model_type: ModelType | None = ModelType.llm
model_config = ConfigDict(protected_namespaces=())
class ListModelsResponse(BaseModel):
data: list[Model]
@json_schema_type
class OpenAIModel(BaseModel):
"""A model from OpenAI.
:id: The ID of the model
:object: The object type, which will be "model"
:created: The Unix timestamp in seconds when the model was created
:owned_by: The owner of the model
"""
id: str
object: Literal["model"] = "model"
created: int
owned_by: str
class OpenAIListModelsResponse(BaseModel):
data: list[OpenAIModel]
@runtime_checkable
@trace_protocol
class Models(Protocol):
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
async def list_models(self) -> ListModelsResponse:
"""List all models.
:returns: A ListModelsResponse.
"""
...
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API.
:returns: A OpenAIListModelsResponse.
"""
...
@webmethod(route="/models/{model_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_model(
self,
model_id: str,
) -> Model:
"""Get model.
Get a model by its identifier.
:param model_id: The identifier of the model to get.
:returns: A Model.
"""
...
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1)
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
"""Register model.
Register a model.
:param model_id: The identifier of the model to register.
:param provider_model_id: The identifier of the model in the provider.
:param provider_id: The identifier of the provider.
:param metadata: Any additional metadata for this model.
:param model_type: The type of model to register.
:returns: A Model.
"""
...
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_model(
self,
model_id: str,
) -> None:
"""Unregister model.
Unregister a model.
:param model_id: The identifier of the model to unregister.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .post_training import *

View file

@ -0,0 +1,374 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class OptimizerType(Enum):
"""Available optimizer algorithms for training.
:cvar adam: Adaptive Moment Estimation optimizer
:cvar adamw: AdamW optimizer with weight decay
:cvar sgd: Stochastic Gradient Descent optimizer
"""
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class DatasetFormat(Enum):
"""Format of the training dataset.
:cvar instruct: Instruction-following format with prompt and completion
:cvar dialog: Multi-turn conversation format with messages
"""
instruct = "instruct"
dialog = "dialog"
@json_schema_type
class DataConfig(BaseModel):
"""Configuration for training data and data loading.
:param dataset_id: Unique identifier for the training dataset
:param batch_size: Number of samples per training batch
:param shuffle: Whether to shuffle the dataset during training
:param data_format: Format of the dataset (instruct or dialog)
:param validation_dataset_id: (Optional) Unique identifier for the validation dataset
:param packed: (Optional) Whether to pack multiple samples into a single sequence for efficiency
:param train_on_input: (Optional) Whether to compute loss on input tokens as well as output tokens
"""
dataset_id: str
batch_size: int
shuffle: bool
data_format: DatasetFormat
validation_dataset_id: str | None = None
packed: bool | None = False
train_on_input: bool | None = False
@json_schema_type
class OptimizerConfig(BaseModel):
"""Configuration parameters for the optimization algorithm.
:param optimizer_type: Type of optimizer to use (adam, adamw, or sgd)
:param lr: Learning rate for the optimizer
:param weight_decay: Weight decay coefficient for regularization
:param num_warmup_steps: Number of steps for learning rate warmup
"""
optimizer_type: OptimizerType
lr: float
weight_decay: float
num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
"""Configuration for memory and compute efficiency optimizations.
:param enable_activation_checkpointing: (Optional) Whether to use activation checkpointing to reduce memory usage
:param enable_activation_offloading: (Optional) Whether to offload activations to CPU to save GPU memory
:param memory_efficient_fsdp_wrap: (Optional) Whether to use memory-efficient FSDP wrapping
:param fsdp_cpu_offload: (Optional) Whether to offload FSDP parameters to CPU
"""
enable_activation_checkpointing: bool | None = False
enable_activation_offloading: bool | None = False
memory_efficient_fsdp_wrap: bool | None = False
fsdp_cpu_offload: bool | None = False
@json_schema_type
class TrainingConfig(BaseModel):
"""Comprehensive configuration for the training process.
:param n_epochs: Number of training epochs to run
:param max_steps_per_epoch: Maximum number of steps to run per epoch
:param gradient_accumulation_steps: Number of steps to accumulate gradients before updating
:param max_validation_steps: (Optional) Maximum number of validation steps per epoch
:param data_config: (Optional) Configuration for data loading and formatting
:param optimizer_config: (Optional) Configuration for the optimization algorithm
:param efficiency_config: (Optional) Configuration for memory and compute optimizations
:param dtype: (Optional) Data type for model parameters (bf16, fp16, fp32)
"""
n_epochs: int
max_steps_per_epoch: int = 1
gradient_accumulation_steps: int = 1
max_validation_steps: int | None = 1
data_config: DataConfig | None = None
optimizer_config: OptimizerConfig | None = None
efficiency_config: EfficiencyConfig | None = None
dtype: str | None = "bf16"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
"""Configuration for Low-Rank Adaptation (LoRA) fine-tuning.
:param type: Algorithm type identifier, always "LoRA"
:param lora_attn_modules: List of attention module names to apply LoRA to
:param apply_lora_to_mlp: Whether to apply LoRA to MLP layers
:param apply_lora_to_output: Whether to apply LoRA to output projection layers
:param rank: Rank of the LoRA adaptation (lower rank = fewer parameters)
:param alpha: LoRA scaling parameter that controls adaptation strength
:param use_dora: (Optional) Whether to use DoRA (Weight-Decomposed Low-Rank Adaptation)
:param quantize_base: (Optional) Whether to quantize the base model weights
"""
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: list[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: bool | None = False
quantize_base: bool | None = False
@json_schema_type
class QATFinetuningConfig(BaseModel):
"""Configuration for Quantization-Aware Training (QAT) fine-tuning.
:param type: Algorithm type identifier, always "QAT"
:param quantizer_name: Name of the quantization algorithm to use
:param group_size: Size of groups for grouped quantization
"""
type: Literal["QAT"] = "QAT"
quantizer_name: str
group_size: int
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job.
:param job_uuid: Unique identifier for the training job
:param log_lines: List of log message strings from the training process
"""
job_uuid: str
log_lines: list[str]
@json_schema_type
class RLHFAlgorithm(Enum):
"""Available reinforcement learning from human feedback algorithms.
:cvar dpo: Direct Preference Optimization algorithm
"""
dpo = "dpo"
@json_schema_type
class DPOLossType(Enum):
sigmoid = "sigmoid"
hinge = "hinge"
ipo = "ipo"
kto_pair = "kto_pair"
@json_schema_type
class DPOAlignmentConfig(BaseModel):
"""Configuration for Direct Preference Optimization (DPO) alignment.
:param beta: Temperature parameter for the DPO loss
:param loss_type: The type of loss function to use for DPO
"""
beta: float
loss_type: DPOLossType = DPOLossType.sigmoid
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model using reinforcement learning from human feedback.
:param job_uuid: Unique identifier for the training job
:param finetuned_model: URL or path to the base model to fine-tune
:param dataset_id: Unique identifier for the training dataset
:param validation_dataset_id: Unique identifier for the validation dataset
:param algorithm: RLHF algorithm to use for training
:param algorithm_config: Configuration parameters for the RLHF algorithm
:param optimizer_config: Configuration parameters for the optimization algorithm
:param training_config: Configuration parameters for the training process
:param hyperparam_search_config: Configuration for hyperparameter search
:param logger_config: Configuration for training logging
"""
job_uuid: str
finetuned_model: URL
dataset_id: str
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: dict[str, Any]
logger_config: dict[str, Any]
class PostTrainingJob(BaseModel):
job_uuid: str
@json_schema_type
class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job.
:param job_uuid: Unique identifier for the training job
:param status: Current status of the training job
:param scheduled_at: (Optional) Timestamp when the job was scheduled
:param started_at: (Optional) Timestamp when the job execution began
:param completed_at: (Optional) Timestamp when the job finished, if completed
:param resources_allocated: (Optional) Information about computational resources allocated to the job
:param checkpoints: List of model checkpoints created during training
"""
job_uuid: str
status: JobStatus
scheduled_at: datetime | None = None
started_at: datetime | None = None
completed_at: datetime | None = None
resources_allocated: dict[str, Any] | None = None
checkpoints: list[Checkpoint] = Field(default_factory=list)
class ListPostTrainingJobsResponse(BaseModel):
data: list[PostTrainingJob]
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job.
:param job_uuid: Unique identifier for the training job
:param checkpoints: List of model checkpoints created during training
"""
job_uuid: str
checkpoints: list[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str | None = Field(
default=None,
description="Model descriptor for training if not in provider config`",
),
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
"""Run supervised fine-tuning of a model.
:param job_uuid: The UUID of the job to create.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:param model: The model to fine-tune.
:param checkpoint_dir: The directory to save checkpoint(s) to.
:param algorithm_config: The algorithm configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
"""Run preference optimization of a model.
:param job_uuid: The UUID of the job to create.
:param finetuned_model: The model to fine-tune.
:param algorithm_config: The algorithm configuration.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all training jobs.
:returns: A ListPostTrainingJobsResponse.
"""
...
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
"""Get the status of a training job.
:param job_uuid: The UUID of the job to get the status of.
:returns: A PostTrainingJobStatusResponse.
"""
...
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancel a training job.
:param job_uuid: The UUID of the job to cancel.
"""
...
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job.
:param job_uuid: The UUID of the job to get the artifacts of.
:returns: A PostTrainingJobArtifactsResponse.
"""
...

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .prompts import ListPromptsResponse, Prompt, Prompts
__all__ = ["Prompt", "Prompts", "ListPromptsResponse"]

View file

@ -0,0 +1,204 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import secrets
from typing import Protocol, runtime_checkable
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class Prompt(BaseModel):
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack.
:param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API.
:param version: Version (integer starting at 1, incremented on save)
:param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>'
:param variables: List of prompt variable names that can be used in the prompt template
:param is_default: Boolean indicating whether this version is the default version for this prompt
"""
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1)
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
variables: list[str] = Field(
default_factory=list, description="List of variable names that can be used in the prompt template"
)
is_default: bool = Field(
default=False, description="Boolean indicating whether this version is the default version"
)
@field_validator("prompt_id")
@classmethod
def validate_prompt_id(cls, prompt_id: str) -> str:
if not isinstance(prompt_id, str):
raise TypeError("prompt_id must be a string in format 'pmpt_<48-digit-hash>'")
if not prompt_id.startswith("pmpt_"):
raise ValueError("prompt_id must start with 'pmpt_' prefix")
hex_part = prompt_id[5:]
if len(hex_part) != 48:
raise ValueError("prompt_id must be in format 'pmpt_<48-digit-hash>' (48 lowercase hex chars)")
for char in hex_part:
if char not in "0123456789abcdef":
raise ValueError("prompt_id hex part must contain only lowercase hex characters [0-9a-f]")
return prompt_id
@field_validator("version")
@classmethod
def validate_version(cls, prompt_version: int) -> int:
if prompt_version < 1:
raise ValueError("version must be >= 1")
return prompt_version
@model_validator(mode="after")
def validate_prompt_variables(self):
"""Validate that all variables used in the prompt are declared in the variables list."""
if not self.prompt:
return self
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
declared_variables = set(self.variables)
undeclared = prompt_variables - declared_variables
if undeclared:
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
return self
@classmethod
def generate_prompt_id(cls) -> str:
# Generate 48 hex characters (24 bytes)
random_bytes = secrets.token_bytes(24)
hex_string = random_bytes.hex()
return f"pmpt_{hex_string}"
class ListPromptsResponse(BaseModel):
"""Response model to list prompts."""
data: list[Prompt]
@runtime_checkable
@trace_protocol
class Prompts(Protocol):
"""Prompts
Protocol for prompt management operations."""
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
async def list_prompts(self) -> ListPromptsResponse:
"""List all prompts.
:returns: A ListPromptsResponse containing all prompts.
"""
...
@webmethod(route="/prompts/{prompt_id}/versions", method="GET", level=LLAMA_STACK_API_V1)
async def list_prompt_versions(
self,
prompt_id: str,
) -> ListPromptsResponse:
"""List prompt versions.
List all versions of a specific prompt.
:param prompt_id: The identifier of the prompt to list versions for.
:returns: A ListPromptsResponse containing all versions of the prompt.
"""
...
@webmethod(route="/prompts/{prompt_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_prompt(
self,
prompt_id: str,
version: int | None = None,
) -> Prompt:
"""Get prompt.
Get a prompt by its identifier and optional version.
:param prompt_id: The identifier of the prompt to get.
:param version: The version of the prompt to get (defaults to latest).
:returns: A Prompt resource.
"""
...
@webmethod(route="/prompts", method="POST", level=LLAMA_STACK_API_V1)
async def create_prompt(
self,
prompt: str,
variables: list[str] | None = None,
) -> Prompt:
"""Create prompt.
Create a new prompt.
:param prompt: The prompt text content with variable placeholders.
:param variables: List of variable names that can be used in the prompt template.
:returns: The created Prompt resource.
"""
...
@webmethod(route="/prompts/{prompt_id}", method="PUT", level=LLAMA_STACK_API_V1)
async def update_prompt(
self,
prompt_id: str,
prompt: str,
version: int,
variables: list[str] | None = None,
set_as_default: bool = True,
) -> Prompt:
"""Update prompt.
Update an existing prompt (increments version).
:param prompt_id: The identifier of the prompt to update.
:param prompt: The updated prompt text content.
:param version: The current version of the prompt being updated.
:param variables: Updated list of variable names that can be used in the prompt template.
:param set_as_default: Set the new version as the default (default=True).
:returns: The updated Prompt resource with incremented version.
"""
...
@webmethod(route="/prompts/{prompt_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def delete_prompt(
self,
prompt_id: str,
) -> None:
"""Delete prompt.
Delete a prompt.
:param prompt_id: The identifier of the prompt to delete.
"""
...
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT", level=LLAMA_STACK_API_V1)
async def set_default_version(
self,
prompt_id: str,
version: int,
) -> Prompt:
"""Set prompt version.
Set which version of a prompt should be the default in get_prompt (latest).
:param prompt_id: The identifier of the prompt.
:param version: The version to set as default.
:returns: The prompt with the specified version now set as default.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .providers import *

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.datatypes import HealthResponse
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ProviderInfo(BaseModel):
"""Information about a registered provider including its configuration and health status.
:param api: The API name this provider implements
:param provider_id: Unique identifier for the provider
:param provider_type: The type of provider implementation
:param config: Configuration parameters for the provider
:param health: Current health status of the provider
"""
api: str
provider_id: str
provider_type: str
config: dict[str, Any]
health: HealthResponse
class ListProvidersResponse(BaseModel):
"""Response containing a list of all available providers.
:param data: List of provider information objects
"""
data: list[ProviderInfo]
@runtime_checkable
class Providers(Protocol):
"""Providers
Providers API for inspecting, listing, and modifying providers and their configurations.
"""
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
async def list_providers(self) -> ListProvidersResponse:
"""List providers.
List all available providers.
:returns: A ListProvidersResponse containing information about all providers.
"""
...
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get provider.
Get detailed information about a specific provider.
:param provider_id: The ID of the provider to inspect.
:returns: A ProviderInfo object containing the provider's details.
"""
...

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from pydantic import BaseModel, Field
class ResourceType(StrEnum):
model = "model"
shield = "shield"
vector_store = "vector_store"
dataset = "dataset"
scoring_function = "scoring_function"
benchmark = "benchmark"
tool = "tool"
tool_group = "tool_group"
prompt = "prompt"
class Resource(BaseModel):
"""Base class for all Llama Stack resources"""
identifier: str = Field(description="Unique identifier for this resource in llama stack")
provider_resource_id: str | None = Field(
default=None,
description="Unique identifier for this resource in the provider",
)
provider_id: str = Field(description="ID of the provider that owns this resource")
type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_store', etc.)")

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .safety import *

View file

@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.shields import Shield
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ModerationObjectResults(BaseModel):
"""A moderation object.
:param flagged: Whether any of the below categories are flagged.
:param categories: A list of the categories, and whether they are flagged or not.
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
:param category_scores: A list of the categories along with their scores as predicted by model.
"""
flagged: bool
categories: dict[str, bool] | None = None
category_applied_input_types: dict[str, list[str]] | None = None
category_scores: dict[str, float] | None = None
user_message: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class ModerationObject(BaseModel):
"""A moderation object.
:param id: The unique identifier for the moderation request.
:param model: The model used to generate the moderation results.
:param results: A list of moderation objects
"""
id: str
model: str
results: list[ModerationObjectResults]
@json_schema_type
class ViolationLevel(Enum):
"""Severity level of a safety violation.
:cvar INFO: Informational level violation that does not require action
:cvar WARN: Warning level violation that suggests caution but allows continuation
:cvar ERROR: Error level violation that requires blocking or intervention
"""
INFO = "info"
WARN = "warn"
ERROR = "error"
@json_schema_type
class SafetyViolation(BaseModel):
"""Details of a safety violation detected by content moderation.
:param violation_level: Severity level of the violation
:param user_message: (Optional) Message to convey to the user about the violation
:param metadata: Additional metadata including specific violation codes for debugging and telemetry
"""
violation_level: ViolationLevel
# what message should you convey to the user
user_message: str | None = None
# additional metadata (including specific violation codes) more for
# debugging, telemetry
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RunShieldResponse(BaseModel):
"""Response from running a safety shield.
:param violation: (Optional) Safety violation detected by the shield, if any
"""
violation: SafetyViolation | None = None
class ShieldStore(Protocol):
async def get_shield(self, identifier: str) -> Shield: ...
@runtime_checkable
@trace_protocol
class Safety(Protocol):
"""Safety
OpenAI-compatible Moderations API.
"""
shield_store: ShieldStore
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
async def run_shield(
self,
shield_id: str,
messages: list[OpenAIMessageParam],
params: dict[str, Any],
) -> RunShieldResponse:
"""Run shield.
Run a shield.
:param shield_id: The identifier of the shield to run.
:param messages: The messages to run the shield on.
:param params: The parameters of the shield.
:returns: A RunShieldResponse.
"""
...
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
"""Create moderation.
Classifies if text and/or image inputs are potentially harmful.
:param input: Input (or inputs) to classify.
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
:param model: (Optional) The content moderation model you would like to use.
:returns: A moderation object.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring import *

View file

@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
# mapping of metric to value
ScoringResultRow = dict[str, Any]
@json_schema_type
class ScoringResult(BaseModel):
"""
A scoring result for a single row.
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
:param aggregated_results: Map of metric name to aggregated value
"""
score_rows: list[ScoringResultRow]
# aggregated metrics to value
aggregated_results: dict[str, Any]
@json_schema_type
class ScoreBatchResponse(BaseModel):
"""Response from batch scoring operations on datasets.
:param dataset_id: (Optional) The identifier of the dataset that was scored
:param results: A map of scoring function name to ScoringResult
"""
dataset_id: str | None = None
results: dict[str, ScoringResult]
@json_schema_type
class ScoreResponse(BaseModel):
"""
The response from scoring.
:param results: A map of scoring function name to ScoringResult.
"""
# each key in the dict is a scoring function name
results: dict[str, ScoringResult]
class ScoringFunctionStore(Protocol):
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
@runtime_checkable
class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score-batch", method="POST", level=LLAMA_STACK_API_V1)
async def score_batch(
self,
dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
"""Score a batch of rows.
:param dataset_id: The ID of the dataset to score.
:param scoring_functions: The scoring functions to use for the scoring.
:param save_results_dataset: Whether to save the results to a dataset.
:returns: A ScoreBatchResponse.
"""
...
@webmethod(route="/scoring/score", method="POST", level=LLAMA_STACK_API_V1)
async def score(
self,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None],
) -> ScoreResponse:
"""Score a list of rows.
:param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring.
:returns: A ScoreResponse object containing rows and aggregated results.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring_functions import *

View file

@ -0,0 +1,208 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# TODO: use enum.StrEnum when we drop support for python 3.10
from enum import StrEnum
from typing import (
Annotated,
Any,
Literal,
Protocol,
runtime_checkable,
)
from pydantic import BaseModel, Field
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
@json_schema_type
class ScoringFnParamsType(StrEnum):
"""Types of scoring function parameter configurations.
:cvar llm_as_judge: Use an LLM model to evaluate and score responses
:cvar regex_parser: Use regex patterns to extract and score specific parts of responses
:cvar basic: Basic scoring with simple aggregation functions
"""
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
basic = "basic"
@json_schema_type
class AggregationFunctionType(StrEnum):
"""Types of aggregation functions for scoring results.
:cvar average: Calculate the arithmetic mean of scores
:cvar weighted_average: Calculate a weighted average of scores
:cvar median: Calculate the median value of scores
:cvar categorical_count: Count occurrences of categorical values
:cvar accuracy: Calculate accuracy as the proportion of correct answers
"""
average = "average"
weighted_average = "weighted_average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
"""Parameters for LLM-as-judge scoring function configuration.
:param type: The type of scoring function parameters, always llm_as_judge
:param judge_model: Identifier of the LLM model to use as a judge for scoring
:param prompt_template: (Optional) Custom prompt template for the judge model
:param judge_score_regexes: Regexes to extract the answer from generated response
:param aggregation_functions: Aggregation functions to apply to the scores of each row
"""
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
judge_model: str
prompt_template: str | None = None
judge_score_regexes: list[str] = Field(
description="Regexes to extract the answer from generated response",
default_factory=lambda: [],
)
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=lambda: [],
)
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
"""Parameters for regex parser scoring function configuration.
:param type: The type of scoring function parameters, always regex_parser
:param parsing_regexes: Regex to extract the answer from generated response
:param aggregation_functions: Aggregation functions to apply to the scores of each row
"""
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
parsing_regexes: list[str] = Field(
description="Regex to extract the answer from generated response",
default_factory=lambda: [],
)
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=lambda: [],
)
@json_schema_type
class BasicScoringFnParams(BaseModel):
"""Parameters for basic scoring function configuration.
:param type: The type of scoring function parameters, always basic
:param aggregation_functions: Aggregation functions to apply to the scores of each row
"""
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
ScoringFnParams = Annotated[
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
Field(discriminator="type"),
]
register_schema(ScoringFnParams, name="ScoringFnParams")
class CommonScoringFnFields(BaseModel):
description: str | None = None
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
params: ScoringFnParams | None = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
"""A scoring function resource for evaluating model outputs.
:param type: The resource type, always scoring_function
"""
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
@property
def scoring_fn_id(self) -> str:
return self.identifier
@property
def provider_scoring_fn_id(self) -> str | None:
return self.provider_resource_id
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: str | None = None
provider_scoring_fn_id: str | None = None
class ListScoringFunctionsResponse(BaseModel):
data: list[ScoringFn]
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET", level=LLAMA_STACK_API_V1)
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
"""List all scoring functions.
:returns: A ListScoringFunctionsResponse.
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
"""Get a scoring function by its ID.
:param scoring_fn_id: The ID of the scoring function to get.
:returns: A ScoringFn.
"""
...
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1)
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None:
"""Register a scoring function.
:param scoring_fn_id: The ID of the scoring function to register.
:param description: The description of the scoring function.
:param return_type: The return type of the scoring function.
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
:param provider_id: The ID of the provider to use for the scoring function.
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
"""Unregister a scoring function.
:param scoring_fn_id: The ID of the scoring function to unregister.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .shields import *

View file

@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonShieldFields(BaseModel):
params: dict[str, Any] | None = None
@json_schema_type
class Shield(CommonShieldFields, Resource):
"""A safety shield resource that can be used to check content.
:param params: (Optional) Configuration parameters for the shield
:param type: The resource type, always shield
"""
type: Literal[ResourceType.shield] = ResourceType.shield
@property
def shield_id(self) -> str:
return self.identifier
@property
def provider_shield_id(self) -> str | None:
return self.provider_resource_id
class ShieldInput(CommonShieldFields):
shield_id: str
provider_id: str | None = None
provider_shield_id: str | None = None
class ListShieldsResponse(BaseModel):
data: list[Shield]
@runtime_checkable
@trace_protocol
class Shields(Protocol):
@webmethod(route="/shields", method="GET", level=LLAMA_STACK_API_V1)
async def list_shields(self) -> ListShieldsResponse:
"""List all shields.
:returns: A ListShieldsResponse.
"""
...
@webmethod(route="/shields/{identifier:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_shield(self, identifier: str) -> Shield:
"""Get a shield by its identifier.
:param identifier: The identifier of the shield to get.
:returns: A Shield.
"""
...
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1)
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
"""Register a shield.
:param shield_id: The identifier of the shield to register.
:param provider_shield_id: The identifier of the shield in the provider.
:param provider_id: The identifier of the provider.
:param params: The parameters of the shield.
:returns: A Shield.
"""
...
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_shield(self, identifier: str) -> None:
"""Unregister a shield.
:param identifier: The identifier of the shield to unregister.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .synthetic_data_generation import *

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Protocol
from pydantic import BaseModel
from llama_stack.apis.inference import Message
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
class FilteringFunction(Enum):
"""The type of filtering function.
:cvar none: No filtering applied, accept all generated synthetic data
:cvar random: Random sampling of generated data points
:cvar top_k: Keep only the top-k highest scoring synthetic data samples
:cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold
:cvar top_k_top_p: Combined top-k and top-p filtering strategy
:cvar sigmoid: Apply sigmoid function for probability-based filtering
"""
none = "none"
random = "random"
top_k = "top_k"
top_p = "top_p"
top_k_top_p = "top_k_top_p"
sigmoid = "sigmoid"
@json_schema_type
class SyntheticDataGenerationRequest(BaseModel):
"""Request to generate synthetic data. A small batch of prompts and a filtering function
:param dialogs: List of conversation messages to use as input for synthetic data generation
:param filtering_function: Type of filtering to apply to generated synthetic data samples
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
"""
dialogs: list[Message]
filtering_function: FilteringFunction = FilteringFunction.none
model: str | None = None
@json_schema_type
class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.
:param synthetic_data: List of generated synthetic data samples that passed the filtering criteria
:param statistics: (Optional) Statistical information about the generation process and filtering results
"""
synthetic_data: list[dict[str, Any]]
statistics: dict[str, Any] | None = None
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
def synthetic_data_generate(
self,
dialogs: list[Message],
filtering_function: FilteringFunction = FilteringFunction.none,
model: str | None = None,
) -> SyntheticDataGenerationResponse:
"""Generate synthetic data based on input dialogs and apply filtering.
:param dialogs: List of conversation messages to use as input for synthetic data generation
:param filtering_function: Type of filtering to apply to generated synthetic data samples
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
:returns: Response containing filtered synthetic data samples and optional statistics
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .telemetry import *

View file

@ -0,0 +1,423 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import (
Annotated,
Any,
Literal,
Protocol,
runtime_checkable,
)
from pydantic import BaseModel, Field
from llama_stack.models.llama.datatypes import Primitive
from llama_stack.schema_utils import json_schema_type, register_schema
# Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7
@json_schema_type
class SpanStatus(Enum):
"""The status of a span indicating whether it completed successfully or with an error.
:cvar OK: Span completed successfully without errors
:cvar ERROR: Span completed with an error or failure
"""
OK = "ok"
ERROR = "error"
@json_schema_type
class Span(BaseModel):
"""A span representing a single operation within a trace.
:param span_id: Unique identifier for the span
:param trace_id: Unique identifier for the trace this span belongs to
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
:param name: Human-readable name describing the operation this span represents
:param start_time: Timestamp when the operation began
:param end_time: (Optional) Timestamp when the operation finished, if completed
:param attributes: (Optional) Key-value pairs containing additional metadata about the span
"""
span_id: str
trace_id: str
parent_span_id: str | None = None
name: str
start_time: datetime
end_time: datetime | None = None
attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
def set_attribute(self, key: str, value: Any):
if self.attributes is None:
self.attributes = {}
self.attributes[key] = value
@json_schema_type
class Trace(BaseModel):
"""A trace representing the complete execution path of a request across multiple operations.
:param trace_id: Unique identifier for the trace
:param root_span_id: Unique identifier for the root span that started this trace
:param start_time: Timestamp when the trace began
:param end_time: (Optional) Timestamp when the trace finished, if completed
"""
trace_id: str
root_span_id: str
start_time: datetime
end_time: datetime | None = None
@json_schema_type
class EventType(Enum):
"""The type of telemetry event being logged.
:cvar UNSTRUCTURED_LOG: A simple log message with severity level
:cvar STRUCTURED_LOG: A structured log event with typed payload data
:cvar METRIC: A metric measurement with value and unit
"""
UNSTRUCTURED_LOG = "unstructured_log"
STRUCTURED_LOG = "structured_log"
METRIC = "metric"
@json_schema_type
class LogSeverity(Enum):
"""The severity level of a log message.
:cvar VERBOSE: Detailed diagnostic information for troubleshooting
:cvar DEBUG: Debug information useful during development
:cvar INFO: General informational messages about normal operation
:cvar WARN: Warning messages about potentially problematic situations
:cvar ERROR: Error messages indicating failures that don't stop execution
:cvar CRITICAL: Critical error messages indicating severe failures
"""
VERBOSE = "verbose"
DEBUG = "debug"
INFO = "info"
WARN = "warn"
ERROR = "error"
CRITICAL = "critical"
class EventCommon(BaseModel):
"""Common fields shared by all telemetry events.
:param trace_id: Unique identifier for the trace this event belongs to
:param span_id: Unique identifier for the span this event belongs to
:param timestamp: Timestamp when the event occurred
:param attributes: (Optional) Key-value pairs containing additional metadata about the event
"""
trace_id: str
span_id: str
timestamp: datetime
attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
@json_schema_type
class UnstructuredLogEvent(EventCommon):
"""An unstructured log event containing a simple text message.
:param type: Event type identifier set to UNSTRUCTURED_LOG
:param message: The log message text
:param severity: The severity level of the log message
"""
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
message: str
severity: LogSeverity
@json_schema_type
class MetricEvent(EventCommon):
"""A metric event containing a measured value.
:param type: Event type identifier set to METRIC
:param metric: The name of the metric being measured
:param value: The numeric value of the metric measurement
:param unit: The unit of measurement for the metric value
"""
type: Literal[EventType.METRIC] = EventType.METRIC
metric: str # this would be an enum
value: int | float
unit: str
@json_schema_type
class MetricInResponse(BaseModel):
"""A metric value included in API responses.
:param metric: The name of the metric
:param value: The numeric value of the metric
:param unit: (Optional) The unit of measurement for the metric value
"""
metric: str
value: int | float
unit: str | None = None
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be included with the response
# To do this, we will need to augment all response types with a metrics field.
# We have hit a blocker from stainless SDK that prevents us from doing this.
# The blocker is that if we were to augment the response types that have a data field
# in them like so
# class ListModelsResponse(BaseModel):
# metrics: Optional[List[MetricEvent]] = None
# data: List[Models]
# ...
# The client SDK will need to access the data by using a .data field, which is not
# ergonomic. Stainless SDK does support unwrapping the response type, but it
# requires that the response type to only have a single field.
# We will need a way in the client SDK to signal that the metrics are needed
# and if they are needed, the client SDK has to return the full response type
# without unwrapping it.
class MetricResponseMixin(BaseModel):
"""Mixin class for API responses that can include metrics.
:param metrics: (Optional) List of metrics associated with the API response
"""
metrics: list[MetricInResponse] | None = None
@json_schema_type
class StructuredLogType(Enum):
"""The type of structured log event payload.
:cvar SPAN_START: Event indicating the start of a new span
:cvar SPAN_END: Event indicating the completion of a span
"""
SPAN_START = "span_start"
SPAN_END = "span_end"
@json_schema_type
class SpanStartPayload(BaseModel):
"""Payload for a span start event.
:param type: Payload type identifier set to SPAN_START
:param name: Human-readable name describing the operation this span represents
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
"""
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
name: str
parent_span_id: str | None = None
@json_schema_type
class SpanEndPayload(BaseModel):
"""Payload for a span end event.
:param type: Payload type identifier set to SPAN_END
:param status: The final status of the span indicating success or failure
"""
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
status: SpanStatus
StructuredLogPayload = Annotated[
SpanStartPayload | SpanEndPayload,
Field(discriminator="type"),
]
register_schema(StructuredLogPayload, name="StructuredLogPayload")
@json_schema_type
class StructuredLogEvent(EventCommon):
"""A structured log event containing typed payload data.
:param type: Event type identifier set to STRUCTURED_LOG
:param payload: The structured payload data for the log event
"""
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
payload: StructuredLogPayload
Event = Annotated[
UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
Field(discriminator="type"),
]
register_schema(Event, name="Event")
@json_schema_type
class EvalTrace(BaseModel):
"""A trace record for evaluation purposes.
:param session_id: Unique identifier for the evaluation session
:param step: The evaluation step or phase identifier
:param input: The input data for the evaluation
:param output: The actual output produced during evaluation
:param expected_output: The expected output for comparison during evaluation
"""
session_id: str
step: str
input: str
output: str
expected_output: str
@json_schema_type
class SpanWithStatus(Span):
"""A span that includes status information.
:param status: (Optional) The current status of the span
"""
status: SpanStatus | None = None
@json_schema_type
class QueryConditionOp(Enum):
"""Comparison operators for query conditions.
:cvar EQ: Equal to comparison
:cvar NE: Not equal to comparison
:cvar GT: Greater than comparison
:cvar LT: Less than comparison
"""
EQ = "eq"
NE = "ne"
GT = "gt"
LT = "lt"
@json_schema_type
class QueryCondition(BaseModel):
"""A condition for filtering query results.
:param key: The attribute key to filter on
:param op: The comparison operator to apply
:param value: The value to compare against
"""
key: str
op: QueryConditionOp
value: Any
class QueryTracesResponse(BaseModel):
"""Response containing a list of traces.
:param data: List of traces matching the query criteria
"""
data: list[Trace]
class QuerySpansResponse(BaseModel):
"""Response containing a list of spans.
:param data: List of spans matching the query criteria
"""
data: list[Span]
class QuerySpanTreeResponse(BaseModel):
"""Response containing a tree structure of spans.
:param data: Dictionary mapping span IDs to spans with status information
"""
data: dict[str, SpanWithStatus]
class MetricQueryType(Enum):
"""The type of metric query to perform.
:cvar RANGE: Query metrics over a time range
:cvar INSTANT: Query metrics at a specific point in time
"""
RANGE = "range"
INSTANT = "instant"
class MetricLabelOperator(Enum):
"""Operators for matching metric labels.
:cvar EQUALS: Label value must equal the specified value
:cvar NOT_EQUALS: Label value must not equal the specified value
:cvar REGEX_MATCH: Label value must match the specified regular expression
:cvar REGEX_NOT_MATCH: Label value must not match the specified regular expression
"""
EQUALS = "="
NOT_EQUALS = "!="
REGEX_MATCH = "=~"
REGEX_NOT_MATCH = "!~"
class MetricLabelMatcher(BaseModel):
"""A matcher for filtering metrics by label values.
:param name: The name of the label to match
:param value: The value to match against
:param operator: The comparison operator to use for matching
"""
name: str
value: str
operator: MetricLabelOperator = MetricLabelOperator.EQUALS
@json_schema_type
class MetricLabel(BaseModel):
"""A label associated with a metric.
:param name: The name of the label
:param value: The value of the label
"""
name: str
value: str
@json_schema_type
class MetricDataPoint(BaseModel):
"""A single data point in a metric time series.
:param timestamp: Unix timestamp when the metric value was recorded
:param value: The numeric value of the metric at this timestamp
"""
timestamp: int
value: float
unit: str
@json_schema_type
class MetricSeries(BaseModel):
"""A time series of metric data points.
:param metric: The name of the metric
:param labels: List of labels associated with this metric series
:param values: List of data points in chronological order
"""
metric: str
labels: list[MetricLabel]
values: list[MetricDataPoint]
class QueryMetricsResponse(BaseModel):
"""Response containing metric time series data.
:param data: List of metric series matching the query criteria
"""
data: list[MetricSeries]
@runtime_checkable
class Telemetry(Protocol):
async def log_event(
self,
event: Event,
ttl_seconds: int = DEFAULT_TTL_DAYS * 86400,
) -> None:
"""Log an event.
:param event: The event to log.
:param ttl_seconds: The time to live of the event.
"""
...

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .rag_tool import *
from .tools import *

View file

@ -0,0 +1,218 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field, field_validator
from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class RRFRanker(BaseModel):
"""
Reciprocal Rank Fusion (RRF) ranker configuration.
:param type: The type of ranker, always "rrf"
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
Must be greater than 0
"""
type: Literal["rrf"] = "rrf"
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
@json_schema_type
class WeightedRanker(BaseModel):
"""
Weighted ranker configuration that combines vector and keyword scores.
:param type: The type of ranker, always "weighted"
:param alpha: Weight factor between 0 and 1.
0 means only use keyword scores,
1 means only use vector scores,
values in between blend both scores.
"""
type: Literal["weighted"] = "weighted"
alpha: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
)
Ranker = Annotated[
RRFRanker | WeightedRanker,
Field(discriminator="type"),
]
register_schema(Ranker, name="Ranker")
@json_schema_type
class RAGDocument(BaseModel):
"""
A document to be used for document ingestion in the RAG Tool.
:param document_id: The unique identifier for the document.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
:param metadata: Additional metadata for the document.
"""
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryResult(BaseModel):
"""Result of a RAG query containing retrieved content and metadata.
:param content: (Optional) The retrieved content from the query
:param metadata: Additional metadata about the query result
"""
content: InterleavedContent | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryGenerator(Enum):
"""Types of query generators for RAG systems.
:cvar default: Default query generator using simple text processing
:cvar llm: LLM-based query generator for enhanced query understanding
:cvar custom: Custom query generator implementation
"""
default = "default"
llm = "llm"
custom = "custom"
@json_schema_type
class RAGSearchMode(StrEnum):
"""
Search modes for RAG query retrieval:
- VECTOR: Uses vector similarity search for semantic matching
- KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
"""
VECTOR = "vector"
KEYWORD = "keyword"
HYBRID = "hybrid"
@json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the default RAG query generator.
:param type: Type of query generator, always 'default'
:param separator: String separator used to join query terms
"""
type: Literal["default"] = "default"
separator: str = " "
@json_schema_type
class LLMRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the LLM-based RAG query generator.
:param type: Type of query generator, always 'llm'
:param model: Name of the language model to use for query generation
:param template: Template string for formatting the query generation prompt
"""
type: Literal["llm"] = "llm"
model: str
template: str
RAGQueryGeneratorConfig = Annotated[
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
Field(discriminator="type"),
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type
class RAGQueryConfig(BaseModel):
"""
Configuration for the RAG query generation.
:param query_generator_config: Configuration for the query generator.
:param max_tokens_in_context: Maximum number of tokens in the context.
:param max_chunks: Maximum number of chunks to retrieve.
:param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
:param mode: Search mode for retrievaleither "vector", "keyword", or "hybrid". Default "vector".
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
"""
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
max_tokens_in_context: int = 4096
max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: RAGSearchMode | None = RAGSearchMode.VECTOR
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
@field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str:
if "{chunk.content}" not in v:
raise ValueError("chunk_template must contain {chunk.content}")
if "{index}" not in v:
raise ValueError("chunk_template must contain {index}")
if len(v) == 0:
raise ValueError("chunk_template must not be empty")
return v
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
"""Index documents so they can be used by the RAG system.
:param documents: List of documents to index in the RAG system
:param vector_db_id: ID of the vector database to store the document embeddings
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
"""
...
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent.
:param content: The query content to search for in the indexed documents
:param vector_db_ids: List of vector database IDs to search within
:param query_config: (Optional) Configuration parameters for the query operation
:returns: RAGQueryResult containing the retrieved content and metadata
"""
...

View file

@ -0,0 +1,221 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Literal, Protocol
from pydantic import BaseModel
from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from .rag_tool import RAGToolRuntime
@json_schema_type
class ToolDef(BaseModel):
"""Tool definition used in runtime contexts.
:param name: Name of the tool
:param description: (Optional) Human-readable description of what the tool does
:param input_schema: (Optional) JSON Schema for tool inputs (MCP inputSchema)
:param output_schema: (Optional) JSON Schema for tool outputs (MCP outputSchema)
:param metadata: (Optional) Additional metadata about the tool
:param toolgroup_id: (Optional) ID of the tool group this tool belongs to
"""
toolgroup_id: str | None = None
name: str
description: str | None = None
input_schema: dict[str, Any] | None = None
output_schema: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None
@json_schema_type
class ToolGroupInput(BaseModel):
"""Input data for registering a tool group.
:param toolgroup_id: Unique identifier for the tool group
:param provider_id: ID of the provider that will handle this tool group
:param args: (Optional) Additional arguments to pass to the provider
:param mcp_endpoint: (Optional) Model Context Protocol endpoint for remote tools
"""
toolgroup_id: str
provider_id: str
args: dict[str, Any] | None = None
mcp_endpoint: URL | None = None
@json_schema_type
class ToolGroup(Resource):
"""A group of related tools managed together.
:param type: Type of resource, always 'tool_group'
:param mcp_endpoint: (Optional) Model Context Protocol endpoint for remote tools
:param args: (Optional) Additional arguments for the tool group
"""
type: Literal[ResourceType.tool_group] = ResourceType.tool_group
mcp_endpoint: URL | None = None
args: dict[str, Any] | None = None
@json_schema_type
class ToolInvocationResult(BaseModel):
"""Result of a tool invocation.
:param content: (Optional) The output content from the tool execution
:param error_message: (Optional) Error message if the tool execution failed
:param error_code: (Optional) Numeric error code if the tool execution failed
:param metadata: (Optional) Additional metadata about the tool execution
"""
content: InterleavedContent | None = None
error_message: str | None = None
error_code: int | None = None
metadata: dict[str, Any] | None = None
class ToolStore(Protocol):
async def get_tool(self, tool_name: str) -> ToolDef: ...
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
class ListToolGroupsResponse(BaseModel):
"""Response containing a list of tool groups.
:param data: List of tool groups
"""
data: list[ToolGroup]
class ListToolDefsResponse(BaseModel):
"""Response containing a list of tool definitions.
:param data: List of tool definitions
"""
data: list[ToolDef]
@runtime_checkable
@trace_protocol
class ToolGroups(Protocol):
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1)
async def register_tool_group(
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
"""Register a tool group.
:param toolgroup_id: The ID of the tool group to register.
:param provider_id: The ID of the provider to use for the tool group.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:param args: A dictionary of arguments to pass to the tool group.
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_tool_group(
self,
toolgroup_id: str,
) -> ToolGroup:
"""Get a tool group by its ID.
:param toolgroup_id: The ID of the tool group to get.
:returns: A ToolGroup.
"""
...
@webmethod(route="/toolgroups", method="GET", level=LLAMA_STACK_API_V1)
async def list_tool_groups(self) -> ListToolGroupsResponse:
"""List tool groups with optional provider.
:returns: A ListToolGroupsResponse.
"""
...
@webmethod(route="/tools", method="GET", level=LLAMA_STACK_API_V1)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
"""List tools with optional tool group.
:param toolgroup_id: The ID of the tool group to list tools for.
:returns: A ListToolDefsResponse.
"""
...
@webmethod(route="/tools/{tool_name:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_tool(
self,
tool_name: str,
) -> ToolDef:
"""Get a tool by its name.
:param tool_name: The name of the tool to get.
:returns: A ToolDef.
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_toolgroup(
self,
toolgroup_id: str,
) -> None:
"""Unregister a tool group.
:param toolgroup_id: The ID of the tool group to unregister.
"""
...
class SpecialToolGroup(Enum):
"""Special tool groups with predefined functionality.
:cvar rag_tool: Retrieval-Augmented Generation tool group for document search and retrieval
"""
rag_tool = "rag_tool"
@runtime_checkable
@trace_protocol
class ToolRuntime(Protocol):
tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
"""List all tools in the runtime.
:param tool_group_id: The ID of the tool group to list tools for.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:returns: A ListToolDefsResponse.
"""
...
@webmethod(route="/tool-runtime/invoke", method="POST", level=LLAMA_STACK_API_V1)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
"""Run a tool with the given arguments.
:param tool_name: The name of the tool to invoke.
:param kwargs: A dictionary of arguments to pass to the tool.
:returns: A ToolInvocationResult.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .vector_io import *

View file

@ -0,0 +1,960 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from fastapi import Body
from pydantic import BaseModel, Field
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.strong_typing.schema import register_schema
@json_schema_type
class ChunkMetadata(BaseModel):
"""
`ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that
will not be used in the context during inference, but is required for backend functionality. The `ChunkMetadata`
is set during chunk creation in `MemoryToolRuntimeImpl().insert()`and is not expected to change after.
Use `Chunk.metadata` for metadata that will be used in the context during inference.
:param chunk_id: The ID of the chunk. If not set, it will be generated based on the document ID and content.
:param document_id: The ID of the document this chunk belongs to.
:param source: The source of the content, such as a URL, file path, or other identifier.
:param created_timestamp: An optional timestamp indicating when the chunk was created.
:param updated_timestamp: An optional timestamp indicating when the chunk was last updated.
:param chunk_window: The window of the chunk, which can be used to group related chunks together.
:param chunk_tokenizer: The tokenizer used to create the chunk. Default is Tiktoken.
:param chunk_embedding_model: The embedding model used to create the chunk's embedding.
:param chunk_embedding_dimension: The dimension of the embedding vector for the chunk.
:param content_token_count: The number of tokens in the content of the chunk.
:param metadata_token_count: The number of tokens in the metadata of the chunk.
"""
chunk_id: str | None = None
document_id: str | None = None
source: str | None = None
created_timestamp: int | None = None
updated_timestamp: int | None = None
chunk_window: str | None = None
chunk_tokenizer: str | None = None
chunk_embedding_model: str | None = None
chunk_embedding_dimension: int | None = None
content_token_count: int | None = None
metadata_token_count: int | None = None
@json_schema_type
class Chunk(BaseModel):
"""
A chunk of content that can be inserted into a vector database.
:param content: The content of the chunk, which can be interleaved text, images, or other types.
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
:param metadata: Metadata associated with the chunk that will be used in the model context during inference.
:param stored_chunk_id: The chunk ID that is stored in the vector database. Used for backend functionality.
:param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference.
The `chunk_metadata` is required backend functionality.
"""
content: InterleavedContent
metadata: dict[str, Any] = Field(default_factory=dict)
embedding: list[float] | None = None
# The alias parameter serializes the field as "chunk_id" in JSON but keeps the internal name as "stored_chunk_id"
stored_chunk_id: str | None = Field(default=None, alias="chunk_id")
chunk_metadata: ChunkMetadata | None = None
model_config = {"populate_by_name": True}
def model_post_init(self, __context):
# Extract chunk_id from metadata if present
if self.metadata and "chunk_id" in self.metadata:
self.stored_chunk_id = self.metadata.pop("chunk_id")
@property
def chunk_id(self) -> str:
"""Returns the chunk ID, which is either an input `chunk_id` or a generated one if not set."""
if self.stored_chunk_id:
return self.stored_chunk_id
if "document_id" in self.metadata:
return generate_chunk_id(self.metadata["document_id"], str(self.content))
return generate_chunk_id(str(uuid.uuid4()), str(self.content))
@property
def document_id(self) -> str | None:
"""Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence."""
# Check metadata first (takes precedence)
doc_id = self.metadata.get("document_id")
if doc_id is not None:
if not isinstance(doc_id, str):
raise TypeError(f"metadata['document_id'] must be a string, got {type(doc_id).__name__}: {doc_id!r}")
return doc_id
# Fall back to chunk_metadata if available (Pydantic ensures type safety)
if self.chunk_metadata is not None:
return self.chunk_metadata.document_id
return None
@json_schema_type
class QueryChunksResponse(BaseModel):
"""Response from querying chunks in a vector database.
:param chunks: List of content chunks returned from the query
:param scores: Relevance scores corresponding to each returned chunk
"""
chunks: list[Chunk]
scores: list[float]
@json_schema_type
class VectorStoreFileCounts(BaseModel):
"""File processing status counts for a vector store.
:param completed: Number of files that have been successfully processed
:param cancelled: Number of files that had their processing cancelled
:param failed: Number of files that failed to process
:param in_progress: Number of files currently being processed
:param total: Total number of files in the vector store
"""
completed: int
cancelled: int
failed: int
in_progress: int
total: int
# TODO: rename this as OpenAIVectorStore
@json_schema_type
class VectorStoreObject(BaseModel):
"""OpenAI Vector Store object.
:param id: Unique identifier for the vector store
:param object: Object type identifier, always "vector_store"
:param created_at: Timestamp when the vector store was created
:param name: (Optional) Name of the vector store
:param usage_bytes: Storage space used by the vector store in bytes
:param file_counts: File processing status counts for the vector store
:param status: Current status of the vector store
:param expires_after: (Optional) Expiration policy for the vector store
:param expires_at: (Optional) Timestamp when the vector store will expire
:param last_active_at: (Optional) Timestamp of last activity on the vector store
:param metadata: Set of key-value pairs that can be attached to the vector store
"""
id: str
object: str = "vector_store"
created_at: int
name: str | None = None
usage_bytes: int = 0
file_counts: VectorStoreFileCounts
status: str = "completed"
expires_after: dict[str, Any] | None = None
expires_at: int | None = None
last_active_at: int | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class VectorStoreCreateRequest(BaseModel):
"""Request to create a vector store.
:param name: (Optional) Name for the vector store
:param file_ids: List of file IDs to include in the vector store
:param expires_after: (Optional) Expiration policy for the vector store
:param chunking_strategy: (Optional) Strategy for splitting files into chunks
:param metadata: Set of key-value pairs that can be attached to the vector store
"""
name: str | None = None
file_ids: list[str] = Field(default_factory=list)
expires_after: dict[str, Any] | None = None
chunking_strategy: dict[str, Any] | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class VectorStoreModifyRequest(BaseModel):
"""Request to modify a vector store.
:param name: (Optional) Updated name for the vector store
:param expires_after: (Optional) Updated expiration policy for the vector store
:param metadata: (Optional) Updated set of key-value pairs for the vector store
"""
name: str | None = None
expires_after: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None
@json_schema_type
class VectorStoreListResponse(BaseModel):
"""Response from listing vector stores.
:param object: Object type identifier, always "list"
:param data: List of vector store objects
:param first_id: (Optional) ID of the first vector store in the list for pagination
:param last_id: (Optional) ID of the last vector store in the list for pagination
:param has_more: Whether there are more vector stores available beyond this page
"""
object: str = "list"
data: list[VectorStoreObject]
first_id: str | None = None
last_id: str | None = None
has_more: bool = False
@json_schema_type
class VectorStoreSearchRequest(BaseModel):
"""Request to search a vector store.
:param query: Search query as a string or list of strings
:param filters: (Optional) Filters based on file attributes to narrow search results
:param max_num_results: Maximum number of results to return, defaults to 10
:param ranking_options: (Optional) Options for ranking and filtering search results
:param rewrite_query: Whether to rewrite the query for better vector search performance
"""
query: str | list[str]
filters: dict[str, Any] | None = None
max_num_results: int = 10
ranking_options: dict[str, Any] | None = None
rewrite_query: bool = False
@json_schema_type
class VectorStoreContent(BaseModel):
"""Content item from a vector store file or search result.
:param type: Content type, currently only "text" is supported
:param text: The actual text content
"""
type: Literal["text"]
text: str
@json_schema_type
class VectorStoreSearchResponse(BaseModel):
"""Response from searching a vector store.
:param file_id: Unique identifier of the file containing the result
:param filename: Name of the file containing the result
:param score: Relevance score for this search result
:param attributes: (Optional) Key-value attributes associated with the file
:param content: List of content items matching the search query
"""
file_id: str
filename: str
score: float
attributes: dict[str, str | float | bool] | None = None
content: list[VectorStoreContent]
@json_schema_type
class VectorStoreSearchResponsePage(BaseModel):
"""Paginated response from searching a vector store.
:param object: Object type identifier for the search results page
:param search_query: The original search query that was executed
:param data: List of search result objects
:param has_more: Whether there are more results available beyond this page
:param next_page: (Optional) Token for retrieving the next page of results
"""
object: str = "vector_store.search_results.page"
search_query: str
data: list[VectorStoreSearchResponse]
has_more: bool = False
next_page: str | None = None
@json_schema_type
class VectorStoreDeleteResponse(BaseModel):
"""Response from deleting a vector store.
:param id: Unique identifier of the deleted vector store
:param object: Object type identifier for the deletion response
:param deleted: Whether the deletion operation was successful
"""
id: str
object: str = "vector_store.deleted"
deleted: bool = True
@json_schema_type
class VectorStoreChunkingStrategyAuto(BaseModel):
"""Automatic chunking strategy for vector store files.
:param type: Strategy type, always "auto" for automatic chunking
"""
type: Literal["auto"] = "auto"
@json_schema_type
class VectorStoreChunkingStrategyStaticConfig(BaseModel):
"""Configuration for static chunking strategy.
:param chunk_overlap_tokens: Number of tokens to overlap between adjacent chunks
:param max_chunk_size_tokens: Maximum number of tokens per chunk, must be between 100 and 4096
"""
chunk_overlap_tokens: int = 400
max_chunk_size_tokens: int = Field(800, ge=100, le=4096)
@json_schema_type
class VectorStoreChunkingStrategyStatic(BaseModel):
"""Static chunking strategy with configurable parameters.
:param type: Strategy type, always "static" for static chunking
:param static: Configuration parameters for the static chunking strategy
"""
type: Literal["static"] = "static"
static: VectorStoreChunkingStrategyStaticConfig
VectorStoreChunkingStrategy = Annotated[
VectorStoreChunkingStrategyAuto | VectorStoreChunkingStrategyStatic,
Field(discriminator="type"),
]
register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
class SearchRankingOptions(BaseModel):
"""Options for ranking and filtering search results.
:param ranker: (Optional) Name of the ranking algorithm to use
:param score_threshold: (Optional) Minimum relevance score threshold for results
"""
ranker: str | None = None
# NOTE: OpenAI File Search Tool requires threshold to be between 0 and 1, however
# we don't guarantee that the score is between 0 and 1, so will leave this unconstrained
# and let the provider handle it
score_threshold: float | None = Field(default=0.0)
@json_schema_type
class VectorStoreFileLastError(BaseModel):
"""Error information for failed vector store file processing.
:param code: Error code indicating the type of failure
:param message: Human-readable error message describing the failure
"""
code: Literal["server_error"] | Literal["rate_limit_exceeded"]
message: str
VectorStoreFileStatus = Literal["completed"] | Literal["in_progress"] | Literal["cancelled"] | Literal["failed"]
register_schema(VectorStoreFileStatus, name="VectorStoreFileStatus")
@json_schema_type
class VectorStoreFileObject(BaseModel):
"""OpenAI Vector Store File object.
:param id: Unique identifier for the file
:param object: Object type identifier, always "vector_store.file"
:param attributes: Key-value attributes associated with the file
:param chunking_strategy: Strategy used for splitting the file into chunks
:param created_at: Timestamp when the file was added to the vector store
:param last_error: (Optional) Error information if file processing failed
:param status: Current processing status of the file
:param usage_bytes: Storage space used by this file in bytes
:param vector_store_id: ID of the vector store containing this file
"""
id: str
object: str = "vector_store.file"
attributes: dict[str, Any] = Field(default_factory=dict)
chunking_strategy: VectorStoreChunkingStrategy
created_at: int
last_error: VectorStoreFileLastError | None = None
status: VectorStoreFileStatus
usage_bytes: int = 0
vector_store_id: str
@json_schema_type
class VectorStoreListFilesResponse(BaseModel):
"""Response from listing files in a vector store.
:param object: Object type identifier, always "list"
:param data: List of vector store file objects
:param first_id: (Optional) ID of the first file in the list for pagination
:param last_id: (Optional) ID of the last file in the list for pagination
:param has_more: Whether there are more files available beyond this page
"""
object: str = "list"
data: list[VectorStoreFileObject]
first_id: str | None = None
last_id: str | None = None
has_more: bool = False
@json_schema_type
class VectorStoreFileContentsResponse(BaseModel):
"""Response from retrieving the contents of a vector store file.
:param file_id: Unique identifier for the file
:param filename: Name of the file
:param attributes: Key-value attributes associated with the file
:param content: List of content items from the file
"""
file_id: str
filename: str
attributes: dict[str, Any]
content: list[VectorStoreContent]
@json_schema_type
class VectorStoreFileDeleteResponse(BaseModel):
"""Response from deleting a vector store file.
:param id: Unique identifier of the deleted file
:param object: Object type identifier for the deletion response
:param deleted: Whether the deletion operation was successful
"""
id: str
object: str = "vector_store.file.deleted"
deleted: bool = True
@json_schema_type
class VectorStoreFileBatchObject(BaseModel):
"""OpenAI Vector Store File Batch object.
:param id: Unique identifier for the file batch
:param object: Object type identifier, always "vector_store.file_batch"
:param created_at: Timestamp when the file batch was created
:param vector_store_id: ID of the vector store containing the file batch
:param status: Current processing status of the file batch
:param file_counts: File processing status counts for the batch
"""
id: str
object: str = "vector_store.file_batch"
created_at: int
vector_store_id: str
status: VectorStoreFileStatus
file_counts: VectorStoreFileCounts
@json_schema_type
class VectorStoreFilesListInBatchResponse(BaseModel):
"""Response from listing files in a vector store file batch.
:param object: Object type identifier, always "list"
:param data: List of vector store file objects in the batch
:param first_id: (Optional) ID of the first file in the list for pagination
:param last_id: (Optional) ID of the last file in the list for pagination
:param has_more: Whether there are more files available beyond this page
"""
object: str = "list"
data: list[VectorStoreFileObject]
first_id: str | None = None
last_id: str | None = None
has_more: bool = False
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAICreateVectorStoreRequestWithExtraBody(BaseModel, extra="allow"):
"""Request to create a vector store with extra_body support.
:param name: (Optional) A name for the vector store
:param file_ids: List of file IDs to include in the vector store
:param expires_after: (Optional) Expiration policy for the vector store
:param chunking_strategy: (Optional) Strategy for splitting files into chunks
:param metadata: Set of key-value pairs that can be attached to the vector store
"""
name: str | None = None
file_ids: list[str] | None = None
expires_after: dict[str, Any] | None = None
chunking_strategy: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAICreateVectorStoreFileBatchRequestWithExtraBody(BaseModel, extra="allow"):
"""Request to create a vector store file batch with extra_body support.
:param file_ids: A list of File IDs that the vector store should use
:param attributes: (Optional) Key-value attributes to store with the files
:param chunking_strategy: (Optional) The chunking strategy used to chunk the file(s). Defaults to auto
"""
file_ids: list[str]
attributes: dict[str, Any] | None = None
chunking_strategy: VectorStoreChunkingStrategy | None = None
class VectorStoreTable(Protocol):
def get_vector_store(self, vector_store_id: str) -> VectorStore | None: ...
@runtime_checkable
@trace_protocol
class VectorIO(Protocol):
vector_store_table: VectorStoreTable | None = None
# this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion
# TODO: rename vector_db_id to vector_store_id once Stainless is working
@webmethod(route="/vector-io/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
"""Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into.
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
If `embedding` is not provided, it will be computed later.
:param ttl_seconds: The time to live of the chunks.
"""
...
# TODO: rename vector_db_id to vector_store_id once Stainless is working
@webmethod(route="/vector-io/query", method="POST", level=LLAMA_STACK_API_V1)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
"""Query chunks from a vector database.
:param vector_db_id: The identifier of the vector database to query.
:param query: The query to search for.
:param params: The parameters of the query.
:returns: A QueryChunksResponse.
"""
...
# OpenAI Vector Stores API endpoints
@webmethod(route="/openai/v1/vector_stores", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
async def openai_create_vector_store(
self,
params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
) -> VectorStoreObject:
"""Creates a vector store.
Generate an OpenAI-compatible vector store with the given parameters.
:returns: A VectorStoreObject representing the created vector store.
"""
...
@webmethod(route="/openai/v1/vector_stores", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/vector_stores", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
"""Returns a list of vector stores.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list.
:param before: A cursor for use in pagination. `before` is an object ID that defines your place in the list.
:returns: A VectorStoreListResponse containing the list of vector stores.
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
"""Retrieves a vector store.
:param vector_store_id: The ID of the vector store to retrieve.
:returns: A VectorStoreObject representing the vector store.
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(
route="/vector_stores/{vector_store_id}",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
"""Updates a vector store.
:param vector_store_id: The ID of the vector store to update.
:param name: The name of the vector store.
:param expires_after: The expiration policy for a vector store.
:param metadata: Set of 16 key-value pairs that can be attached to an object.
:returns: A VectorStoreObject representing the updated vector store.
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(
route="/vector_stores/{vector_store_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
"""Delete a vector store.
:param vector_store_id: The ID of the vector store to delete.
:returns: A VectorStoreDeleteResponse indicating the deletion status.
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/search",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/search",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: (
str | None
) = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store.
Searches a vector store for relevant chunks based on a query and optional file attribute filters.
:param vector_store_id: The ID of the vector store to search.
:param query: The query string or array for performing the search.
:param filters: Filters based on file attributes to narrow the search results.
:param max_num_results: Maximum number of results to return (1 to 50 inclusive, default 10).
:param ranking_options: Ranking options for fine-tuning the search results.
:param rewrite_query: Whether to rewrite the natural language query for vector search (default false)
:param search_mode: The search mode to use - "keyword", "vector", or "hybrid" (default "vector")
:returns: A VectorStoreSearchResponse containing the search results.
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
"""Attach a file to a vector store.
:param vector_store_id: The ID of the vector store to attach the file to.
:param file_id: The ID of the file to attach to the vector store.
:param attributes: The key-value attributes stored with the file, which can be used for filtering.
:param chunking_strategy: The chunking strategy to use for the file.
:returns: A VectorStoreFileObject representing the attached file.
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> VectorStoreListFilesResponse:
"""List files in a vector store.
:param vector_store_id: The ID of the vector store to list files from.
:param limit: (Optional) A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
:param order: (Optional) Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
:param after: (Optional) A cursor for use in pagination. `after` is an object ID that defines your place in the list.
:param before: (Optional) A cursor for use in pagination. `before` is an object ID that defines your place in the list.
:param filter: (Optional) Filter by file status to only return files with the specified status.
:returns: A VectorStoreListFilesResponse containing the list of files.
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
"""Retrieves a vector store file.
:param vector_store_id: The ID of the vector store containing the file to retrieve.
:param file_id: The ID of the file to retrieve.
:returns: A VectorStoreFileObject representing the file.
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}/content",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
"""Retrieves the contents of a vector store file.
:param vector_store_id: The ID of the vector store containing the file to retrieve.
:param file_id: The ID of the file to retrieve.
:returns: A list of InterleavedContent representing the file contents.
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
"""Updates a vector store file.
:param vector_store_id: The ID of the vector store containing the file to update.
:param file_id: The ID of the file to update.
:param attributes: The updated key-value attributes to store with the file.
:returns: A VectorStoreFileObject representing the updated file.
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
"""Delete a vector store file.
:param vector_store_id: The ID of the vector store containing the file to delete.
:param file_id: The ID of the file to delete.
:returns: A VectorStoreFileDeleteResponse indicating the deletion status.
"""
...
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches",
method="POST",
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)],
) -> VectorStoreFileBatchObject:
"""Create a vector store file batch.
Generate an OpenAI-compatible vector store file batch for the given vector store.
:param vector_store_id: The ID of the vector store to create the file batch for.
:returns: A VectorStoreFileBatchObject representing the created file batch.
"""
...
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}",
method="GET",
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Retrieve a vector store file batch.
:param batch_id: The ID of the file batch to retrieve.
:param vector_store_id: The ID of the vector store containing the file batch.
:returns: A VectorStoreFileBatchObject representing the file batch.
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
) -> VectorStoreFilesListInBatchResponse:
"""Returns a list of vector store files in a batch.
:param batch_id: The ID of the file batch to list files from.
:param vector_store_id: The ID of the vector store containing the file batch.
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list.
:param before: A cursor for use in pagination. `before` is an object ID that defines your place in the list.
:param filter: Filter by file status. One of in_progress, completed, failed, cancelled.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
:returns: A VectorStoreFilesListInBatchResponse containing the list of files in the batch.
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Cancels a vector store file batch.
:param batch_id: The ID of the file batch to cancel.
:param vector_store_id: The ID of the vector store containing the file batch.
:returns: A VectorStoreFileBatchObject representing the cancelled file batch.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .vector_stores import *

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
# Internal resource type for storing the vector store routing and other information
class VectorStore(Resource):
"""Vector database resource for storing and querying vector embeddings.
:param type: Type of resource, always 'vector_store' for vector stores
:param embedding_model: Name of the embedding model to use for vector generation
:param embedding_dimension: Dimension of the embedding vectors
"""
type: Literal[ResourceType.vector_store] = ResourceType.vector_store
embedding_model: str
embedding_dimension: int
vector_store_name: str | None = None
@property
def vector_store_id(self) -> str:
return self.identifier
@property
def provider_vector_store_id(self) -> str | None:
return self.provider_resource_id
class VectorStoreInput(BaseModel):
"""Input parameters for creating or configuring a vector database.
:param vector_store_id: Unique identifier for the vector store
:param embedding_model: Name of the embedding model to use for vector generation
:param embedding_dimension: Dimension of the embedding vectors
:param provider_vector_store_id: (Optional) Provider-specific identifier for the vector store
"""
vector_store_id: str
embedding_model: str
embedding_dimension: int
provider_id: str | None = None
provider_vector_store_id: str | None = None

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_STACK_API_V1 = "v1"
LLAMA_STACK_API_V1BETA = "v1beta"
LLAMA_STACK_API_V1ALPHA = "v1alpha"

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.log import setup_logging
from .stack import StackParser
from .stack.utils import print_subcommand_description
class LlamaCLIParser:
"""Defines CLI parser for Llama CLI"""
def __init__(self):
self.parser = argparse.ArgumentParser(
prog="llama",
description="Welcome to the Llama CLI",
add_help=True,
formatter_class=argparse.RawTextHelpFormatter,
)
# Default command is to print help
self.parser.set_defaults(func=lambda args: self.parser.print_help())
subparsers = self.parser.add_subparsers(title="subcommands")
# Add sub-commands
StackParser.create(subparsers)
print_subcommand_description(self.parser, subparsers)
def parse_args(self) -> argparse.Namespace:
args = self.parser.parse_args()
if not isinstance(args, argparse.Namespace):
raise TypeError(f"Expected argparse.Namespace, got {type(args)}")
return args
def run(self, args: argparse.Namespace) -> None:
args.func(args)
def main():
# Initialize logging from environment variables before any other operations
setup_logging()
parser = LlamaCLIParser()
args = parser.parse_args()
parser.run(args)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,38 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
if [ $# -eq 0 ]; then
echo "Please provide a URL as an argument."
exit 1
fi
URL=$1
HEADERS_FILE=$(mktemp)
curl -s -I "$URL" >"$HEADERS_FILE"
FILENAME=$(grep -i "x-manifold-obj-canonicalpath:" "$HEADERS_FILE" | sed -E 's/.*nodes\/[^\/]+\/(.+)/\1/' | tr -d "\r\n")
if [ -z "$FILENAME" ]; then
echo "Could not find the x-manifold-obj-canonicalpath header."
echo "HEADERS_FILE contents: "
cat "$HEADERS_FILE"
echo ""
exit 1
fi
echo "Downloading $FILENAME..."
curl -s -L -o "$FILENAME" "$URL"
echo "Installing $FILENAME..."
pip install "$FILENAME"
echo "Successfully installed $FILENAME"
rm -f "$FILENAME"

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import subprocess
import sys
def install_wheel_from_presigned():
file = "install-wheel-from-presigned.sh"
script_path = os.path.join(os.path.dirname(__file__), file)
try:
subprocess.run(["sh", script_path] + sys.argv[1:], check=True)
except Exception:
sys.exit(1)

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .stack import StackParser # noqa

View file

@ -0,0 +1,182 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import sys
from pathlib import Path
import yaml
from termcolor import cprint
from llama_stack.cli.stack.utils import ImageType
from llama_stack.core.build import get_provider_dependencies
from llama_stack.core.datatypes import (
BuildConfig,
BuildProvider,
DistributionSpec,
)
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.stack import replace_env_vars
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
logger = get_logger(name=__name__, category="cli")
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"aiosqlite",
"fastapi",
"fire",
"httpx",
"uvicorn",
"opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http",
]
def format_output_deps_only(
normal_deps: list[str],
special_deps: list[str],
external_deps: list[str],
uv: bool = False,
) -> str:
"""Format dependencies as a list."""
lines = []
uv_str = ""
if uv:
uv_str = "uv pip install "
# Quote deps with commas
quoted_normal_deps = [quote_if_needed(dep) for dep in normal_deps]
lines.append(f"{uv_str}{' '.join(quoted_normal_deps)}")
for special_dep in special_deps:
lines.append(f"{uv_str}{quote_special_dep(special_dep)}")
for external_dep in external_deps:
lines.append(f"{uv_str}{quote_special_dep(external_dep)}")
return "\n".join(lines)
def run_stack_list_deps_command(args: argparse.Namespace) -> None:
if args.config:
try:
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
config_file = resolve_config_or_distro(args.config, Mode.BUILD)
except ValueError as e:
cprint(
f"Could not parse config file {args.config}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(1)
if config_file:
with open(config_file) as f:
try:
contents = yaml.safe_load(f)
contents = replace_env_vars(contents)
build_config = BuildConfig(**contents)
build_config.image_type = "venv"
except Exception as e:
cprint(
f"Could not parse config file {config_file}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(1)
elif args.providers:
provider_list: dict[str, list[BuildProvider]] = dict()
for api_provider in args.providers.split(","):
if "=" not in api_provider:
cprint(
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
color="red",
file=sys.stderr,
)
sys.exit(1)
api, provider_type = api_provider.split("=")
providers_for_api = get_provider_registry().get(Api(api), None)
if providers_for_api is None:
cprint(
f"{api} is not a valid API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
if provider_type in providers_for_api:
provider = BuildProvider(
provider_type=provider_type,
module=None,
)
provider_list.setdefault(api, []).append(provider)
else:
cprint(
f"{provider_type} is not a valid provider for the {api} API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
distribution_spec = DistributionSpec(
providers=provider_list,
description=",".join(args.providers),
)
build_config = BuildConfig(image_type=ImageType.VENV.value, distribution_spec=distribution_spec)
normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
# Add external API dependencies
if build_config.external_apis_dir:
from llama_stack.core.external import load_external_apis
external_apis = load_external_apis(build_config)
if external_apis:
for _, api_spec in external_apis.items():
normal_deps.extend(api_spec.pip_packages)
# Format and output based on requested format
output = format_output_deps_only(
normal_deps=normal_deps,
special_deps=special_deps,
external_deps=external_provider_dependencies,
uv=args.format == "uv",
)
print(output)
def quote_if_needed(dep):
# Add quotes if the dependency contains special characters that need escaping in shell
# This includes: commas, comparison operators (<, >, <=, >=, ==, !=)
needs_quoting = any(char in dep for char in [",", "<", ">", "="])
return f"'{dep}'" if needs_quoting else dep
def quote_special_dep(dep_string):
"""
Quote individual packages in a special dependency string.
Special deps may contain multiple packages and flags like --extra-index-url.
We need to quote only the package specs that contain special characters.
"""
parts = dep_string.split()
quoted_parts = []
for part in parts:
# Don't quote flags (they start with -)
if part.startswith("-"):
quoted_parts.append(part)
else:
# Quote package specs that need it
quoted_parts.append(quote_if_needed(part))
return " ".join(quoted_parts)

View file

@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class StackListApis(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list-apis",
prog="llama stack list-apis",
description="List APIs part of the Llama Stack implementation",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_apis_list_cmd)
def _add_arguments(self):
pass
def _run_apis_list_cmd(self, args: argparse.Namespace) -> None:
from llama_stack.cli.table import print_table
from llama_stack.core.distribution import stack_apis
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"API",
]
rows = []
for api in stack_apis():
rows.append(
[
api.value,
]
)
print_table(
rows,
headers,
separate_rows=True,
)

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class StackListDeps(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list-deps",
prog="llama stack list-deps",
description="list the dependencies for a llama stack distribution",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_list_deps_command)
def _add_arguments(self):
self.parser.add_argument(
"config",
type=str,
nargs="?", # Make it optional
metavar="config | distro",
help="Path to config file to use or name of known distro (llama stack list for a list).",
)
self.parser.add_argument(
"--providers",
type=str,
default=None,
help="sync dependencies for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
)
self.parser.add_argument(
"--format",
type=str,
choices=["uv", "deps-only"],
default="deps-only",
help="Output format: 'uv' shows shell commands, 'deps-only' shows just the list of dependencies without `uv` (default)",
)
def _run_stack_list_deps_command(self, args: argparse.Namespace) -> None:
# always keep implementation completely silo-ed away from CLI so CLI
# can be fast to load and reduces dependencies
from ._list_deps import run_stack_list_deps_command
return run_stack_list_deps_command(args)

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class StackListProviders(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list-providers",
prog="llama stack list-providers",
description="Show available Llama Stack Providers for an API",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_providers_list_cmd)
@property
def providable_apis(self):
from llama_stack.core.distribution import providable_apis
return [api.value for api in providable_apis()]
def _add_arguments(self):
self.parser.add_argument(
"api",
type=str,
choices=self.providable_apis,
nargs="?",
help="API to list providers for. List all if not specified.",
)
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
from llama_stack.cli.table import print_table
from llama_stack.core.distribution import Api, get_provider_registry
all_providers = get_provider_registry()
if args.api:
providers = [(args.api, all_providers[Api(args.api)])]
else:
providers = [(k.value, prov) for k, prov in all_providers.items()]
providers = [(api, p) for api, p in providers if api in self.providable_apis]
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"API Type",
"Provider Type",
"PIP Package Dependencies",
]
rows = []
specs = [spec for api, p in providers for spec in p.values()]
for spec in specs:
if spec.is_sample:
continue
rows.append(
[
spec.api.value,
spec.provider_type,
",".join(spec.pip_packages) if hasattr(spec, "pip_packages") else "",
]
)
print_table(
rows,
headers,
separate_rows=True,
sort_by=(0, 1),
)

View file

@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class StackListBuilds(Subcommand):
"""List built stacks in .llama/distributions directory"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama stack list",
description="list the build stacks",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._list_stack_command)
def _get_distribution_dirs(self) -> dict[str, Path]:
"""Return a dictionary of distribution names and their paths"""
distributions = {}
dist_dir = Path.home() / ".llama" / "distributions"
if dist_dir.exists():
for stack_dir in dist_dir.iterdir():
if stack_dir.is_dir():
distributions[stack_dir.name] = stack_dir
return distributions
def _list_stack_command(self, args: argparse.Namespace) -> None:
distributions = self._get_distribution_dirs()
if not distributions:
print("No stacks found in ~/.llama/distributions")
return
headers = ["Stack Name", "Path"]
headers.extend(["Build Config", "Run Config"])
rows = []
for name, path in distributions.items():
row = [name, str(path)]
# Check for build and run config files
build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No"
run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No"
row.extend([build_config, run_config])
rows.append(row)
print_table(rows, headers, separate_rows=True)

View file

@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import shutil
import sys
from pathlib import Path
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class StackRemove(Subcommand):
"""Remove the build stack"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"rm",
prog="llama stack rm",
description="Remove the build stack",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._remove_stack_build_command)
def _add_arguments(self) -> None:
self.parser.add_argument(
"name",
type=str,
nargs="?",
help="Name of the stack to delete",
)
self.parser.add_argument(
"--all",
"-a",
action="store_true",
help="Delete all stacks (use with caution)",
)
def _get_distribution_dirs(self) -> dict[str, Path]:
"""Return a dictionary of distribution names and their paths"""
distributions = {}
dist_dir = Path.home() / ".llama" / "distributions"
if dist_dir.exists():
for stack_dir in dist_dir.iterdir():
if stack_dir.is_dir():
distributions[stack_dir.name] = stack_dir
return distributions
def _list_stacks(self) -> None:
"""Display available stacks in a table"""
distributions = self._get_distribution_dirs()
if not distributions:
cprint("No stacks found in ~/.llama/distributions", color="red", file=sys.stderr)
sys.exit(1)
headers = ["Stack Name", "Path"]
rows = [[name, str(path)] for name, path in distributions.items()]
print_table(rows, headers, separate_rows=True)
def _remove_stack_build_command(self, args: argparse.Namespace) -> None:
distributions = self._get_distribution_dirs()
if args.all:
confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower()
if confirm != "yes-i-really-want":
cprint("Deletion cancelled.", color="green", file=sys.stderr)
return
for name, path in distributions.items():
try:
shutil.rmtree(path)
cprint(f"Deleted stack: {name}", color="green", file=sys.stderr)
except Exception as e:
cprint(
f"Failed to delete stack {name}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(1)
if not args.name:
self._list_stacks()
if not args.name:
return
if args.name not in distributions:
self._list_stacks()
cprint(
f"Stack not found: {args.name}",
color="red",
file=sys.stderr,
)
sys.exit(1)
stack_path = distributions[args.name]
confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower()
if confirm != "y":
cprint("Deletion cancelled.", color="green", file=sys.stderr)
return
try:
shutil.rmtree(stack_path)
cprint(f"Successfully deleted stack: {args.name}", color="green", file=sys.stderr)
except Exception as e:
cprint(f"Failed to delete stack {args.name}: {e}", color="red", file=sys.stderr)
sys.exit(1)

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import ssl
import subprocess
from pathlib import Path
import uvicorn
import yaml
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.subcommand import Subcommand
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.log import LoggingConfig, get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="cli")
class StackRun(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"run",
prog="llama stack run",
description="""Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_run_cmd)
def _add_arguments(self):
self.parser.add_argument(
"config",
type=str,
nargs="?", # Make it optional
metavar="config | distro",
help="Path to config file to use for the run or name of known distro (`llama stack list` for a list).",
)
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT.",
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
)
self.parser.add_argument(
"--image-name",
type=str,
default=None,
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
)
self.parser.add_argument(
"--image-type",
type=str,
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value],
)
self.parser.add_argument(
"--enable-ui",
action="store_true",
help="Start the UI server",
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
from llama_stack.core.configure import parse_and_maybe_upgrade_config
if args.image_type or args.image_name:
self.parser.error(
"The --image-type and --image-name flags are no longer supported.\n\n"
"Please activate your virtual environment manually before running `llama stack run`.\n\n"
"For example:\n"
" source /path/to/venv/bin/activate\n"
" llama stack run <config>\n"
)
if args.enable_ui:
self._start_ui_development_server(args.port)
if args.config:
try:
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
config_file = resolve_config_or_distro(args.config, Mode.RUN)
except ValueError as e:
self.parser.error(str(e))
else:
config_file = None
if config_file:
logger.info(f"Using run configuration: {config_file}")
try:
config_dict = yaml.safe_load(config_file.read_text())
except yaml.parser.ParserError as e:
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
try:
config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(str(config.external_providers_dir)):
os.makedirs(str(config.external_providers_dir), exist_ok=True)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
self._uvicorn_run(config_file, args)
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
if not config_file:
self.parser.error("Config file is required")
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
port = args.port or config.server.port
host = config.server.host or ["::", "0.0.0.0"]
# Set the config file in environment so create_app can find it
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
uvicorn_config = {
"factory": True,
"host": host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
}
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
uvicorn_config["ssl_keyfile"] = config.server.tls_keyfile
uvicorn_config["ssl_certfile"] = config.server.tls_certfile
if config.server.tls_cafile:
uvicorn_config["ssl_ca_certs"] = config.server.tls_cafile
uvicorn_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
logger.info(f"Listening on {host}:{port}")
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
# stack trace when using Ctrl+C or kill -2 (SIGINT).
# SIGTERM (kill -15) works fine without this because Python doesn't
# have a default handler for it.
#
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
uvicorn.run("llama_stack.core.server.server:create_app", **uvicorn_config)
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
def _start_ui_development_server(self, stack_server_port: int):
logger.info("Attempting to start UI development server...")
# Check if npm is available
npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False)
if npm_check.returncode != 0:
logger.warning(
f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}"
)
return
ui_dir = REPO_ROOT / "llama_stack" / "ui"
logs_dir = Path("~/.llama/ui/logs").expanduser()
try:
# Create logs directory if it doesn't exist
logs_dir.mkdir(parents=True, exist_ok=True)
ui_stdout_log_path = logs_dir / "stdout.log"
ui_stderr_log_path = logs_dir / "stderr.log"
# Open log files in append mode
stdout_log_file = open(ui_stdout_log_path, "a")
stderr_log_file = open(ui_stderr_log_path, "a")
process = subprocess.Popen(
["npm", "run", "dev"],
cwd=str(ui_dir),
stdout=stdout_log_file,
stderr=stderr_log_file,
env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"},
)
logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
except FileNotFoundError:
logger.error(
"Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH."
)
except Exception as e:
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")

View file

@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from importlib.metadata import version
from llama_stack.cli.stack.list_stacks import StackListBuilds
from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand
from .list_apis import StackListApis
from .list_deps import StackListDeps
from .list_providers import StackListProviders
from .remove import StackRemove
from .run import StackRun
class StackParser(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"stack",
prog="llama stack",
description="Operations for the Llama Stack / Distributions",
formatter_class=argparse.RawTextHelpFormatter,
)
self.parser.add_argument(
"--version",
action="version",
version=f"{version('llama-stack')}",
)
self.parser.set_defaults(func=lambda args: self.parser.print_help())
subparsers = self.parser.add_subparsers(title="stack_subcommands")
# Add sub-commands
StackListDeps.create(subparsers)
StackListApis.create(subparsers)
StackListProviders.create(subparsers)
StackRun.create(subparsers)
StackRemove.create(subparsers)
StackListBuilds.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -0,0 +1,151 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import sys
from enum import Enum
from functools import lru_cache
from pathlib import Path
import yaml
from termcolor import cprint
from llama_stack.core.datatypes import (
BuildConfig,
Provider,
StackRunConfig,
StorageConfig,
)
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.resolver import InvalidProviderError
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
)
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "distributions"
class ImageType(Enum):
CONTAINER = "container"
VENV = "venv"
def print_subcommand_description(parser, subparsers):
"""Print descriptions of subcommands."""
description_text = ""
for name, subcommand in subparsers.choices.items():
description = subcommand.description
description_text += f" {name:<21} {description}\n"
parser.epilog = description_text
def generate_run_config(
build_config: BuildConfig,
build_dir: Path,
image_name: str,
) -> Path:
"""
Generate a run.yaml template file for user to edit from a build.yaml file
"""
apis = list(build_config.distribution_spec.providers.keys())
distro_dir = DISTRIBS_BASE_DIR / image_name
run_config = StackRunConfig(
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
image_name=image_name,
apis=apis,
providers={},
storage=StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(db_path=str(distro_dir / "kvstore.db")),
"sql_default": SqliteSqlStoreConfig(db_path=str(distro_dir / "sql_store.db")),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
conversations=SqlStoreReference(backend="sql_default", table_name="openai_conversations"),
),
),
external_providers_dir=build_config.external_providers_dir
if build_config.external_providers_dir
else EXTERNAL_PROVIDERS_DIR,
)
# build providers dict
provider_registry = get_provider_registry(build_config)
for api in apis:
run_config.providers[api] = []
providers = build_config.distribution_spec.providers[api]
for provider in providers:
pid = provider.provider_type.split("::")[-1]
p = provider_registry[Api(api)][provider.provider_type]
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)
try:
config_type = instantiate_class_type(provider_registry[Api(api)][provider.provider_type].config_class)
except (ModuleNotFoundError, ValueError) as exc:
# HACK ALERT:
# This code executes after building is done, the import cannot work since the
# package is either available in the venv or container - not available on the host.
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
# external
cprint(
f"Failed to import provider {provider.provider_type} for API {api} - assuming it's external, skipping: {exc}",
color="yellow",
file=sys.stderr,
)
# Set config_type to None to avoid UnboundLocalError
config_type = None
if config_type is not None and hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
else:
config = {}
p_spec = Provider(
provider_id=pid,
provider_type=provider.provider_type,
config=config,
module=provider.module,
)
run_config.providers[api].append(p_spec)
run_config_file = build_dir / f"{image_name}-run.yaml"
with open(run_config_file, "w") as f:
to_write = json.loads(run_config.model_dump_json())
f.write(yaml.dump(to_write, sort_keys=False))
# Only print this message for non-container builds since it will be displayed before the
# container is built
# For non-container builds, the run.yaml is generated at the very end of the build process so it
# makes sense to display this message
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr)
return run_config_file
@lru_cache
def available_templates_specs() -> dict[str, BuildConfig]:
import yaml
template_specs = {}
for p in TEMPLATES_PATH.rglob("*build.yaml"):
template_name = p.parent.name
with open(p) as f:
build_config = BuildConfig(**yaml.safe_load(f))
template_specs[template_name] = build_config
return template_specs

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
class Subcommand:
"""All llama cli subcommands must inherit from this class"""
def __init__(self, *args, **kwargs):
pass
@classmethod
def create(cls, *args, **kwargs):
return cls(*args, **kwargs)
def _add_arguments(self):
pass

View file

@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Iterable
from rich.console import Console
from rich.table import Table
def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()):
# Convert rows and handle None values
rows = [[x or "" for x in row] for row in rows]
# Sort rows if sort_by is specified
if sort_by:
rows.sort(key=lambda x: tuple(x[i] for i in sort_by))
# Create Rich table
table = Table(show_lines=separate_rows)
# Add headers if provided
if headers:
for header in headers:
table.add_column(header, style="bold white")
else:
# Add unnamed columns based on first row
for _ in range(len(rows[0]) if rows else 0):
table.add_column()
# Add rows
for row in rows:
table.add_row(*row)
# Print table
console = Console()
console.print(table)

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="cli")
# TODO: this can probably just be inlined now?
def add_config_distro_args(parser: argparse.ArgumentParser):
"""Add unified config/distro arguments."""
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"config",
nargs="?",
help="Configuration file path or distribution name",
)
def get_config_from_args(args: argparse.Namespace) -> str | None:
if args.config is not None:
return str(args.config)
return None

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import User
from .conditions import (
Condition,
ProtectedResource,
parse_conditions,
)
from .datatypes import (
AccessRule,
Action,
Scope,
)
def matches_resource(resource_scope: str, actual_resource: str) -> bool:
if resource_scope == actual_resource:
return True
return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1])
def matches_scope(
scope: Scope,
action: Action,
resource: str,
user: str | None,
) -> bool:
if scope.resource and not matches_resource(scope.resource, resource):
return False
if scope.principal and scope.principal != user:
return False
return action in scope.actions
def as_list(obj: Any) -> list[Any]:
if isinstance(obj, list):
return obj
return [obj]
def matches_conditions(
conditions: list[Condition],
resource: ProtectedResource,
user: User,
) -> bool:
for condition in conditions:
# must match all conditions
if not condition.matches(resource, user):
return False
return True
def default_policy() -> list[AccessRule]:
# for backwards compatibility, if no rules are provided, assume
# full access subject to previous attribute matching rules
return [
AccessRule(
permit=Scope(actions=list(Action)),
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
),
]
def is_action_allowed(
policy: list[AccessRule],
action: Action,
resource: ProtectedResource,
user: User | None,
) -> bool:
# If user is not set, assume authentication is not enabled
if not user:
return True
if not len(policy):
policy = default_policy()
qualified_resource_id = f"{resource.type}::{resource.identifier}"
for rule in policy:
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
if rule.when:
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
return False
elif rule.unless:
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
return False
else:
return False
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
if rule.when:
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
return True
elif rule.unless:
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
return True
else:
return True
# assume access is denied unless we find a rule that permits access
return False
class AccessDeniedError(RuntimeError):
def __init__(self, action: str | None = None, resource: ProtectedResource | None = None, user: User | None = None):
self.action = action
self.resource = resource
self.user = user
message = _build_access_denied_message(action, resource, user)
super().__init__(message)
def _build_access_denied_message(action: str | None, resource: ProtectedResource | None, user: User | None) -> str:
"""Build detailed error message for access denied scenarios."""
if action and resource and user:
resource_info = f"{resource.type}::{resource.identifier}"
user_info = f"'{user.principal}'"
if user.attributes:
attrs = ", ".join([f"{k}={v}" for k, v in user.attributes.items()])
user_info += f" (attributes: {attrs})"
message = f"User {user_info} cannot perform action '{action}' on resource '{resource_info}'"
else:
message = "Insufficient permissions"
return message

View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol
class User(Protocol):
principal: str
attributes: dict[str, list[str]] | None
class ProtectedResource(Protocol):
type: str
identifier: str
owner: User
class Condition(Protocol):
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
class UserInOwnersList:
def __init__(self, name: str):
self.name = name
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
if (
hasattr(resource, "owner")
and resource.owner
and resource.owner.attributes
and self.name in resource.owner.attributes
):
return resource.owner.attributes[self.name]
else:
return None
def matches(self, resource: ProtectedResource, user: User) -> bool:
required = self.owners_values(resource)
if not required:
return True
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
return False
user_values = user.attributes[self.name]
for value in required:
if value in user_values:
return True
return False
def __repr__(self):
return f"user in owners {self.name}"
class UserNotInOwnersList(UserInOwnersList):
def __init__(self, name: str):
super().__init__(name)
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not super().matches(resource, user)
def __repr__(self):
return f"user not in owners {self.name}"
class UserWithValueInList:
def __init__(self, name: str, value: str):
self.name = name
self.value = value
def matches(self, resource: ProtectedResource, user: User) -> bool:
if user.attributes and self.name in user.attributes:
return self.value in user.attributes[self.name]
print(f"User does not have {self.value} in {self.name}")
return False
def __repr__(self):
return f"user with {self.value} in {self.name}"
class UserWithValueNotInList(UserWithValueInList):
def __init__(self, name: str, value: str):
super().__init__(name, value)
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not super().matches(resource, user)
def __repr__(self):
return f"user with {self.value} not in {self.name}"
class UserIsOwner:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return resource.owner.principal == user.principal if resource.owner else False
def __repr__(self):
return "user is owner"
class UserIsNotOwner:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not resource.owner or resource.owner.principal != user.principal
def __repr__(self):
return "user is not owner"
def parse_condition(condition: str) -> Condition:
words = condition.split()
match words:
case ["user", "is", "owner"]:
return UserIsOwner()
case ["user", "is", "not", "owner"]:
return UserIsNotOwner()
case ["user", "with", value, "in", name]:
return UserWithValueInList(name, value)
case ["user", "with", value, "not", "in", name]:
return UserWithValueNotInList(name, value)
case ["user", "in", "owners", name]:
return UserInOwnersList(name)
case ["user", "not", "in", "owners", name]:
return UserNotInOwnersList(name)
case _:
raise ValueError(f"Invalid condition: {condition}")
def parse_conditions(conditions: list[str]) -> list[Condition]:
return [parse_condition(c) for c in conditions]

View file

@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Self
from pydantic import BaseModel, model_validator
from .conditions import parse_conditions
class Action(StrEnum):
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
class Scope(BaseModel):
principal: str | None = None
actions: Action | list[Action]
resource: str | None = None
def _mutually_exclusive(obj, a: str, b: str):
if getattr(obj, a) and getattr(obj, b):
raise ValueError(f"{a} and {b} are mutually exclusive")
def _require_one_of(obj, a: str, b: str):
if not getattr(obj, a) and not getattr(obj, b):
raise ValueError(f"on of {a} or {b} is required")
class AccessRule(BaseModel):
"""Access rule based loosely on cedar policy language
A rule defines a list of action either to permit or to forbid. It may specify a
principal or a resource that must match for the rule to take effect. The resource
to match should be specified in the form of a type qualified identifier, e.g.
model::my-model or vector_store::some-db, or a wildcard for all resources of a type,
e.g. model::*. If the principal or resource are not specified, they will match all
requests.
A rule may also specify a condition, either a 'when' or an 'unless', with additional
constraints as to where the rule applies. The constraints supported at present are:
- 'user with <attr-value> in <attr-name>'
- 'user with <attr-value> not in <attr-name>'
- 'user is owner'
- 'user is not owner'
- 'user in owners <attr-name>'
- 'user not in owners <attr-name>'
Rules are tested in order to find a match. If a match is found, the request is
permitted or forbidden depending on the type of rule. If no match is found, the
request is denied. If no rules are specified, a rule that allows any action as
long as the resource attributes match the user attributes is added
(i.e. the previous behaviour is the default).
Some examples in yaml:
- permit:
principal: user-1
actions: [create, read, delete]
resource: model::*
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
actions: [read]
when: user in owner teams
description: any user has read access to any resource created by a member of their team
- forbid:
actions: [create, read, delete]
resource: vector_store::*
unless: user with admin in roles
description: only user with admin role can use vector_store resources
"""
permit: Scope | None = None
forbid: Scope | None = None
when: str | list[str] | None = None
unless: str | list[str] | None = None
description: str | None = None
@model_validator(mode="after")
def validate_rule_format(self) -> Self:
_require_one_of(self, "permit", "forbid")
_mutually_exclusive(self, "permit", "forbid")
_mutually_exclusive(self, "when", "unless")
if isinstance(self.when, list):
parse_conditions(self.when)
elif self.when:
parse_conditions([self.when])
if isinstance(self.unless, list):
parse_conditions(self.unless)
elif self.unless:
parse_conditions([self.unless])
return self

View file

@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib.resources
import sys
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.core.datatypes import BuildConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.external import load_external_apis
from llama_stack.core.utils.exec import run_command
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.distributions.template import DistributionTemplate
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
log = get_logger(name=__name__, category="core")
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"aiosqlite",
"fastapi",
"fire",
"httpx",
"uvicorn",
"opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http",
]
class ApiInput(BaseModel):
api: Api
provider: str
def get_provider_dependencies(
config: BuildConfig | DistributionTemplate,
) -> tuple[list[str], list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
if isinstance(config, DistributionTemplate):
config = config.build_config()
providers = config.distribution_spec.providers
additional_pip_packages = config.additional_pip_packages
deps = []
external_provider_deps = []
registry = get_provider_registry(config)
for api_str, provider_or_providers in providers.items():
providers_for_api = registry[Api(api_str)]
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
for provider in providers:
# Providers from BuildConfig and RunConfig are subtly different - not great
provider_type = provider if isinstance(provider, str) else provider.provider_type
if provider_type not in providers_for_api:
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
provider_spec = providers_for_api[provider_type]
if hasattr(provider_spec, "is_external") and provider_spec.is_external:
# this ensures we install the top level module for our external providers
if provider_spec.module:
if isinstance(provider_spec.module, str):
external_provider_deps.append(provider_spec.module)
else:
external_provider_deps.extend(provider_spec.module)
if hasattr(provider_spec, "pip_packages"):
deps.extend(provider_spec.pip_packages)
if hasattr(provider_spec, "container_image") and provider_spec.container_image:
raise ValueError("A stack's dependencies cannot have a container image")
normal_deps = []
special_deps = []
for package in deps:
if any(f in package for f in ["--no-deps", "--index-url", "--extra-index-url"]):
special_deps.append(package)
else:
normal_deps.append(package)
normal_deps.extend(additional_pip_packages or [])
return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps))
def print_pip_install_help(config: BuildConfig):
normal_deps, special_deps, _ = get_provider_dependencies(config)
cprint(
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
color="yellow",
file=sys.stderr,
)
for special_dep in special_deps:
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
print()
def build_image(
build_config: BuildConfig,
image_name: str,
distro_or_config: str,
run_config: str | None = None,
):
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
if build_config.external_apis_dir:
external_apis = load_external_apis(build_config)
if external_apis:
for _, api_spec in external_apis.items():
normal_deps.extend(api_spec.pip_packages)
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
script = str(importlib.resources.files("llama_stack") / "core/build_container.sh")
args = [
script,
"--distro-or-config",
distro_or_config,
"--image-name",
image_name,
"--container-base",
container_base,
"--normal-deps",
" ".join(normal_deps),
]
# When building from a config file (not a template), include the run config path in the
# build arguments
if run_config is not None:
args.extend(["--run-config", run_config])
else:
script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh")
args = [
script,
"--env-name",
str(image_name),
"--normal-deps",
" ".join(normal_deps),
]
# Always pass both arguments, even if empty, to maintain consistent positional arguments
if special_deps:
args.extend(["--optional-deps", "#".join(special_deps)])
if external_provider_deps:
args.extend(
["--external-provider-deps", "#".join(external_provider_deps)]
) # the script will install external provider module, get its deps, and install those too.
return_code = run_command(args)
if return_code != 0:
log.error(
f"Failed to build target {image_name} with return code {return_code}",
)
return return_code

View file

@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import json
import sys
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any, Union, get_args, get_origin
import httpx
from pydantic import BaseModel, parse_obj_as
from termcolor import cprint
from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {}
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
client_class = create_api_client_class(protocol)
impl = client_class(config.url)
await impl.initialize()
return impl
def create_api_client_class(protocol) -> type:
if protocol in _CLIENT_CLASSES:
return _CLIENT_CLASSES[protocol]
class APIClient:
def __init__(self, base_url: str):
print(f"({protocol.__name__}) Connecting to {base_url}")
self.base_url = base_url.rstrip("/")
self.routes = {}
# Store routes for this protocol
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
async def initialize(self):
pass
async def shutdown(self):
pass
async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
assert method_name in self.routes, f"Unknown endpoint: {method_name}"
# TODO: make this more precise, same thing needs to happen in server.py
is_streaming = kwargs.get("stream", False)
if is_streaming:
return self._call_streaming(method_name, *args, **kwargs)
else:
return await self._call_non_streaming(method_name, *args, **kwargs)
async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
_, sig = self.routes[method_name]
if sig.return_annotation is None:
return_type = None
else:
return_type = extract_non_async_iterator_type(sig.return_annotation)
assert return_type, f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
response = await client.request(**params)
response.raise_for_status()
j = response.json()
if j is None:
return None
# print(f"({protocol.__name__}) Returning {j}, type {return_type}")
return parse_obj_as(return_type, j)
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
webmethod, sig = self.routes[method_name]
return_type = extract_async_iterator_type(sig.return_annotation)
assert return_type, f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
async with client.stream(**params) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
data = json.loads(data)
if "error" in data:
cprint(data, color="red", file=sys.stderr)
continue
yield parse_obj_as(return_type, data)
except Exception as e:
cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
cprint(data, color="red", file=sys.stderr)
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
webmethod, sig = self.routes[method_name]
parameters = list(sig.parameters.values())[1:] # skip `self`
for i, param in enumerate(parameters):
if i >= len(args):
break
kwargs[param.name] = args[i]
# Get all webmethods for this method (supports multiple decorators)
webmethods = getattr(method, "__webmethods__", [])
if not webmethods:
raise RuntimeError(f"Method {method} has no webmethod decorators")
# Choose the preferred webmethod (non-deprecated if available)
preferred_webmethod = None
for wm in webmethods:
if not getattr(wm, "deprecated", False):
preferred_webmethod = wm
break
# If no non-deprecated found, use the first one
if preferred_webmethod is None:
preferred_webmethod = webmethods[0]
url = f"{self.base_url}/{preferred_webmethod.level}/{preferred_webmethod.route.lstrip('/')}"
def convert(value):
if isinstance(value, list):
return [convert(v) for v in value]
elif isinstance(value, dict):
return {k: convert(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
elif isinstance(value, Enum):
return value.value
else:
return value
params = {}
data = {}
if webmethod.method == "GET":
params.update(kwargs)
else:
data.update(convert(kwargs))
ret = dict(
method=webmethod.method or "POST",
url=url,
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
timeout=30,
)
if params:
ret["params"] = params
if data:
ret["json"] = data
return ret
# Add protocol methods to the wrapper
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
# Name the class after the protocol
APIClient.__name__ = f"{protocol.__name__}Client"
_CLIENT_CLASSES[protocol] = APIClient
return APIClient
# not quite general these methods are
def extract_non_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if not issubclass(get_origin(arg) or arg, AsyncIterator):
return arg
return type_hint
def extract_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if issubclass(get_origin(arg) or arg, AsyncIterator):
inner_args = get_args(arg)
return inner_args[0]
return None

37
src/llama_stack/core/common.sh Executable file
View file

@ -0,0 +1,37 @@
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
cleanup() {
# For venv environments, no special cleanup is needed
# This function exists to avoid "function not found" errors
local env_name="$1"
echo "Cleanup called for environment: $env_name"
}
handle_int() {
if [ -n "$ENVNAME" ]; then
cleanup "$ENVNAME"
fi
exit 1
}
handle_exit() {
if [ $? -ne 0 ]; then
echo -e "\033[1;31mABORTING.\033[0m"
if [ -n "$ENVNAME" ]; then
cleanup "$ENVNAME"
fi
fi
}
# check if a command is present
is_command_available() {
command -v "$1" &>/dev/null
}

View file

@ -0,0 +1,212 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import textwrap
from typing import Any
from llama_stack.core.datatypes import (
LLAMA_STACK_RUN_CONFIG_VERSION,
DistributionSpec,
Provider,
StackRunConfig,
)
from llama_stack.core.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, ProviderSpec
logger = get_logger(name=__name__, category="core")
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
provider_spec = registry[provider.provider_type]
config_type = instantiate_class_type(provider_spec.config_class)
try:
if provider.config:
existing = config_type(**provider.config)
else:
existing = None
except Exception:
existing = None
cfg = prompt_for_config(config_type, existing)
return Provider(
provider_id=provider.provider_id,
provider_type=provider.provider_type,
config=cfg.model_dump(),
)
def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
is_nux = len(config.providers) == 0
if is_nux:
logger.info(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. For each API served by the Stack,
we need to configure the providers (implementations) you want to use for these APIs.
"""
)
)
provider_registry = get_provider_registry()
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
if config.apis:
apis_to_serve = config.apis
else:
apis_to_serve = [a.value for a in Api if a not in (Api.inspect, Api.providers)]
for api_str in apis_to_serve:
api = Api(api_str)
if api in builtin_apis:
continue
if api not in provider_registry:
raise ValueError(f"Unknown API `{api_str}`")
existing_providers = config.providers.get(api_str, [])
if existing_providers:
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
updated_providers = []
for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(configure_single_provider(provider_registry[api], p))
logger.info("")
else:
# we are newly configuring this API
plist = build_spec.providers.get(api_str, [])
plist = plist if isinstance(plist, list) else [plist]
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
logger.info(f"Configuring API `{api_str}`...")
updated_providers = []
for i, provider in enumerate(plist):
if i >= 1:
others = ", ".join(p.provider_type for p in plist[i:])
logger.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
)
break
logger.info(f"> Configuring provider `({provider.provider_type})`")
pid = provider.provider_type.split("::")[-1]
updated_providers.append(
configure_single_provider(
provider_registry[api],
Provider(
provider_id=(f"{pid}-{i:02d}" if len(plist) > 1 else pid),
provider_type=provider.provider_type,
config={},
),
)
)
logger.info("")
config.providers[api_str] = updated_providers
return config
def upgrade_from_routing_table(
config_dict: dict[str, Any],
) -> dict[str, Any]:
def get_providers(entries):
return [
Provider(
provider_id=(f"{entry['provider_type']}-{i:02d}" if len(entries) > 1 else entry["provider_type"]),
provider_type=entry["provider_type"],
config=entry["config"],
)
for i, entry in enumerate(entries)
]
providers_by_api = {}
routing_table = config_dict.get("routing_table", {})
for api_str, entries in routing_table.items():
providers = get_providers(entries)
providers_by_api[api_str] = providers
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
if provider_map:
for api_str, provider in provider_map.items():
if isinstance(provider, dict) and "provider_type" in provider:
providers_by_api[api_str] = [
Provider(
provider_id=f"{provider['provider_type']}",
provider_type=provider["provider_type"],
config=provider["config"],
)
]
config_dict["providers"] = providers_by_api
config_dict.pop("routing_table", None)
config_dict.pop("api_providers", None)
config_dict.pop("provider_map", None)
config_dict["apis"] = config_dict["apis_to_serve"]
config_dict.pop("apis_to_serve", None)
# Add default storage config if not present
if "storage" not in config_dict:
config_dict["storage"] = {
"backends": {
"kv_default": {
"type": "kv_sqlite",
"db_path": "~/.llama/kvstore.db",
},
"sql_default": {
"type": "sql_sqlite",
"db_path": "~/.llama/sql_store.db",
},
},
"stores": {
"metadata": {
"namespace": "registry",
"backend": "kv_default",
},
"inference": {
"table_name": "inference_store",
"backend": "sql_default",
"max_write_queue_size": 10000,
"num_writers": 4,
},
"conversations": {
"table_name": "openai_conversations",
"backend": "sql_default",
},
},
}
return config_dict
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
if "routing_table" in config_dict:
logger.info("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
if not config_dict.get("external_providers_dir", None):
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,314 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import secrets
import time
from typing import Any, Literal
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationDeletedResource,
ConversationItem,
ConversationItemDeletedResource,
ConversationItemInclude,
ConversationItemList,
Conversations,
Metadata,
)
from llama_stack.core.datatypes import AccessRule, StackRunConfig
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
logger = get_logger(name=__name__, category="openai_conversations")
class ConversationServiceConfig(BaseModel):
"""Configuration for the built-in conversation service.
:param run_config: Stack run configuration for resolving persistence
:param policy: Access control rules
"""
run_config: StackRunConfig
policy: list[AccessRule] = []
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
"""Get the conversation service implementation."""
impl = ConversationServiceImpl(config, deps)
await impl.initialize()
return impl
class ConversationServiceImpl(Conversations):
"""Built-in conversation service implementation using AuthorizedSqlStore."""
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.policy = config.policy
# Use conversations store reference from run config
conversations_ref = config.run_config.storage.stores.conversations
if not conversations_ref:
raise ValueError("storage.stores.conversations must be configured in run config")
base_sql_store = sqlstore_impl(conversations_ref)
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
async def initialize(self) -> None:
"""Initialize the store and create tables."""
await self.sql_store.create_table(
"openai_conversations",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"created_at": ColumnType.INTEGER,
"items": ColumnType.JSON,
"metadata": ColumnType.JSON,
},
)
await self.sql_store.create_table(
"conversation_items",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"conversation_id": ColumnType.STRING,
"created_at": ColumnType.INTEGER,
"item_data": ColumnType.JSON,
},
)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation."""
random_bytes = secrets.token_bytes(24)
conversation_id = f"conv_{random_bytes.hex()}"
created_at = int(time.time())
record_data = {
"id": conversation_id,
"created_at": created_at,
"items": [],
"metadata": metadata,
}
await self.sql_store.insert(
table="openai_conversations",
data=record_data,
)
if items:
item_records = []
for item in items:
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
item_record = {
"id": item_id,
"conversation_id": conversation_id,
"created_at": created_at,
"item_data": item_dict,
}
item_records.append(item_record)
await self.sql_store.insert(table="conversation_items", data=item_records)
conversation = Conversation(
id=conversation_id,
created_at=created_at,
metadata=metadata,
object="conversation",
)
logger.debug(f"Created conversation {conversation_id}")
return conversation
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID."""
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
if record is None:
raise ValueError(f"Conversation {conversation_id} not found")
return Conversation(
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID"""
await self.sql_store.update(
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
)
return await self.get_conversation(conversation_id)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID."""
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
logger.debug(f"Deleted conversation {conversation_id}")
return ConversationDeletedResource(id=conversation_id)
def _validate_conversation_id(self, conversation_id: str) -> None:
"""Validate conversation ID format."""
if not conversation_id.startswith("conv_"):
raise ValueError(
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
)
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
"""Get existing item ID or generate one if missing."""
if item.id is None:
random_bytes = secrets.token_bytes(24)
if item.type == "message":
item_id = f"msg_{random_bytes.hex()}"
else:
item_id = f"item_{random_bytes.hex()}"
item_dict["id"] = item_id
return item_id
return item.id
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
"""Validate conversation ID and return the conversation if it exists."""
self._validate_conversation_id(conversation_id)
return await self.get_conversation(conversation_id)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create (add) items to a conversation."""
await self._get_validated_conversation(conversation_id)
created_items = []
base_time = int(time.time())
for i, item in enumerate(items):
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
# make each timestamp unique to maintain order
created_at = base_time + i
item_record = {
"id": item_id,
"conversation_id": conversation_id,
"created_at": created_at,
"item_data": item_dict,
}
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
try:
await self.sql_store.insert(table="conversation_items", data=item_record)
except Exception:
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_items",
data={"created_at": created_at, "item_data": item_dict},
where={"id": item_id},
)
created_items.append(item_dict)
logger.debug(f"Created {len(created_items)} items in conversation {conversation_id}")
# Convert created items (dicts) to proper ConversationItem types
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
return ConversationItemList(
data=response_items,
first_id=created_items[0]["id"] if created_items else None,
last_id=created_items[-1]["id"] if created_items else None,
has_more=False,
)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
# Get item from conversation_items table
record = await self.sql_store.fetch_one(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
if record is None:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
return adapter.validate_python(record["item_data"])
async def list_items(
self,
conversation_id: str,
after: str | None = None,
include: list[ConversationItemInclude] | None = None,
limit: int | None = None,
order: Literal["asc", "desc"] | None = None,
) -> ConversationItemList:
"""List items in the conversation."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
# check if conversation exists
await self.get_conversation(conversation_id)
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
records = result.data
if order is not None and order == "asc":
records.sort(key=lambda x: x["created_at"])
else:
records.sort(key=lambda x: x["created_at"], reverse=True)
actual_limit = limit or 20
records = records[:actual_limit]
items = [record["item_data"] for record in records]
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
first_id = response_items[0].id if response_items else None
last_id = response_items[-1].id if response_items else None
return ConversationItemList(
data=response_items,
first_id=first_id,
last_id=last_id,
has_more=False,
)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
_ = await self._get_validated_conversation(conversation_id)
record = await self.sql_store.fetch_one(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
if record is None:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
await self.sql_store.delete(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
logger.debug(f"Deleted item {item_id} from conversation {conversation_id}")
return ConversationItemDeletedResource(id=item_id)

View file

@ -0,0 +1,629 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Any, Literal, Self
from urllib.parse import urlparse
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.resource import Resource
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.storage.datatypes import (
KVStoreReference,
StorageBackendType,
StorageConfig,
)
from llama_stack.log import LoggingConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
LLAMA_STACK_RUN_CONFIG_VERSION = 2
RoutingKey = str | list[str]
class RegistryEntrySource(StrEnum):
via_register_api = "via_register_api"
listed_from_provider = "listed_from_provider"
class User(BaseModel):
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
super().__init__(principal=principal, attributes=attributes)
class ResourceWithOwner(Resource):
"""Extension of Resource that adds an optional owner, i.e. the user that created the
resource. This can be used to constrain access to the resource."""
owner: User | None = None
source: RegistryEntrySource = RegistryEntrySource.via_register_api
# Use the extended Resource for all routable objects
class ModelWithOwner(Model, ResourceWithOwner):
pass
class ShieldWithOwner(Shield, ResourceWithOwner):
pass
class VectorStoreWithOwner(VectorStore, ResourceWithOwner):
pass
class DatasetWithOwner(Dataset, ResourceWithOwner):
pass
class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
pass
class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
pass
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
pass
RoutableObject = Model | Shield | VectorStore | Dataset | ScoringFn | Benchmark | ToolGroup
RoutableObjectWithProvider = Annotated[
ModelWithOwner
| ShieldWithOwner
| VectorStoreWithOwner
| DatasetWithOwner
| ScoringFnWithOwner
| BenchmarkWithOwner
| ToolGroupWithOwner,
Field(discriminator="type"),
]
RoutedProtocol = Inference | Safety | VectorIO | DatasetIO | Scoring | Eval | ToolRuntime
# Example: /inference, /safety
class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router"
config_class: str = ""
container_image: str | None = None
routing_table_api: Api
module: str
provider_data_validator: str | None = Field(
default=None,
)
# Example: /models, /shields
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
container_image: str | None = None
router_api: Api
module: str
pip_packages: list[str] = Field(default_factory=list)
class Provider(BaseModel):
# provider_id of None means that the provider is not enabled - this happens
# when the provider is enabled via a conditional environment variable
provider_id: str | None
provider_type: str
config: dict[str, Any] = {}
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the external provider module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
class BuildProvider(BaseModel):
provider_type: str
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the external provider module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
class DistributionSpec(BaseModel):
description: str | None = Field(
default="",
description="Description of the distribution",
)
container_image: str | None = None
providers: dict[str, list[BuildProvider]] = Field(
default_factory=dict,
description="""
Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.
""",
)
class TelemetryConfig(BaseModel):
"""
Configuration for telemetry.
Llama Stack uses OpenTelemetry for telemetry. Please refer to https://opentelemetry.io/docs/languages/sdk-configuration/
for env variables to configure the OpenTelemetry SDK.
Example:
```bash
OTEL_SERVICE_NAME=llama-stack OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 uv run llama stack run starter
```
"""
enabled: bool = Field(default=False, description="enable or disable telemetry")
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class AuthProviderType(StrEnum):
"""Supported authentication provider types."""
OAUTH2_TOKEN = "oauth2_token"
GITHUB_TOKEN = "github_token"
CUSTOM = "custom"
KUBERNETES = "kubernetes"
class OAuth2TokenAuthConfig(BaseModel):
"""Configuration for OAuth2 token authentication."""
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
audience: str = Field(default="llama-stack")
verify_tls: bool = Field(default=True)
tls_cafile: Path | None = Field(default=None)
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None = Field(default=None, description="JWKS configuration")
introspection: OAuth2IntrospectionConfig | None = Field(
default=None, description="OAuth2 introspection configuration"
)
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class CustomAuthConfig(BaseModel):
"""Configuration for custom authentication."""
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
endpoint: str = Field(
...,
description="Custom authentication endpoint URL",
)
class GitHubTokenAuthConfig(BaseModel):
"""Configuration for GitHub token authentication."""
type: Literal[AuthProviderType.GITHUB_TOKEN] = AuthProviderType.GITHUB_TOKEN
github_api_base_url: str = Field(
default="https://api.github.com",
description="Base URL for GitHub API (use https://api.github.com for public GitHub)",
)
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"login": "roles",
"organizations": "teams",
},
description="Mapping from GitHub user fields to access attributes",
)
class KubernetesAuthProviderConfig(BaseModel):
"""Configuration for Kubernetes authentication provider."""
type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES
api_server_url: str = Field(
default="https://kubernetes.default.svc",
description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)",
)
verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates")
tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"username": "roles",
"groups": "roles",
},
description="Mapping of Kubernetes user claims to access attributes",
)
@field_validator("api_server_url")
@classmethod
def validate_api_server_url(cls, v):
parsed = urlparse(v)
if not parsed.scheme or not parsed.netloc:
raise ValueError(f"api_server_url must be a valid URL with scheme and host: {v}")
if parsed.scheme not in ["http", "https"]:
raise ValueError(f"api_server_url scheme must be http or https: {v}")
return v
@field_validator("claims_mapping")
@classmethod
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
AuthProviderConfig = Annotated[
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig,
Field(discriminator="type"),
]
class AuthenticationConfig(BaseModel):
"""Top-level authentication configuration."""
provider_config: AuthProviderConfig = Field(
...,
description="Authentication provider configuration",
)
access_policy: list[AccessRule] = Field(
default=[],
description="Rules for determining access to resources",
)
class AuthenticationRequiredError(Exception):
pass
class QualifiedModel(BaseModel):
"""A qualified model identifier, consisting of a provider ID and a model ID."""
provider_id: str
model_id: str
class VectorStoresConfig(BaseModel):
"""Configuration for vector stores in the stack."""
default_provider_id: str | None = Field(
default=None,
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
)
default_embedding_model: QualifiedModel | None = Field(
default=None,
description="Default embedding model configuration for vector stores.",
)
class SafetyConfig(BaseModel):
"""Configuration for default moderations model."""
default_shield_id: str | None = Field(
default=None,
description="ID of the shield to use for when `model` is not specified in the `moderations` API request.",
)
class QuotaPeriod(StrEnum):
DAY = "day"
class QuotaConfig(BaseModel):
kvstore: KVStoreReference = Field(description="Config for KV store backend (SQLite only for now)")
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
authenticated_max_requests: int = Field(
default=1000, description="Max requests for authenticated clients per period"
)
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel):
allow_origins: list[str] = Field(default_factory=list)
allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["OPTIONS"])
allow_headers: list[str] = Field(default_factory=list)
allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0)
@model_validator(mode="after")
def validate_credentials_config(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("Cannot use wildcard origins with credentials enabled")
return self
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
if cors_config is False or cors_config is None:
return None
if cors_config is True:
# dev mode: allow localhost on any port
return CORSConfig(
allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
)
if isinstance(cors_config, CORSConfig):
return cors_config
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
class RegisteredResources(BaseModel):
"""Registry of resources available in the distribution."""
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)
vector_stores: list[VectorStoreInput] = Field(default_factory=list)
datasets: list[DatasetInput] = Field(default_factory=list)
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
description="Port to listen on",
ge=1024,
le=65535,
)
tls_certfile: str | None = Field(
default=None,
description="Path to TLS certificate file for HTTPS",
)
tls_keyfile: str | None = Field(
default=None,
description="Path to TLS key file for HTTPS",
)
tls_cafile: str | None = Field(
default=None,
description="Path to TLS CA file for HTTPS with mutual TLS authentication",
)
auth: AuthenticationConfig | None = Field(
default=None,
description="Authentication configuration for the server",
)
host: str | None = Field(
default=None,
description="The host the server should listen on",
)
quota: QuotaConfig | None = Field(
default=None,
description="Per client quota request configuration",
)
cors: bool | CORSConfig | None = Field(
default=None,
description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
class StackRunConfig(BaseModel):
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
image_name: str = Field(
...,
description="""
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
this could be just a hash
""",
)
container_image: str | None = Field(
default=None,
description="Reference to the container image if this package refers to a container",
)
apis: list[str] = Field(
default_factory=list,
description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
)
providers: dict[str, list[Provider]] = Field(
description="""
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
can be instantiated multiple times (with different configs) if necessary.
""",
)
storage: StorageConfig = Field(
description="Catalog of named storage backends and references available to the stack",
)
registered_resources: RegisteredResources = Field(
default_factory=RegisteredResources,
description="Registry of resources available in the distribution",
)
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig, description="Configuration for telemetry")
server: ServerConfig = Field(
default_factory=ServerConfig,
description="Configuration for the HTTP(S) server",
)
external_providers_dir: Path | None = Field(
default=None,
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
)
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
vector_stores: VectorStoresConfig | None = Field(
default=None,
description="Configuration for vector stores, including default embedding model",
)
safety: SafetyConfig | None = Field(
default=None,
description="Configuration for default moderations model",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v
@model_validator(mode="after")
def validate_server_stores(self) -> "StackRunConfig":
backend_map = self.storage.backends
stores = self.storage.stores
kv_backends = {
name
for name, cfg in backend_map.items()
if cfg.type
in {
StorageBackendType.KV_REDIS,
StorageBackendType.KV_SQLITE,
StorageBackendType.KV_POSTGRES,
StorageBackendType.KV_MONGODB,
}
}
sql_backends = {
name
for name, cfg in backend_map.items()
if cfg.type in {StorageBackendType.SQL_SQLITE, StorageBackendType.SQL_POSTGRES}
}
def _ensure_backend(reference, expected_set, store_name: str) -> None:
if reference is None:
return
backend_name = reference.backend
if backend_name not in backend_map:
raise ValueError(
f"{store_name} references unknown backend '{backend_name}'. "
f"Available backends: {sorted(backend_map)}"
)
if backend_name not in expected_set:
raise ValueError(
f"{store_name} references backend '{backend_name}' of type "
f"'{backend_map[backend_name].type.value}', but a backend of type "
f"{'kv_*' if expected_set is kv_backends else 'sql_*'} is required."
)
_ensure_backend(stores.metadata, kv_backends, "storage.stores.metadata")
_ensure_backend(stores.inference, sql_backends, "storage.stores.inference")
_ensure_backend(stores.conversations, sql_backends, "storage.stores.conversations")
_ensure_backend(stores.responses, sql_backends, "storage.stores.responses")
_ensure_backend(stores.prompts, kv_backends, "storage.stores.prompts")
return self
class BuildConfig(BaseModel):
version: int = LLAMA_STACK_BUILD_CONFIG_VERSION
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
image_type: str = Field(
default="venv",
description="Type of package to build (container | venv)",
)
image_name: str | None = Field(
default=None,
description="Name of the distribution to build",
)
external_providers_dir: Path | None = Field(
default=None,
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
"pip_packages MUST contain the provider package name.",
)
additional_pip_packages: list[str] = Field(
default_factory=list,
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
)
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v

View file

@ -0,0 +1,276 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import glob
import importlib
import os
from typing import Any
import yaml
from pydantic import BaseModel
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.external import load_external_apis
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
Api,
InlineProviderSpec,
ProviderSpec,
RemoteProviderSpec,
)
logger = get_logger(name=__name__, category="core")
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
def stack_apis() -> list[Api]:
return list(Api)
class AutoRoutedApiInfo(BaseModel):
routing_table_api: Api
router_api: Api
def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
return [
AutoRoutedApiInfo(
routing_table_api=Api.models,
router_api=Api.inference,
),
AutoRoutedApiInfo(
routing_table_api=Api.shields,
router_api=Api.safety,
),
AutoRoutedApiInfo(
routing_table_api=Api.datasets,
router_api=Api.datasetio,
),
AutoRoutedApiInfo(
routing_table_api=Api.scoring_functions,
router_api=Api.scoring,
),
AutoRoutedApiInfo(
routing_table_api=Api.benchmarks,
router_api=Api.eval,
),
AutoRoutedApiInfo(
routing_table_api=Api.tool_groups,
router_api=Api.tool_runtime,
),
AutoRoutedApiInfo(
routing_table_api=Api.vector_stores,
router_api=Api.vector_io,
),
]
def providable_apis() -> list[Api]:
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api not in INTERNAL_APIS]
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
return spec
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
return spec
def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files or from their provided modules.
External providers are loaded from a directory structure like:
providers.d/
remote/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
inline/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
There is special handling for all of the potential cases this method can be called from.
Args:
config: Optional object containing the external providers directory path
building: Optional bool delineating whether or not this is being called from a build process
Returns:
A dictionary mapping APIs to their available providers
Raises:
FileNotFoundError: If the external providers directory doesn't exist
ValueError: If any provider spec is invalid
"""
registry: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis():
name = api.name.lower()
logger.debug(f"Importing module {name}")
try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
# Refresh providable APIs with external APIs if any
external_apis = load_external_apis(config)
for api, api_spec in external_apis.items():
name = api_spec.name.lower()
logger.info(f"Importing external API {name} module {api_spec.module}")
try:
module = importlib.import_module(api_spec.module)
registry[api] = {a.provider_type: a for a in module.available_providers()}
except (ImportError, AttributeError) as e:
# Populate the registry with an empty dict to avoid breaking the provider registry
# This assume that the in-tree provider(s) are not available for this API which means
# that users will need to use external providers for this API.
registry[api] = {}
logger.error(
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
"Install the API package to load any in-tree providers for this API."
)
# Check if config has external providers
if config:
if hasattr(config, "external_providers_dir") and config.external_providers_dir:
registry = get_external_providers_from_dir(registry, config)
# else lets check for modules in each provider
registry = get_external_providers_from_module(
registry=registry,
config=config,
building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)),
)
return registry
def get_external_providers_from_dir(
registry: dict[Api, dict[str, ProviderSpec]], config
) -> dict[Api, dict[str, ProviderSpec]]:
logger.warning(
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
)
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
for api in providable_apis():
api_name = api.name.lower()
# Process both remote and inline providers
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
continue
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
try:
with open(spec_path) as f:
spec_data = yaml.safe_load(f)
if provider_type == "remote":
spec = _load_remote_provider_spec(spec_data, api)
provider_type_key = f"remote::{provider_name}"
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in registry[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return registry
def get_external_providers_from_module(
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
) -> dict[Api, dict[str, ProviderSpec]]:
provider_list = None
if isinstance(config, BuildConfig):
provider_list = config.distribution_spec.providers.items()
else:
provider_list = config.providers.items()
if provider_list is None:
logger.warning("Could not get list of providers from config")
return registry
for provider_api, providers in provider_list:
for provider in providers:
if not hasattr(provider, "module") or provider.module is None:
continue
# get provider using module
try:
if not building:
package_name = provider.module.split("==")[0]
module = importlib.import_module(f"{package_name}.provider")
# if config class is wrong you will get an error saying module could not be imported
spec = module.get_provider_spec()
else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
# in the case we are building we CANNOT import this module of course because it has not been installed.
spec = ProviderSpec(
api=Api(provider_api),
provider_type=provider.provider_type,
is_external=True,
module=provider.module,
config_class="",
)
provider_type = provider.provider_type
if isinstance(spec, list):
# optionally allow people to pass inline and remote provider specs as a returned list.
# with the old method, users could pass in directories of specs using overlapping code
# we want to ensure we preserve that flexibility in this method.
logger.info(
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
)
for provider_spec in spec:
if provider_spec.provider_type != provider.provider_type:
continue
logger.info(f"Adding {provider.provider_type} to registry")
registry[Api(provider_api)][provider.provider_type] = provider_spec
else:
registry[Api(provider_api)][provider_type] = spec
except ModuleNotFoundError as exc:
raise ValueError(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
) from exc
except Exception as e:
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
raise e
return registry

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import yaml
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
"""Load external API specifications from the configured directory.
Args:
config: StackRunConfig or BuildConfig containing the external APIs directory path
Returns:
A dictionary mapping API names to their specifications
"""
if not config or not config.external_apis_dir:
return {}
external_apis_dir = config.external_apis_dir.expanduser().resolve()
if not external_apis_dir.is_dir():
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
return {}
logger.info(f"Loading external APIs from {external_apis_dir}")
external_apis: dict[Api, ExternalApiSpec] = {}
# Look for YAML files in the external APIs directory
for yaml_path in external_apis_dir.glob("*.yaml"):
try:
with open(yaml_path) as f:
spec_data = yaml.safe_load(f)
spec = ExternalApiSpec(**spec_data)
api = Api.add(spec.name)
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
external_apis[api] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
raise
except Exception:
logger.exception(f"Failed to load external API spec from {yaml_path}")
raise
return external_apis

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable
IdFactory = Callable[[], str]
IdOverride = Callable[[str, IdFactory], str]
_id_override: IdOverride | None = None
def generate_object_id(kind: str, factory: IdFactory) -> str:
"""Generate an identifier for the given kind using the provided factory.
Allows tests to override ID generation deterministically by installing an
override callback via :func:`set_id_override`.
"""
override = _id_override
if override is not None:
return override(kind, factory)
return factory()
def set_id_override(override: IdOverride) -> IdOverride | None:
"""Install an override used to generate deterministic identifiers."""
global _id_override
previous = _id_override
_id_override = override
return previous
def reset_id_override(previous: IdOverride | None) -> None:
"""Restore the previous override returned by :func:`set_id_override`."""
global _id_override
_id_override = previous

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from importlib.metadata import version
from pydantic import BaseModel
from llama_stack.apis.inspect import (
HealthInfo,
Inspect,
ListRoutesResponse,
RouteInfo,
VersionInfo,
)
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus
class DistributionInspectConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = DistributionInspectImpl(config, deps)
await impl.initialize()
return impl
class DistributionInspectImpl(Inspect):
def __init__(self, config: DistributionInspectConfig, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def list_routes(self) -> ListRoutesResponse:
run_config: StackRunConfig = self.config.run_config
ret = []
external_apis = load_external_apis(run_config)
all_endpoints = get_all_api_routes(external_apis)
for api, endpoints in all_endpoints.items():
# Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]:
ret.extend(
[
RouteInfo(
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
)
for e, _ in endpoints
if e.methods is not None
]
)
else:
providers = run_config.providers.get(api.value, [])
if providers: # Only process if there are providers for this API
ret.extend(
[
RouteInfo(
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers],
)
for e, _ in endpoints
if e.methods is not None
]
)
return ListRoutesResponse(data=ret)
async def health(self) -> HealthInfo:
return HealthInfo(status=HealthStatus.OK)
async def version(self) -> VersionInfo:
return VersionInfo(version=version("llama-stack"))
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,540 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import json
import logging # allow-direct-logging
import os
import sys
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Any, TypeVar, Union, get_args, get_origin
import httpx
import yaml
from fastapi import Response as FastAPIResponse
from llama_stack_client import (
NOT_GIVEN,
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
)
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
from llama_stack.core.build import print_pip_install_help
from llama_stack.core.configure import parse_and_maybe_upgrade_config
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.core.resolver import ProviderRegistry
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
from llama_stack.core.stack import (
Stack,
get_stack_run_config_from_distro,
replace_env_vars,
)
from llama_stack.core.telemetry import Telemetry
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger, setup_logging
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
logger = get_logger(name=__name__, category="core")
T = TypeVar("T")
def convert_pydantic_to_json_value(value: Any) -> Any:
if isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return [convert_pydantic_to_json_value(item) for item in value]
elif isinstance(value, dict):
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
else:
return value
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value
origin = get_origin(annotation)
if origin is list:
item_type = get_args(annotation)[0]
try:
return [convert_to_pydantic(item_type, item) for item in value]
except Exception:
logger.error(f"Error converting list {value} into {item_type}")
return value
elif origin is dict:
key_type, val_type = get_args(annotation)
try:
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
except Exception:
logger.error(f"Error converting dict {value} into {val_type}")
return value
try:
# Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value)
except Exception as e:
# TODO: this is workardound for having Union[str, AgentToolGroup] in API schema.
# We should get rid of any non-discriminated unions in the API schema.
if origin is Union:
for union_type in get_args(annotation):
try:
return convert_to_pydantic(union_type, value)
except Exception:
continue
logger.warning(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
)
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
class LibraryClientUploadFile:
"""LibraryClient UploadFile object that mimics FastAPI's UploadFile interface."""
def __init__(self, filename: str, content: bytes):
self.filename = filename
self.content = content
self.content_type = "application/octet-stream"
async def read(self) -> bytes:
return self.content
class LibraryClientHttpxResponse:
"""LibraryClient httpx Response object for FastAPI Response conversion."""
def __init__(self, response):
self.content = response.body if isinstance(response.body, bytes) else response.body.encode()
self.status_code = response.status_code
self.headers = response.headers
class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__(
self,
config_path_or_distro_name: str,
skip_logger_removal: bool = False,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
)
self.provider_data = provider_data
self.loop = asyncio.new_event_loop()
# use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)
def initialize(self):
"""
Deprecated method for backward compatibility.
"""
pass
def request(self, *args, **kwargs):
loop = self.loop
asyncio.set_event_loop(loop)
if kwargs.get("stream"):
def sync_generator():
try:
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
while True:
chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk
except StopAsyncIteration:
pass
finally:
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
return sync_generator()
else:
try:
result = loop.run_until_complete(self.async_client.request(*args, **kwargs))
finally:
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
return result
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
def __init__(
self,
config_path_or_distro_name: str,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
skip_logger_removal: bool = False,
):
super().__init__()
# Initialize logging from environment variables first
setup_logging()
# when using the library client, we should not log to console since many
# of our logs are intended for server-side usage
if sinks_from_env := os.environ.get("TELEMETRY_SINKS", None):
current_sinks = sinks_from_env.strip().lower().split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not skip_logger_removal:
self._remove_root_logger_handlers()
if config_path_or_distro_name.endswith(".yaml"):
config_path = Path(config_path_or_distro_name)
if not config_path.exists():
raise ValueError(f"Config file {config_path} does not exist")
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
config = parse_and_maybe_upgrade_config(config_dict)
else:
# distribution
config = get_stack_run_config_from_distro(config_path_or_distro_name)
self.config_path_or_distro_name = config_path_or_distro_name
self.config = config
self.custom_provider_registry = custom_provider_registry
self.provider_data = provider_data
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
async def initialize(self) -> bool:
"""
Initialize the async client.
Returns:
bool: True if initialization was successful
"""
try:
self.route_impls = None
stack = Stack(self.config, self.custom_provider_registry)
await stack.initialize()
self.impls = stack.impls
except ModuleNotFoundError as _e:
cprint(_e.msg, color="red", file=sys.stderr)
cprint(
"Using llama-stack as a library requires installing dependencies depending on the distribution (providers) you choose.\n",
color="yellow",
file=sys.stderr,
)
if self.config_path_or_distro_name.endswith(".yaml"):
providers: dict[str, list[BuildProvider]] = {}
for api, run_providers in self.config.providers.items():
for provider in run_providers:
providers.setdefault(api, []).append(
BuildProvider(provider_type=provider.provider_type, module=provider.module)
)
providers = dict(providers)
build_config = BuildConfig(
distribution_spec=DistributionSpec(
providers=providers,
),
external_providers_dir=self.config.external_providers_dir,
)
print_pip_install_help(build_config)
else:
prefix = "!" if in_notebook() else ""
cprint(
f"Please run:\n\n{prefix}llama stack list-deps {self.config_path_or_distro_name} | xargs -L1 uv pip install\n\n",
"yellow",
file=sys.stderr,
)
cprint(
"Please check your internet connection and try again.",
"red",
file=sys.stderr,
)
raise _e
assert self.impls is not None
if self.config.telemetry.enabled:
setup_logger(Telemetry())
if not os.environ.get("PYTEST_CURRENT_TEST"):
console = Console()
console.print(f"Using config [blue]{self.config_path_or_distro_name}[/blue]:")
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
self.route_impls = initialize_route_impls(self.impls)
return True
async def request(
self,
cast_to: Any,
options: Any,
*,
stream=False,
stream_cls=None,
):
if self.route_impls is None:
raise ValueError("Client not initialized. Please call initialize() first.")
# Create headers with provider data if available
headers = options.headers or {}
if self.provider_data:
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
if all(key not in headers for key in keys):
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
# Use context manager for provider data
with request_provider_data_context(headers):
if stream:
response = await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
response = await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
return response
def _handle_file_uploads(self, options: Any, body: dict) -> tuple[dict, list[str]]:
"""Handle file uploads from OpenAI client and add them to the request body."""
if not (hasattr(options, "files") and options.files):
return body, []
if not isinstance(options.files, list):
return body, []
field_names = []
for file_tuple in options.files:
if not (isinstance(file_tuple, tuple) and len(file_tuple) >= 2):
continue
field_name = file_tuple[0]
file_object = file_tuple[1]
if isinstance(file_object, BytesIO):
file_object.seek(0)
file_content = file_object.read()
filename = getattr(file_object, "name", "uploaded_file")
field_names.append(field_name)
body[field_name] = LibraryClientUploadFile(filename, file_content)
return body, field_names
async def _call_non_streaming(
self,
*,
cast_to: Any,
options: Any,
):
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
path = options.url
body = options.params or {}
body |= options.json_data or {}
# Merge extra_json parameters (extra_body from SDK is converted to extra_json)
if hasattr(options, "extra_json") and options.extra_json:
body |= options.extra_json
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
try:
result = await matched_func(**body)
finally:
await end_trace()
# Handle FastAPI Response objects (e.g., from file content retrieval)
if isinstance(result, FastAPIResponse):
return LibraryClientHttpxResponse(result)
json_content = json.dumps(convert_pydantic_to_json_value(result))
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
status_code = httpx.codes.OK
if options.method.upper() == "DELETE" and result is None:
status_code = httpx.codes.NO_CONTENT
if status_code == httpx.codes.NO_CONTENT:
json_content = ""
mock_response = httpx.Response(
status_code=status_code,
content=json_content.encode("utf-8"),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers or {},
json=convert_pydantic_to_json_value(filtered_body),
),
)
response = APIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=False,
stream_cls=None,
)
return response.parse()
async def _call_streaming(
self,
*,
cast_to: Any,
options: Any,
stream_cls: Any,
):
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
path = options.url
body = options.params or {}
body |= options.json_data or {}
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
# Prepare body for the function call (handles both Pydantic and traditional params)
body = self._convert_body(func, body)
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
async def gen():
try:
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
finally:
await end_trace()
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=wrapped_gen,
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers or {},
json=convert_pydantic_to_json_value(body),
),
)
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
# so we need to convert it to AsyncStream
# mypy can't track runtime variables inside the [...] of a generic, so ignore that check
args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]] # type: ignore[valid-type]
response = AsyncAPIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=True,
stream_cls=stream_cls,
)
return await response.parse()
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
body = body or {}
exclude_params = exclude_params or set()
sig = inspect.signature(func)
params_list = [p for p in sig.parameters.values() if p.name != "self"]
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
if len(params_list) == 1:
param = params_list[0]
param_type = param.annotation
if is_unwrapped_body_param(param_type):
base_type = get_args(param_type)[0]
return {param.name: base_type(**body)}
# Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
# Check if there's an unwrapped body parameter among multiple parameters
# (e.g., path param + body param like: vector_store_id: str, params: Annotated[Model, Body(...)])
unwrapped_body_param = None
for param in params_list:
if is_unwrapped_body_param(param.annotation):
unwrapped_body_param = param
break
# Convert parameters to Pydantic models where needed
converted_body = {}
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
if param_name in exclude_params:
converted_body[param_name] = value
else:
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
# handle unwrapped body parameter after processing all named parameters
if unwrapped_body_param:
base_type = get_args(unwrapped_body_param.annotation)[0]
# extract only keys not already used by other params
remaining_keys = {k: v for k, v in body.items() if k not in converted_body}
converted_body[unwrapped_body_param.name] = base_type(**remaining_keys)
return converted_body

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
class PromptServiceConfig(BaseModel):
"""Configuration for the built-in prompt service.
:param run_config: Stack run configuration containing distribution info
"""
run_config: StackRunConfig
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
"""Get the prompt service implementation."""
impl = PromptServiceImpl(config, deps)
await impl.initialize()
return impl
class PromptServiceImpl(Prompts):
"""Built-in prompt service implementation using KVStore."""
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.kvstore: KVStore
async def initialize(self) -> None:
# Use prompts store reference from run config
prompts_ref = self.config.run_config.storage.stores.prompts
if not prompts_ref:
raise ValueError("storage.stores.prompts must be configured in run config")
self.kvstore = await kvstore_impl(prompts_ref)
def _get_default_key(self, prompt_id: str) -> str:
"""Get the KVStore key that stores the default version number."""
return f"prompts:v1:{prompt_id}:default"
async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
"""Get the KVStore key for prompt data, returning default version if applicable."""
if version:
return self._get_version_key(prompt_id, str(version))
default_key = self._get_default_key(prompt_id)
resolved_version = await self.kvstore.get(default_key)
if resolved_version is None:
raise ValueError(f"Prompt {prompt_id}:default not found")
return self._get_version_key(prompt_id, resolved_version)
def _get_version_key(self, prompt_id: str, version: str) -> str:
"""Get the KVStore key for a specific prompt version."""
return f"prompts:v1:{prompt_id}:{version}"
def _get_list_key_prefix(self) -> str:
"""Get the key prefix for listing prompts."""
return "prompts:v1:"
def _serialize_prompt(self, prompt: Prompt) -> str:
"""Serialize a prompt to JSON string for storage."""
return json.dumps(
{
"prompt_id": prompt.prompt_id,
"prompt": prompt.prompt,
"version": prompt.version,
"variables": prompt.variables or [],
"is_default": prompt.is_default,
}
)
def _deserialize_prompt(self, data: str) -> Prompt:
"""Deserialize a prompt from JSON string."""
obj = json.loads(data)
return Prompt(
prompt_id=obj["prompt_id"],
prompt=obj["prompt"],
version=obj["version"],
variables=obj.get("variables", []),
is_default=obj.get("is_default", False),
)
async def list_prompts(self) -> ListPromptsResponse:
"""List all prompts (default versions only)."""
prefix = self._get_list_key_prefix()
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
prompts = []
for key in keys:
if key.endswith(":default"):
try:
default_version = await self.kvstore.get(key)
if default_version:
prompt_id = key.replace(prefix, "").replace(":default", "")
version_key = self._get_version_key(prompt_id, default_version)
data = await self.kvstore.get(version_key)
if data:
prompt = self._deserialize_prompt(data)
prompts.append(prompt)
except (json.JSONDecodeError, KeyError):
continue
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
return ListPromptsResponse(data=prompts)
async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
"""Get a prompt by its identifier and optional version."""
key = await self._get_prompt_key(prompt_id, version)
data = await self.kvstore.get(key)
if data is None:
raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
return self._deserialize_prompt(data)
async def create_prompt(
self,
prompt: str,
variables: list[str] | None = None,
) -> Prompt:
"""Create a new prompt."""
if variables is None:
variables = []
prompt_obj = Prompt(
prompt_id=Prompt.generate_prompt_id(),
prompt=prompt,
version=1,
variables=variables,
)
version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
data = self._serialize_prompt(prompt_obj)
await self.kvstore.set(version_key, data)
default_key = self._get_default_key(prompt_obj.prompt_id)
await self.kvstore.set(default_key, str(prompt_obj.version))
return prompt_obj
async def update_prompt(
self,
prompt_id: str,
prompt: str,
version: int,
variables: list[str] | None = None,
set_as_default: bool = True,
) -> Prompt:
"""Update an existing prompt (increments version)."""
if version < 1:
raise ValueError("Version must be >= 1")
if variables is None:
variables = []
prompt_versions = await self.list_prompt_versions(prompt_id)
latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version))
if version and latest_prompt.version != version:
raise ValueError(
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request."
)
current_version = latest_prompt.version if version is None else version
new_version = current_version + 1
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
version_key = self._get_version_key(prompt_id, str(new_version))
data = self._serialize_prompt(updated_prompt)
await self.kvstore.set(version_key, data)
if set_as_default:
await self.set_default_version(prompt_id, new_version)
return updated_prompt
async def delete_prompt(self, prompt_id: str) -> None:
"""Delete a prompt and all its versions."""
await self.get_prompt(prompt_id)
prefix = f"prompts:v1:{prompt_id}:"
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
for key in keys:
await self.kvstore.delete(key)
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
"""List all versions of a specific prompt."""
prefix = f"prompts:v1:{prompt_id}:"
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
default_version = None
prompts = []
for key in keys:
data = await self.kvstore.get(key)
if key.endswith(":default"):
default_version = data
else:
if data:
prompt_obj = self._deserialize_prompt(data)
prompts.append(prompt_obj)
if not prompts:
raise ValueError(f"Prompt {prompt_id} not found")
for prompt in prompts:
prompt.is_default = str(prompt.version) == default_version
prompts.sort(key=lambda x: x.version)
return ListPromptsResponse(data=prompts)
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
"""Set which version of a prompt should be the default, If not set. the default is the latest."""
version_key = self._get_version_key(prompt_id, str(version))
data = await self.kvstore.get(version_key)
if data is None:
raise ValueError(f"Prompt {prompt_id} version {version} not found")
default_key = self._get_default_key(prompt_id)
await self.kvstore.set(default_key, str(version))
return self._deserialize_prompt(data)

View file

@ -0,0 +1,137 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from .datatypes import StackRunConfig
from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
class ProviderImplConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = ProviderImpl(config, deps)
await impl.initialize()
return impl
class ProviderImpl(Providers):
def __init__(self, config, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown")
pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
providers_health = await self.get_providers_health()
ret = []
for api, providers in safe_config.providers.items():
for p in providers:
# Skip providers that are not enabled
if p.provider_id is None:
continue
ret.append(
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
config=p.config,
health=providers_health.get(api, {}).get(
p.provider_id,
HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
),
)
)
return ListProvidersResponse(data=ret)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
all_providers = await self.list_providers()
for p in all_providers.data:
if p.provider_id == provider_id:
return p
raise ValueError(f"Provider {provider_id} not found")
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
"""Get health status for all providers.
Returns:
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
Each API maps to a dictionary of provider IDs to their health responses.
"""
providers_health: dict[str, dict[str, HealthResponse]] = {}
# The timeout has to be long enough to allow all the providers to be checked, especially in
# the case of the inference router health check since it checks all registered inference
# providers.
# The timeout must not be equal to the one set by health method for a given implementation,
# otherwise we will miss some providers.
timeout = 3.0
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
# Skip special implementations (inspect/providers) that don't have provider specs
if not hasattr(impl, "__provider_spec__"):
return None
api_name = impl.__provider_spec__.api.name
if not hasattr(impl, "health"):
return (
api_name,
HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
)
try:
health = await asyncio.wait_for(impl.health(), timeout=timeout)
return api_name, health
except TimeoutError:
return (
api_name,
HealthResponse(
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
),
)
except Exception as e:
return (
api_name,
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
)
# Create tasks for all providers
tasks = [check_provider_health(impl) for impl in self.deps.values()]
# Wait for all health checks to complete
results = await asyncio.gather(*tasks)
# Organize results by API and provider ID
for result in results:
if result is None: # Skip special implementations
continue
api_name, health_response = result
providers_health[api_name] = health_response
return providers_health

View file

@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import contextvars
import json
from contextlib import AbstractContextManager
from typing import Any
from llama_stack.core.datatypes import User
from llama_stack.log import get_logger
from .utils.dynamic import instantiate_class_type
log = get_logger(name=__name__, category="core")
# Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data"""
def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None):
self.provider_data = provider_data or {}
if user:
self.provider_data["__authenticated_user"] = user
self.token = None
def __enter__(self):
# Save the current value and set the new one
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Restore the previous value
if self.token is not None:
PROVIDER_DATA_VAR.reset(self.token)
class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__
if not spec:
raise ValueError(f"Provider spec not set on {self.__class__}")
provider_type = spec.provider_type
validator_class = spec.provider_data_validator
if not validator_class:
raise ValueError(f"Provider {provider_type} does not have a validator")
val = PROVIDER_DATA_VAR.get()
if not val:
return None
validator = instantiate_class_type(validator_class)
try:
provider_data = validator(**val)
return provider_data
except Exception as e:
log.error(f"Error parsing provider data: {e}")
return None
def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | None:
"""Parse provider data from request headers"""
keys = [
"X-LlamaStack-Provider-Data",
"x-llamastack-provider-data",
]
val = None
for key in keys:
val = headers.get(key, None)
if val:
break
if not val:
return None
try:
return json.loads(val)
except json.JSONDecodeError:
log.error("Provider data not encoded as a JSON object!")
return None
def request_provider_data_context(
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
) -> AbstractContextManager:
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data, auth_attributes)
def get_authenticated_user() -> User | None:
"""Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return None
return provider_data.get("__authenticated_user")
def user_from_scope(scope: dict) -> User | None:
"""Create a User object from ASGI scope data (set by authentication middleware)"""
user_attributes = scope.get("user_attributes", {})
principal = scope.get("principal", "")
# auth not enabled
if not principal and not user_attributes:
return None
return User(principal=principal, attributes=user_attributes)

Some files were not shown because too many files have changed in this diff Show more