Merge origin/main into add-missing-provider-data-impls

Resolved conflicts in:
- benchmarking/k8s-benchmark/stack_run_config.yaml (accepted new storage schema)
- llama_stack/providers/remote/inference/cerebras/cerebras.py (kept provider data support)
- llama_stack/providers/remote/inference/cerebras/config.py (kept provider data support)
- llama_stack/providers/remote/inference/nvidia/config.py (kept provider data support)
- llama_stack/providers/remote/inference/runpod/config.py (merged imports)
- pyproject.toml (kept databricks-sdk dependency)
This commit is contained in:
Ashwin Bharambe 2025-10-27 11:39:00 -07:00
commit 9eb9a37ee4
1880 changed files with 804868 additions and 70533 deletions

View file

@ -43,17 +43,17 @@ from .openai_responses import (
@json_schema_type
class ResponseShieldSpec(BaseModel):
"""Specification for a shield to apply during response generation.
class ResponseGuardrailSpec(BaseModel):
"""Specification for a guardrail to apply during response generation.
:param type: The type/identifier of the shield.
:param type: The type/identifier of the guardrail.
"""
type: str
# TODO: more fields to be added for shield configuration
# TODO: more fields to be added for guardrail configuration
ResponseShield = str | ResponseShieldSpec
ResponseGuardrail = str | ResponseGuardrailSpec
class Attachment(BaseModel):
@ -812,6 +812,7 @@ class Agents(Protocol):
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
@ -819,10 +820,10 @@ class Agents(Protocol):
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
shields: Annotated[
list[ResponseShield] | None,
guardrails: Annotated[
list[ResponseGuardrail] | None,
ExtraBodyField(
"List of shields to apply during response generation. Shields provide safety and content moderation."
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
),
] = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
@ -831,8 +832,9 @@ class Agents(Protocol):
:param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response.
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
:returns: An OpenAIResponseObject.
"""
...

View file

@ -131,8 +131,20 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
@json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel):
"""Refusal content within a streamed response part.
:param type: Content part type identifier, always "refusal"
:param refusal: Refusal text supplied by the model
"""
type: Literal["refusal"] = "refusal"
refusal: str
OpenAIResponseOutputMessageContent = Annotated[
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
@ -346,6 +358,174 @@ class OpenAIResponseText(BaseModel):
format: OpenAIResponseTextFormat | None = None
# Must match type Literals of OpenAIResponseInputToolWebSearch below
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
@json_schema_type
class OpenAIResponseInputToolWebSearch(BaseModel):
"""Web search tool configuration for OpenAI response inputs.
:param type: Web search tool type variant to use
:param search_context_size: (Optional) Size of search context, must be "low", "medium", or "high"
"""
# Must match values of WebSearchToolTypes above
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
"web_search"
)
# TODO: actually use search_context_size somewhere...
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
# TODO: add user_location
@json_schema_type
class OpenAIResponseInputToolFunction(BaseModel):
"""Function tool configuration for OpenAI response inputs.
:param type: Tool type identifier, always "function"
:param name: Name of the function that can be called
:param description: (Optional) Description of what the function does
:param parameters: (Optional) JSON schema defining the function's parameters
:param strict: (Optional) Whether to enforce strict parameter validation
"""
type: Literal["function"] = "function"
name: str
description: str | None = None
parameters: dict[str, Any] | None
strict: bool | None = None
@json_schema_type
class OpenAIResponseInputToolFileSearch(BaseModel):
"""File search tool configuration for OpenAI response inputs.
:param type: Tool type identifier, always "file_search"
:param vector_store_ids: List of vector store identifiers to search within
:param filters: (Optional) Additional filters to apply to the search
:param max_num_results: (Optional) Maximum number of search results to return (1-50)
:param ranking_options: (Optional) Options for ranking and scoring search results
"""
type: Literal["file_search"] = "file_search"
vector_store_ids: list[str]
filters: dict[str, Any] | None = None
max_num_results: int | None = Field(default=10, ge=1, le=50)
ranking_options: FileSearchRankingOptions | None = None
class ApprovalFilter(BaseModel):
"""Filter configuration for MCP tool approval requirements.
:param always: (Optional) List of tool names that always require approval
:param never: (Optional) List of tool names that never require approval
"""
always: list[str] | None = None
never: list[str] | None = None
class AllowedToolsFilter(BaseModel):
"""Filter configuration for restricting which MCP tools can be used.
:param tool_names: (Optional) List of specific tool names that are allowed
"""
tool_names: list[str] | None = None
@json_schema_type
class OpenAIResponseInputToolMCP(BaseModel):
"""Model Context Protocol (MCP) tool configuration for OpenAI response inputs.
:param type: Tool type identifier, always "mcp"
:param server_label: Label to identify this MCP server
:param server_url: URL endpoint of the MCP server
:param headers: (Optional) HTTP headers to include when connecting to the server
:param require_approval: Approval requirement for tool calls ("always", "never", or filter)
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
"""
type: Literal["mcp"] = "mcp"
server_label: str
server_url: str
headers: dict[str, Any] | None = None
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
allowed_tools: list[str] | AllowedToolsFilter | None = None
OpenAIResponseInputTool = Annotated[
OpenAIResponseInputToolWebSearch
| OpenAIResponseInputToolFileSearch
| OpenAIResponseInputToolFunction
| OpenAIResponseInputToolMCP,
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
@json_schema_type
class OpenAIResponseToolMCP(BaseModel):
"""Model Context Protocol (MCP) tool configuration for OpenAI response object.
:param type: Tool type identifier, always "mcp"
:param server_label: Label to identify this MCP server
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
"""
type: Literal["mcp"] = "mcp"
server_label: str
allowed_tools: list[str] | AllowedToolsFilter | None = None
OpenAIResponseTool = Annotated[
OpenAIResponseInputToolWebSearch
| OpenAIResponseInputToolFileSearch
| OpenAIResponseInputToolFunction
| OpenAIResponseToolMCP, # The only type that differes from that in the inputs is the MCP tool
Field(discriminator="type"),
]
register_schema(OpenAIResponseTool, name="OpenAIResponseTool")
class OpenAIResponseUsageOutputTokensDetails(BaseModel):
"""Token details for output tokens in OpenAI response usage.
:param reasoning_tokens: Number of tokens used for reasoning (o1/o3 models)
"""
reasoning_tokens: int | None = None
class OpenAIResponseUsageInputTokensDetails(BaseModel):
"""Token details for input tokens in OpenAI response usage.
:param cached_tokens: Number of tokens retrieved from cache
"""
cached_tokens: int | None = None
@json_schema_type
class OpenAIResponseUsage(BaseModel):
"""Usage information for OpenAI response.
:param input_tokens: Number of tokens in the input
:param output_tokens: Number of tokens in the output
:param total_tokens: Total tokens used (input + output)
:param input_tokens_details: Detailed breakdown of input token usage
:param output_tokens_details: Detailed breakdown of output token usage
"""
input_tokens: int
output_tokens: int
total_tokens: int
input_tokens_details: OpenAIResponseUsageInputTokensDetails | None = None
output_tokens_details: OpenAIResponseUsageOutputTokensDetails | None = None
@json_schema_type
class OpenAIResponseObject(BaseModel):
"""Complete OpenAI response object containing generation results and metadata.
@ -362,7 +542,10 @@ class OpenAIResponseObject(BaseModel):
:param temperature: (Optional) Sampling temperature used for generation
:param text: Text formatting configuration for the response
:param top_p: (Optional) Nucleus sampling parameter used for generation
:param tools: (Optional) An array of tools the model may call while generating a response.
:param truncation: (Optional) Truncation strategy applied to the response
:param usage: (Optional) Token usage information for the response
:param instructions: (Optional) System message inserted into the model's context
"""
created_at: int
@ -379,7 +562,10 @@ class OpenAIResponseObject(BaseModel):
# before the field was added. New responses will have this set always.
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
top_p: float | None = None
tools: list[OpenAIResponseTool] | None = None
truncation: str | None = None
usage: OpenAIResponseUsage | None = None
instructions: str | None = None
@json_schema_type
@ -400,7 +586,7 @@ class OpenAIDeleteResponseObject(BaseModel):
class OpenAIResponseObjectStreamResponseCreated(BaseModel):
"""Streaming event indicating a new response has been created.
:param response: The newly created response object
:param response: The response object that was created
:param type: Event type identifier, always "response.created"
"""
@ -408,11 +594,25 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
type: Literal["response.created"] = "response.created"
@json_schema_type
class OpenAIResponseObjectStreamResponseInProgress(BaseModel):
"""Streaming event indicating the response remains in progress.
:param response: Current response state while in progress
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.in_progress"
"""
response: OpenAIResponseObject
sequence_number: int
type: Literal["response.in_progress"] = "response.in_progress"
@json_schema_type
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
"""Streaming event indicating a response has been completed.
:param response: The completed response object
:param response: Completed response object
:param type: Event type identifier, always "response.completed"
"""
@ -420,6 +620,34 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
type: Literal["response.completed"] = "response.completed"
@json_schema_type
class OpenAIResponseObjectStreamResponseIncomplete(BaseModel):
"""Streaming event emitted when a response ends in an incomplete state.
:param response: Response object describing the incomplete state
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.incomplete"
"""
response: OpenAIResponseObject
sequence_number: int
type: Literal["response.incomplete"] = "response.incomplete"
@json_schema_type
class OpenAIResponseObjectStreamResponseFailed(BaseModel):
"""Streaming event emitted when a response fails.
:param response: Response object describing the failure
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.failed"
"""
response: OpenAIResponseObject
sequence_number: int
type: Literal["response.failed"] = "response.failed"
@json_schema_type
class OpenAIResponseObjectStreamResponseOutputItemAdded(BaseModel):
"""Streaming event for when a new output item is added to the response.
@ -650,19 +878,34 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
@json_schema_type
class OpenAIResponseContentPartOutputText(BaseModel):
"""Text content within a streamed response part.
:param type: Content part type identifier, always "output_text"
:param text: Text emitted for this content part
:param annotations: Structured annotations associated with the text
:param logprobs: (Optional) Token log probability details
"""
type: Literal["output_text"] = "output_text"
text: str
# TODO: add annotations, logprobs, etc.
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
logprobs: list[dict[str, Any]] | None = None
@json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel):
type: Literal["refusal"] = "refusal"
refusal: str
class OpenAIResponseContentPartReasoningText(BaseModel):
"""Reasoning text emitted as part of a streamed response.
:param type: Content part type identifier, always "reasoning_text"
:param text: Reasoning text supplied by the model
"""
type: Literal["reasoning_text"] = "reasoning_text"
text: str
OpenAIResponseContentPart = Annotated[
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal,
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal | OpenAIResponseContentPartReasoningText,
Field(discriminator="type"),
]
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
@ -672,15 +915,19 @@ register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
"""Streaming event for when a new content part is added to a response item.
:param content_index: Index position of the part within the content array
:param response_id: Unique identifier of the response containing this content
:param item_id: Unique identifier of the output item containing this content part
:param output_index: Index position of the output item in the response
:param part: The content part that was added
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.content_part.added"
"""
content_index: int
response_id: str
item_id: str
output_index: int
part: OpenAIResponseContentPart
sequence_number: int
type: Literal["response.content_part.added"] = "response.content_part.added"
@ -690,22 +937,269 @@ class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel):
"""Streaming event for when a content part is completed.
:param content_index: Index position of the part within the content array
:param response_id: Unique identifier of the response containing this content
:param item_id: Unique identifier of the output item containing this content part
:param output_index: Index position of the output item in the response
:param part: The completed content part
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.content_part.done"
"""
content_index: int
response_id: str
item_id: str
output_index: int
part: OpenAIResponseContentPart
sequence_number: int
type: Literal["response.content_part.done"] = "response.content_part.done"
@json_schema_type
class OpenAIResponseObjectStreamResponseReasoningTextDelta(BaseModel):
"""Streaming event for incremental reasoning text updates.
:param content_index: Index position of the reasoning content part
:param delta: Incremental reasoning text being added
:param item_id: Unique identifier of the output item being updated
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.reasoning_text.delta"
"""
content_index: int
delta: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.reasoning_text.delta"] = "response.reasoning_text.delta"
@json_schema_type
class OpenAIResponseObjectStreamResponseReasoningTextDone(BaseModel):
"""Streaming event for when reasoning text is completed.
:param content_index: Index position of the reasoning content part
:param text: Final complete reasoning text
:param item_id: Unique identifier of the completed output item
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.reasoning_text.done"
"""
content_index: int
text: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.reasoning_text.done"] = "response.reasoning_text.done"
@json_schema_type
class OpenAIResponseContentPartReasoningSummary(BaseModel):
"""Reasoning summary part in a streamed response.
:param type: Content part type identifier, always "summary_text"
:param text: Summary text
"""
type: Literal["summary_text"] = "summary_text"
text: str
@json_schema_type
class OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded(BaseModel):
"""Streaming event for when a new reasoning summary part is added.
:param item_id: Unique identifier of the output item
:param output_index: Index position of the output item
:param part: The summary part that was added
:param sequence_number: Sequential number for ordering streaming events
:param summary_index: Index of the summary part within the reasoning summary
:param type: Event type identifier, always "response.reasoning_summary_part.added"
"""
item_id: str
output_index: int
part: OpenAIResponseContentPartReasoningSummary
sequence_number: int
summary_index: int
type: Literal["response.reasoning_summary_part.added"] = "response.reasoning_summary_part.added"
@json_schema_type
class OpenAIResponseObjectStreamResponseReasoningSummaryPartDone(BaseModel):
"""Streaming event for when a reasoning summary part is completed.
:param item_id: Unique identifier of the output item
:param output_index: Index position of the output item
:param part: The completed summary part
:param sequence_number: Sequential number for ordering streaming events
:param summary_index: Index of the summary part within the reasoning summary
:param type: Event type identifier, always "response.reasoning_summary_part.done"
"""
item_id: str
output_index: int
part: OpenAIResponseContentPartReasoningSummary
sequence_number: int
summary_index: int
type: Literal["response.reasoning_summary_part.done"] = "response.reasoning_summary_part.done"
@json_schema_type
class OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta(BaseModel):
"""Streaming event for incremental reasoning summary text updates.
:param delta: Incremental summary text being added
:param item_id: Unique identifier of the output item
:param output_index: Index position of the output item
:param sequence_number: Sequential number for ordering streaming events
:param summary_index: Index of the summary part within the reasoning summary
:param type: Event type identifier, always "response.reasoning_summary_text.delta"
"""
delta: str
item_id: str
output_index: int
sequence_number: int
summary_index: int
type: Literal["response.reasoning_summary_text.delta"] = "response.reasoning_summary_text.delta"
@json_schema_type
class OpenAIResponseObjectStreamResponseReasoningSummaryTextDone(BaseModel):
"""Streaming event for when reasoning summary text is completed.
:param text: Final complete summary text
:param item_id: Unique identifier of the output item
:param output_index: Index position of the output item
:param sequence_number: Sequential number for ordering streaming events
:param summary_index: Index of the summary part within the reasoning summary
:param type: Event type identifier, always "response.reasoning_summary_text.done"
"""
text: str
item_id: str
output_index: int
sequence_number: int
summary_index: int
type: Literal["response.reasoning_summary_text.done"] = "response.reasoning_summary_text.done"
@json_schema_type
class OpenAIResponseObjectStreamResponseRefusalDelta(BaseModel):
"""Streaming event for incremental refusal text updates.
:param content_index: Index position of the content part
:param delta: Incremental refusal text being added
:param item_id: Unique identifier of the output item
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.refusal.delta"
"""
content_index: int
delta: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.refusal.delta"] = "response.refusal.delta"
@json_schema_type
class OpenAIResponseObjectStreamResponseRefusalDone(BaseModel):
"""Streaming event for when refusal text is completed.
:param content_index: Index position of the content part
:param refusal: Final complete refusal text
:param item_id: Unique identifier of the output item
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.refusal.done"
"""
content_index: int
refusal: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.refusal.done"] = "response.refusal.done"
@json_schema_type
class OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded(BaseModel):
"""Streaming event for when an annotation is added to output text.
:param item_id: Unique identifier of the item to which the annotation is being added
:param output_index: Index position of the output item in the response's output array
:param content_index: Index position of the content part within the output item
:param annotation_index: Index of the annotation within the content part
:param annotation: The annotation object being added
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.output_text.annotation.added"
"""
item_id: str
output_index: int
content_index: int
annotation_index: int
annotation: OpenAIResponseAnnotations
sequence_number: int
type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added"
@json_schema_type
class OpenAIResponseObjectStreamResponseFileSearchCallInProgress(BaseModel):
"""Streaming event for file search calls in progress.
:param item_id: Unique identifier of the file search call
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.file_search_call.in_progress"
"""
item_id: str
output_index: int
sequence_number: int
type: Literal["response.file_search_call.in_progress"] = "response.file_search_call.in_progress"
@json_schema_type
class OpenAIResponseObjectStreamResponseFileSearchCallSearching(BaseModel):
"""Streaming event for file search currently searching.
:param item_id: Unique identifier of the file search call
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.file_search_call.searching"
"""
item_id: str
output_index: int
sequence_number: int
type: Literal["response.file_search_call.searching"] = "response.file_search_call.searching"
@json_schema_type
class OpenAIResponseObjectStreamResponseFileSearchCallCompleted(BaseModel):
"""Streaming event for completed file search calls.
:param item_id: Unique identifier of the completed file search call
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.file_search_call.completed"
"""
item_id: str
output_index: int
sequence_number: int
type: Literal["response.file_search_call.completed"] = "response.file_search_call.completed"
OpenAIResponseObjectStream = Annotated[
OpenAIResponseObjectStreamResponseCreated
| OpenAIResponseObjectStreamResponseInProgress
| OpenAIResponseObjectStreamResponseOutputItemAdded
| OpenAIResponseObjectStreamResponseOutputItemDone
| OpenAIResponseObjectStreamResponseOutputTextDelta
@ -725,6 +1219,20 @@ OpenAIResponseObjectStream = Annotated[
| OpenAIResponseObjectStreamResponseMcpCallCompleted
| OpenAIResponseObjectStreamResponseContentPartAdded
| OpenAIResponseObjectStreamResponseContentPartDone
| OpenAIResponseObjectStreamResponseReasoningTextDelta
| OpenAIResponseObjectStreamResponseReasoningTextDone
| OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded
| OpenAIResponseObjectStreamResponseReasoningSummaryPartDone
| OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta
| OpenAIResponseObjectStreamResponseReasoningSummaryTextDone
| OpenAIResponseObjectStreamResponseRefusalDelta
| OpenAIResponseObjectStreamResponseRefusalDone
| OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded
| OpenAIResponseObjectStreamResponseFileSearchCallInProgress
| OpenAIResponseObjectStreamResponseFileSearchCallSearching
| OpenAIResponseObjectStreamResponseFileSearchCallCompleted
| OpenAIResponseObjectStreamResponseIncomplete
| OpenAIResponseObjectStreamResponseFailed
| OpenAIResponseObjectStreamResponseCompleted,
Field(discriminator="type"),
]
@ -746,128 +1254,15 @@ class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
OpenAIResponseInput = Annotated[
# Responses API allows output messages to be passed in as input
OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
OpenAIResponseOutput
| OpenAIResponseInputFunctionToolCallOutput
| OpenAIResponseMCPApprovalRequest
| OpenAIResponseMCPApprovalResponse
|
# Fallback to the generic message type as a last resort
OpenAIResponseMessage,
| OpenAIResponseMessage,
Field(union_mode="left_to_right"),
]
register_schema(OpenAIResponseInput, name="OpenAIResponseInput")
# Must match type Literals of OpenAIResponseInputToolWebSearch below
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
@json_schema_type
class OpenAIResponseInputToolWebSearch(BaseModel):
"""Web search tool configuration for OpenAI response inputs.
:param type: Web search tool type variant to use
:param search_context_size: (Optional) Size of search context, must be "low", "medium", or "high"
"""
# Must match values of WebSearchToolTypes above
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
"web_search"
)
# TODO: actually use search_context_size somewhere...
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
# TODO: add user_location
@json_schema_type
class OpenAIResponseInputToolFunction(BaseModel):
"""Function tool configuration for OpenAI response inputs.
:param type: Tool type identifier, always "function"
:param name: Name of the function that can be called
:param description: (Optional) Description of what the function does
:param parameters: (Optional) JSON schema defining the function's parameters
:param strict: (Optional) Whether to enforce strict parameter validation
"""
type: Literal["function"] = "function"
name: str
description: str | None = None
parameters: dict[str, Any] | None
strict: bool | None = None
@json_schema_type
class OpenAIResponseInputToolFileSearch(BaseModel):
"""File search tool configuration for OpenAI response inputs.
:param type: Tool type identifier, always "file_search"
:param vector_store_ids: List of vector store identifiers to search within
:param filters: (Optional) Additional filters to apply to the search
:param max_num_results: (Optional) Maximum number of search results to return (1-50)
:param ranking_options: (Optional) Options for ranking and scoring search results
"""
type: Literal["file_search"] = "file_search"
vector_store_ids: list[str]
filters: dict[str, Any] | None = None
max_num_results: int | None = Field(default=10, ge=1, le=50)
ranking_options: FileSearchRankingOptions | None = None
class ApprovalFilter(BaseModel):
"""Filter configuration for MCP tool approval requirements.
:param always: (Optional) List of tool names that always require approval
:param never: (Optional) List of tool names that never require approval
"""
always: list[str] | None = None
never: list[str] | None = None
class AllowedToolsFilter(BaseModel):
"""Filter configuration for restricting which MCP tools can be used.
:param tool_names: (Optional) List of specific tool names that are allowed
"""
tool_names: list[str] | None = None
@json_schema_type
class OpenAIResponseInputToolMCP(BaseModel):
"""Model Context Protocol (MCP) tool configuration for OpenAI response inputs.
:param type: Tool type identifier, always "mcp"
:param server_label: Label to identify this MCP server
:param server_url: URL endpoint of the MCP server
:param headers: (Optional) HTTP headers to include when connecting to the server
:param require_approval: Approval requirement for tool calls ("always", "never", or filter)
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
"""
type: Literal["mcp"] = "mcp"
server_label: str
server_url: str
headers: dict[str, Any] | None = None
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
allowed_tools: list[str] | AllowedToolsFilter | None = None
OpenAIResponseInputTool = Annotated[
OpenAIResponseInputToolWebSearch
| OpenAIResponseInputToolFileSearch
| OpenAIResponseInputToolFunction
| OpenAIResponseInputToolMCP,
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
class ListOpenAIResponseInputItem(BaseModel):
"""List container for OpenAI response input items.

View file

@ -86,3 +86,18 @@ class TokenValidationError(ValueError):
def __init__(self, message: str) -> None:
super().__init__(message)
class ConversationNotFoundError(ResourceNotFoundError):
"""raised when Llama Stack cannot find a referenced conversation"""
def __init__(self, conversation_id: str) -> None:
super().__init__(conversation_id, "Conversation", "client.conversations.list()")
class InvalidConversationIdError(ValueError):
"""raised when a conversation ID has an invalid format"""
def __init__(self, conversation_id: str) -> None:
message = f"Invalid conversation ID '{conversation_id}'. Expected an ID that begins with 'conv_'."
super().__init__(message)

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Annotated, Literal, Protocol, runtime_checkable
from openai import NOT_GIVEN
from openai._types import NotGiven
from openai.types.responses.response_includable import ResponseIncludable
from pydantic import BaseModel, Field
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseMessage,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall,
@ -20,7 +21,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
Metadata = dict[str, str]
@ -61,9 +62,14 @@ class ConversationMessage(BaseModel):
ConversationItem = Annotated[
OpenAIResponseMessage
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseInputFunctionToolCallOutput
| OpenAIResponseMCPApprovalRequest
| OpenAIResponseMCPApprovalResponse
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"),
@ -142,6 +148,20 @@ class ConversationItemCreateRequest(BaseModel):
)
class ConversationItemInclude(StrEnum):
"""
Specify additional output data to include in the model response.
"""
web_search_call_action_sources = "web_search_call.action.sources"
code_interpreter_call_outputs = "code_interpreter_call.outputs"
computer_call_output_output_image_url = "computer_call_output.output.image_url"
file_search_call_results = "file_search_call.results"
message_input_image_image_url = "message.input_image.image_url"
message_output_text_logprobs = "message.output_text.logprobs"
reasoning_encrypted_content = "reasoning.encrypted_content"
@json_schema_type
class ConversationItemList(BaseModel):
"""List of conversation items with pagination."""
@ -165,7 +185,9 @@ class ConversationItemDeletedResource(BaseModel):
@runtime_checkable
@trace_protocol
class Conversations(Protocol):
"""Protocol for conversation management operations."""
"""Conversations
Protocol for conversation management operations."""
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
async def create_conversation(
@ -173,6 +195,8 @@ class Conversations(Protocol):
) -> Conversation:
"""Create a conversation.
Create a conversation.
:param items: Initial items to include in the conversation context.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The created conversation object.
@ -181,7 +205,9 @@ class Conversations(Protocol):
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID.
"""Retrieve a conversation.
Get a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The conversation object.
@ -190,7 +216,9 @@ class Conversations(Protocol):
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID.
"""Update a conversation.
Update a conversation's metadata with the given ID.
:param conversation_id: The conversation identifier.
:param metadata: Set of key-value pairs that can be attached to an object.
@ -200,7 +228,9 @@ class Conversations(Protocol):
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID.
"""Delete a conversation.
Delete a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The deleted conversation resource.
@ -209,7 +239,9 @@ class Conversations(Protocol):
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create items in the conversation.
"""Create items.
Create items in the conversation.
:param conversation_id: The conversation identifier.
:param items: Items to include in the conversation context.
@ -219,7 +251,9 @@ class Conversations(Protocol):
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item.
"""Retrieve an item.
Retrieve a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
@ -228,15 +262,17 @@ class Conversations(Protocol):
...
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
async def list(
async def list_items(
self,
conversation_id: str,
after: str | NotGiven = NOT_GIVEN,
include: list[ResponseIncludable] | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
after: str | None = None,
include: list[ConversationItemInclude] | None = None,
limit: int | None = None,
order: Literal["asc", "desc"] | None = None,
) -> ConversationItemList:
"""List items in the conversation.
"""List items.
List items in the conversation.
:param conversation_id: The conversation identifier.
:param after: An item ID to list items after, used in pagination.
@ -251,7 +287,9 @@ class Conversations(Protocol):
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item.
"""Delete an item.
Delete a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.

View file

@ -96,7 +96,6 @@ class Api(Enum, metaclass=DynamicApiMeta):
:cvar telemetry: Observability and system monitoring
:cvar models: Model metadata and management
:cvar shields: Safety shield implementations
:cvar vector_dbs: Vector database management
:cvar datasets: Dataset creation and management
:cvar scoring_functions: Scoring function definitions
:cvar benchmarks: Benchmark suite management
@ -118,11 +117,9 @@ class Api(Enum, metaclass=DynamicApiMeta):
post_training = "post_training"
tool_runtime = "tool_runtime"
telemetry = "telemetry"
models = "models"
shields = "shields"
vector_dbs = "vector_dbs"
vector_stores = "vector_stores" # only used for routing table
datasets = "datasets"
scoring_functions = "scoring_functions"
benchmarks = "benchmarks"

View file

@ -82,7 +82,9 @@ class EvaluateResponse(BaseModel):
class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
"""Evaluations
Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)

View file

@ -12,7 +12,7 @@ from pydantic import BaseModel, Field
from llama_stack.apis.common.responses import Order
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod

View file

@ -14,6 +14,7 @@ from typing import (
runtime_checkable,
)
from fastapi import Body
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
@ -22,6 +23,7 @@ from llama_stack.apis.common.responses import Order
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry import MetricResponseMixin
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
@ -29,7 +31,6 @@ from llama_stack.models.llama.datatypes import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
register_schema(ToolCall)
@ -96,7 +97,7 @@ class SamplingParams(BaseModel):
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: int | None = 0
max_tokens: int | None = None
repetition_penalty: float | None = 1.0
stop: list[str] | None = None
@ -776,12 +777,14 @@ class OpenAIChoiceDelta(BaseModel):
:param refusal: (Optional) The refusal of the delta
:param role: (Optional) The role of the delta
:param tool_calls: (Optional) The tool calls of the delta
:param reasoning_content: (Optional) The reasoning content from the model (non-standard, for o1/o3 models)
"""
content: str | None = None
refusal: str | None = None
role: str | None = None
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
reasoning_content: str | None = None
@json_schema_type
@ -816,6 +819,42 @@ class OpenAIChoice(BaseModel):
logprobs: OpenAIChoiceLogprobs | None = None
class OpenAIChatCompletionUsageCompletionTokensDetails(BaseModel):
"""Token details for output tokens in OpenAI chat completion usage.
:param reasoning_tokens: Number of tokens used for reasoning (o1/o3 models)
"""
reasoning_tokens: int | None = None
class OpenAIChatCompletionUsagePromptTokensDetails(BaseModel):
"""Token details for prompt tokens in OpenAI chat completion usage.
:param cached_tokens: Number of tokens retrieved from cache
"""
cached_tokens: int | None = None
@json_schema_type
class OpenAIChatCompletionUsage(BaseModel):
"""Usage information for OpenAI chat completion.
:param prompt_tokens: Number of tokens in the prompt
:param completion_tokens: Number of tokens in the completion
:param total_tokens: Total tokens used (prompt + completion)
:param input_tokens_details: Detailed breakdown of input token usage
:param output_tokens_details: Detailed breakdown of output token usage
"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: OpenAIChatCompletionUsagePromptTokensDetails | None = None
completion_tokens_details: OpenAIChatCompletionUsageCompletionTokensDetails | None = None
@json_schema_type
class OpenAIChatCompletion(BaseModel):
"""Response from an OpenAI-compatible chat completion request.
@ -825,6 +864,7 @@ class OpenAIChatCompletion(BaseModel):
:param object: The object type, which will be "chat.completion"
:param created: The Unix timestamp in seconds when the chat completion was created
:param model: The model that was used to generate the chat completion
:param usage: Token usage information for the completion
"""
id: str
@ -832,6 +872,7 @@ class OpenAIChatCompletion(BaseModel):
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
usage: OpenAIChatCompletionUsage | None = None
@json_schema_type
@ -843,6 +884,7 @@ class OpenAIChatCompletionChunk(BaseModel):
:param object: The object type, which will be "chat.completion.chunk"
:param created: The Unix timestamp in seconds when the chat completion was created
:param model: The model that was used to generate the chat completion
:param usage: Token usage information (typically included in final chunk with stream_options)
"""
id: str
@ -850,6 +892,7 @@ class OpenAIChatCompletionChunk(BaseModel):
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
usage: OpenAIChatCompletionUsage | None = None
@json_schema_type
@ -995,6 +1038,127 @@ class ListOpenAIChatCompletionResponse(BaseModel):
object: Literal["list"] = "list"
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAICompletionRequestWithExtraBody(BaseModel, extra="allow"):
"""Request parameters for OpenAI-compatible completion endpoint.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
"""
# Standard OpenAI completion parameters
model: str
prompt: str | list[str] | list[int] | list[list[int]]
best_of: int | None = None
echo: bool | None = None
frequency_penalty: float | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_tokens: int | None = None
n: int | None = None
presence_penalty: float | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
top_p: float | None = None
user: str | None = None
suffix: str | None = None
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAIChatCompletionRequestWithExtraBody(BaseModel, extra="allow"):
"""Request parameters for OpenAI-compatible chat completion endpoint.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
"""
# Standard OpenAI chat completion parameters
model: str
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)]
frequency_penalty: float | None = None
function_call: str | dict[str, Any] | None = None
functions: list[dict[str, Any]] | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_completion_tokens: int | None = None
max_tokens: int | None = None
n: int | None = None
parallel_tool_calls: bool | None = None
presence_penalty: float | None = None
response_format: OpenAIResponseFormatParam | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
tool_choice: str | dict[str, Any] | None = None
tools: list[dict[str, Any]] | None = None
top_logprobs: int | None = None
top_p: float | None = None
user: str | None = None
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAIEmbeddingsRequestWithExtraBody(BaseModel, extra="allow"):
"""Request parameters for OpenAI-compatible embeddings endpoint.
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
:param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float".
:param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
:param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
"""
model: str
input: str | list[str]
encoding_format: str | None = "float"
dimensions: int | None = None
user: str | None = None
@runtime_checkable
@trace_protocol
class InferenceProvider(Protocol):
@ -1029,52 +1193,11 @@ class InferenceProvider(Protocol):
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
) -> OpenAICompletion:
"""Create completion.
Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
:returns: An OpenAICompletion.
"""
...
@ -1083,57 +1206,11 @@ class InferenceProvider(Protocol):
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Create chat completions.
Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:returns: An OpenAIChatCompletion.
"""
...
@ -1142,21 +1219,11 @@ class InferenceProvider(Protocol):
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)],
) -> OpenAIEmbeddingsResponse:
"""Create embeddings.
Generate OpenAI-compatible embeddings for the given input using the specified model.
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
:param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float".
:param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
:param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
:returns: An OpenAIEmbeddingsResponse containing the embeddings.
"""
...
@ -1167,9 +1234,10 @@ class Inference(InferenceProvider):
Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models reorder the documents based on their relevance to a query.
"""
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)

View file

@ -73,7 +73,7 @@ class Inspect(Protocol):
"""
...
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1)
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
async def health(self) -> HealthInfo:
"""Get health status.
@ -83,7 +83,7 @@ class Inspect(Protocol):
"""
...
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1)
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
async def version(self) -> VersionInfo:
"""Get version.

View file

@ -11,7 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -27,10 +27,12 @@ class ModelType(StrEnum):
"""Enumeration of supported model types in Llama Stack.
:cvar llm: Large language model for text generation and completion
:cvar embedding: Embedding model for converting text to vector representations
:cvar rerank: Reranking model for reordering documents based on their relevance to a query
"""
llm = "llm"
embedding = "embedding"
rerank = "rerank"
@json_schema_type

View file

@ -11,7 +11,7 @@ from typing import Protocol, runtime_checkable
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod

View file

@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
class ResourceType(StrEnum):
model = "model"
shield = "shield"
vector_db = "vector_db"
vector_store = "vector_store"
dataset = "dataset"
scoring_function = "scoring_function"
benchmark = "benchmark"
@ -34,4 +34,4 @@ class Resource(BaseModel):
provider_id: str = Field(description="ID of the provider that owns this resource")
type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)")
type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_store', etc.)")

View file

@ -9,10 +9,10 @@ from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.shields import Shield
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -107,7 +107,7 @@ class Safety(Protocol):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any],
) -> RunShieldResponse:
"""Run shield.
@ -123,13 +123,13 @@ class Safety(Protocol):
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
"""Create moderation.
Classifies if text and/or image inputs are potentially harmful.
:param input: Input (or inputs) to classify.
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
:param model: The content moderation model you would like to use.
:param model: (Optional) The content moderation model you would like to use.
:returns: A moderation object.
"""
...

View file

@ -10,7 +10,7 @@ from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod

View file

@ -16,15 +16,12 @@ from typing import (
from pydantic import BaseModel, Field
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.models.llama.datatypes import Primitive
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from llama_stack.schema_utils import json_schema_type, register_schema
# Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7
REQUIRED_SCOPE = "telemetry.read"
@json_schema_type
class SpanStatus(Enum):
@ -413,7 +410,6 @@ class QueryMetricsResponse(BaseModel):
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/events", method="POST", level=LLAMA_STACK_API_V1)
async def log_event(
self,
event: Event,
@ -425,174 +421,3 @@ class Telemetry(Protocol):
:param ttl_seconds: The time to live of the event.
"""
...
@webmethod(
route="/telemetry/traces",
method="POST",
required_scope=REQUIRED_SCOPE,
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1ALPHA)
async def query_traces(
self,
attribute_filters: list[QueryCondition] | None = None,
limit: int | None = 100,
offset: int | None = 0,
order_by: list[str] | None = None,
) -> QueryTracesResponse:
"""Query traces.
:param attribute_filters: The attribute filters to apply to the traces.
:param limit: The limit of traces to return.
:param offset: The offset of the traces to return.
:param order_by: The order by of the traces to return.
:returns: A QueryTracesResponse.
"""
...
@webmethod(
route="/telemetry/traces/{trace_id:path}",
method="GET",
required_scope=REQUIRED_SCOPE,
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/telemetry/traces/{trace_id:path}",
method="GET",
required_scope=REQUIRED_SCOPE,
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_trace(self, trace_id: str) -> Trace:
"""Get a trace by its ID.
:param trace_id: The ID of the trace to get.
:returns: A Trace.
"""
...
@webmethod(
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}",
method="GET",
required_scope=REQUIRED_SCOPE,
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}",
method="GET",
required_scope=REQUIRED_SCOPE,
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_span(self, trace_id: str, span_id: str) -> Span:
"""Get a span by its ID.
:param trace_id: The ID of the trace to get the span from.
:param span_id: The ID of the span to get.
:returns: A Span.
"""
...
@webmethod(
route="/telemetry/spans/{span_id:path}/tree",
method="POST",
deprecated=True,
required_scope=REQUIRED_SCOPE,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/telemetry/spans/{span_id:path}/tree",
method="POST",
required_scope=REQUIRED_SCOPE,
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_span_tree(
self,
span_id: str,
attributes_to_return: list[str] | None = None,
max_depth: int | None = None,
) -> QuerySpanTreeResponse:
"""Get a span tree by its ID.
:param span_id: The ID of the span to get the tree from.
:param attributes_to_return: The attributes to return in the tree.
:param max_depth: The maximum depth of the tree.
:returns: A QuerySpanTreeResponse.
"""
...
@webmethod(
route="/telemetry/spans",
method="POST",
required_scope=REQUIRED_SCOPE,
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1ALPHA)
async def query_spans(
self,
attribute_filters: list[QueryCondition],
attributes_to_return: list[str],
max_depth: int | None = None,
) -> QuerySpansResponse:
"""Query spans.
:param attribute_filters: The attribute filters to apply to the spans.
:param attributes_to_return: The attributes to return in the spans.
:param max_depth: The maximum depth of the tree.
:returns: A QuerySpansResponse.
"""
...
@webmethod(route="/telemetry/spans/export", method="POST", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/telemetry/spans/export", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def save_spans_to_dataset(
self,
attribute_filters: list[QueryCondition],
attributes_to_save: list[str],
dataset_id: str,
max_depth: int | None = None,
) -> None:
"""Save spans to a dataset.
:param attribute_filters: The attribute filters to apply to the spans.
:param attributes_to_save: The attributes to save to the dataset.
:param dataset_id: The ID of the dataset to save the spans to.
:param max_depth: The maximum depth of the tree.
"""
...
@webmethod(
route="/telemetry/metrics/{metric_name}",
method="POST",
required_scope=REQUIRED_SCOPE,
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/telemetry/metrics/{metric_name}",
method="POST",
required_scope=REQUIRED_SCOPE,
level=LLAMA_STACK_API_V1ALPHA,
)
async def query_metrics(
self,
metric_name: str,
start_time: int,
end_time: int | None = None,
granularity: str | None = None,
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:
"""Query metrics.
:param metric_name: The name of the metric to query.
:param start_time: The start time of the metric to query.
:param end_time: The end time of the metric to query.
:param granularity: The granularity of the metric to query.
:param query_type: The type of query to perform.
:param label_matchers: The label matchers to apply to the metric.
:returns: A QueryMetricsResponse.
"""
...

View file

@ -12,7 +12,7 @@ from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod

View file

@ -13,7 +13,7 @@ from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from .rag_tool import RAGToolRuntime

View file

@ -1,117 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class VectorDB(Resource):
"""Vector database resource for storing and querying vector embeddings.
:param type: Type of resource, always 'vector_db' for vector databases
:param embedding_model: Name of the embedding model to use for vector generation
:param embedding_dimension: Dimension of the embedding vectors
"""
type: Literal[ResourceType.vector_db] = ResourceType.vector_db
embedding_model: str
embedding_dimension: int
vector_db_name: str | None = None
@property
def vector_db_id(self) -> str:
return self.identifier
@property
def provider_vector_db_id(self) -> str | None:
return self.provider_resource_id
class VectorDBInput(BaseModel):
"""Input parameters for creating or configuring a vector database.
:param vector_db_id: Unique identifier for the vector database
:param embedding_model: Name of the embedding model to use for vector generation
:param embedding_dimension: Dimension of the embedding vectors
:param provider_vector_db_id: (Optional) Provider-specific identifier for the vector database
"""
vector_db_id: str
embedding_model: str
embedding_dimension: int
provider_id: str | None = None
provider_vector_db_id: str | None = None
class ListVectorDBsResponse(BaseModel):
"""Response from listing vector databases.
:param data: List of vector databases
"""
data: list[VectorDB]
@runtime_checkable
@trace_protocol
class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET", level=LLAMA_STACK_API_V1)
async def list_vector_dbs(self) -> ListVectorDBsResponse:
"""List all vector databases.
:returns: A ListVectorDBsResponse.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_vector_db(
self,
vector_db_id: str,
) -> VectorDB:
"""Get a vector database by its identifier.
:param vector_db_id: The identifier of the vector database to get.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs", method="POST", level=LLAMA_STACK_API_V1)
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
"""Register a vector database.
:param vector_db_id: The identifier of the vector database to register.
:param embedding_model: The embedding model to use.
:param embedding_dimension: The dimension of the embedding model.
:param provider_id: The identifier of the provider.
:param vector_db_name: The name of the vector database.
:param provider_vector_db_id: The identifier of the vector database in the provider.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database.
:param vector_db_id: The identifier of the vector database to unregister.
"""
...

View file

@ -11,12 +11,13 @@
import uuid
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from fastapi import Body
from pydantic import BaseModel, Field
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.strong_typing.schema import register_schema
@ -92,6 +93,22 @@ class Chunk(BaseModel):
return generate_chunk_id(str(uuid.uuid4()), str(self.content))
@property
def document_id(self) -> str | None:
"""Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence."""
# Check metadata first (takes precedence)
doc_id = self.metadata.get("document_id")
if doc_id is not None:
if not isinstance(doc_id, str):
raise TypeError(f"metadata['document_id'] must be a string, got {type(doc_id).__name__}: {doc_id!r}")
return doc_id
# Fall back to chunk_metadata if available (Pydantic ensures type safety)
if self.chunk_metadata is not None:
return self.chunk_metadata.document_id
return None
@json_schema_type
class QueryChunksResponse(BaseModel):
@ -123,6 +140,7 @@ class VectorStoreFileCounts(BaseModel):
total: int
# TODO: rename this as OpenAIVectorStore
@json_schema_type
class VectorStoreObject(BaseModel):
"""OpenAI Vector Store object.
@ -466,17 +484,52 @@ class VectorStoreFilesListInBatchResponse(BaseModel):
has_more: bool = False
class VectorDBStore(Protocol):
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAICreateVectorStoreRequestWithExtraBody(BaseModel, extra="allow"):
"""Request to create a vector store with extra_body support.
:param name: (Optional) A name for the vector store
:param file_ids: List of file IDs to include in the vector store
:param expires_after: (Optional) Expiration policy for the vector store
:param chunking_strategy: (Optional) Strategy for splitting files into chunks
:param metadata: Set of key-value pairs that can be attached to the vector store
"""
name: str | None = None
file_ids: list[str] | None = None
expires_after: dict[str, Any] | None = None
chunking_strategy: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAICreateVectorStoreFileBatchRequestWithExtraBody(BaseModel, extra="allow"):
"""Request to create a vector store file batch with extra_body support.
:param file_ids: A list of File IDs that the vector store should use
:param attributes: (Optional) Key-value attributes to store with the files
:param chunking_strategy: (Optional) The chunking strategy used to chunk the file(s). Defaults to auto
"""
file_ids: list[str]
attributes: dict[str, Any] | None = None
chunking_strategy: VectorStoreChunkingStrategy | None = None
class VectorStoreTable(Protocol):
def get_vector_store(self, vector_store_id: str) -> VectorStore | None: ...
@runtime_checkable
@trace_protocol
class VectorIO(Protocol):
vector_db_store: VectorDBStore | None = None
vector_store_table: VectorStoreTable | None = None
# this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion
# TODO: rename vector_db_id to vector_store_id once Stainless is working
@webmethod(route="/vector-io/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert_chunks(
self,
@ -495,6 +548,7 @@ class VectorIO(Protocol):
"""
...
# TODO: rename vector_db_id to vector_store_id once Stainless is working
@webmethod(route="/vector-io/query", method="POST", level=LLAMA_STACK_API_V1)
async def query_chunks(
self,
@ -516,25 +570,11 @@ class VectorIO(Protocol):
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
async def openai_create_vector_store(
self,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
) -> VectorStoreObject:
"""Creates a vector store.
:param name: A name for the vector store.
:param file_ids: A list of File IDs that the vector store should use. Useful for tools like `file_search` that can access files.
:param expires_after: The expiration policy for a vector store.
:param chunking_strategy: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.
:param metadata: Set of 16 key-value pairs that can be attached to an object.
:param embedding_model: The embedding model to use for this vector store.
:param embedding_dimension: The dimension of the embedding vectors (default: 384).
:param provider_id: The ID of the provider to use for this vector store.
Generate an OpenAI-compatible vector store with the given parameters.
:returns: A VectorStoreObject representing the created vector store.
"""
...
@ -827,16 +867,12 @@ class VectorIO(Protocol):
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)],
) -> VectorStoreFileBatchObject:
"""Create a vector store file batch.
Generate an OpenAI-compatible vector store file batch for the given vector store.
:param vector_store_id: The ID of the vector store to create the file batch for.
:param file_ids: A list of File IDs that the vector store should use.
:param attributes: (Optional) Key-value attributes to store with the files.
:param chunking_strategy: (Optional) The chunking strategy used to chunk the file(s). Defaults to auto.
:returns: A VectorStoreFileBatchObject representing the created file batch.
"""
...

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .vector_dbs import *
from .vector_stores import *

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
# Internal resource type for storing the vector store routing and other information
class VectorStore(Resource):
"""Vector database resource for storing and querying vector embeddings.
:param type: Type of resource, always 'vector_store' for vector stores
:param embedding_model: Name of the embedding model to use for vector generation
:param embedding_dimension: Dimension of the embedding vectors
"""
type: Literal[ResourceType.vector_store] = ResourceType.vector_store
embedding_model: str
embedding_dimension: int
vector_store_name: str | None = None
@property
def vector_store_id(self) -> str:
return self.identifier
@property
def provider_vector_store_id(self) -> str | None:
return self.provider_resource_id
class VectorStoreInput(BaseModel):
"""Input parameters for creating or configuring a vector database.
:param vector_store_id: Unique identifier for the vector store
:param embedding_model: Name of the embedding model to use for vector generation
:param embedding_dimension: Dimension of the embedding vectors
:param provider_vector_store_id: (Optional) Provider-specific identifier for the vector store
"""
vector_store_id: str
embedding_model: str
embedding_dimension: int
provider_id: str | None = None
provider_vector_store_id: str | None = None

View file

@ -1,495 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import json
import os
import shutil
import sys
from dataclasses import dataclass
from datetime import UTC, datetime
from functools import partial
from pathlib import Path
import httpx
from pydantic import BaseModel, ConfigDict
from rich.console import Console
from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
TextColumn,
TimeRemainingColumn,
TransferSpeedColumn,
)
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
from llama_stack.models.llama.sku_types import Model
class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"download",
prog="llama download",
description="Download a model from llama.meta.com or Hugging Face Hub",
formatter_class=argparse.RawTextHelpFormatter,
)
setup_download_parser(self.parser)
def setup_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--source",
choices=["meta", "huggingface"],
default="meta",
)
parser.add_argument(
"--model-id",
required=False,
help="See `llama model list` or `llama model list --show-all` for the list of available models. Specify multiple model IDs with commas, e.g. --model-id Llama3.2-1B,Llama3.2-3B",
)
parser.add_argument(
"--hf-token",
type=str,
required=False,
default=None,
help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.",
)
parser.add_argument(
"--meta-url",
type=str,
required=False,
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
)
parser.add_argument(
"--max-parallel",
type=int,
required=False,
default=3,
help="Maximum number of concurrent downloads",
)
parser.add_argument(
"--ignore-patterns",
type=str,
required=False,
default="*.safetensors",
help="""For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
safetensors files to avoid downloading duplicate weights.
""",
)
parser.add_argument(
"--manifest-file",
type=str,
help="For source=meta, you can download models from a manifest file containing a file => URL mapping",
required=False,
)
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
@dataclass
class DownloadTask:
url: str
output_file: str
total_size: int = 0
downloaded_size: int = 0
task_id: int | None = None
retries: int = 0
max_retries: int = 3
class DownloadError(Exception):
pass
class CustomTransferSpeedColumn(TransferSpeedColumn):
def render(self, task):
if task.finished:
return "-"
return super().render(task)
class ParallelDownloader:
def __init__(
self,
max_concurrent_downloads: int = 3,
buffer_size: int = 1024 * 1024,
timeout: int = 30,
):
self.max_concurrent_downloads = max_concurrent_downloads
self.buffer_size = buffer_size
self.timeout = timeout
self.console = Console()
self.progress = Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(bar_width=40),
"[progress.percentage]{task.percentage:>3.1f}%",
DownloadColumn(),
CustomTransferSpeedColumn(),
TimeRemainingColumn(),
console=self.console,
expand=True,
)
self.client_options = {
"timeout": httpx.Timeout(timeout),
"follow_redirects": True,
}
async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
last_exception = None
for attempt in range(task.max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < task.max_retries - 1:
wait_time = min(30, 2**attempt) # Cap at 30 seconds
self.console.print(
f"[yellow]Attempt {attempt + 1}/{task.max_retries} failed, "
f"retrying in {wait_time} seconds: {str(e)}[/yellow]"
)
await asyncio.sleep(wait_time)
continue
raise last_exception
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
if task.total_size > 0:
self.progress.update(task.task_id, total=task.total_size)
return
async def _get_info():
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
response.raise_for_status()
return response
try:
response = await self.retry_with_exponential_backoff(task, _get_info)
task.url = str(response.url)
task.total_size = int(response.headers.get("Content-Length", 0))
if task.total_size == 0:
raise DownloadError(
f"Unable to determine file size for {task.output_file}. "
"The server might not support range requests."
)
# Update the progress bar's total size once we know it
if task.task_id is not None:
self.progress.update(task.task_id, total=task.total_size)
except httpx.HTTPError as e:
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
raise
def verify_file_integrity(self, task: DownloadTask) -> bool:
if not os.path.exists(task.output_file):
return False
return os.path.getsize(task.output_file) == task.total_size
async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
async def _download_chunk():
headers = {"Range": f"bytes={start}-{end}"}
async with client.stream("GET", task.url, headers=headers, **self.client_options) as response:
response.raise_for_status()
with open(task.output_file, "ab") as file:
file.seek(start)
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
task.downloaded_size += len(chunk)
self.progress.update(
task.task_id,
completed=task.downloaded_size,
)
try:
await self.retry_with_exponential_backoff(task, _download_chunk)
except Exception as e:
raise DownloadError(
f"Failed to download chunk {start}-{end} after {task.max_retries} attempts: {str(e)}"
) from e
async def prepare_download(self, task: DownloadTask) -> None:
output_dir = os.path.dirname(task.output_file)
os.makedirs(output_dir, exist_ok=True)
if os.path.exists(task.output_file):
task.downloaded_size = os.path.getsize(task.output_file)
async def download_file(self, task: DownloadTask) -> None:
try:
async with httpx.AsyncClient(**self.client_options) as client:
await self.get_file_info(client, task)
# Check if file is already downloaded
if os.path.exists(task.output_file):
if self.verify_file_integrity(task):
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
self.progress.update(task.task_id, completed=task.total_size)
return
await self.prepare_download(task)
try:
# Split the remaining download into chunks
chunk_size = 27_000_000_000 # Cloudfront max chunk size
chunks = []
current_pos = task.downloaded_size
while current_pos < task.total_size:
chunk_end = min(current_pos + chunk_size - 1, task.total_size - 1)
chunks.append((current_pos, chunk_end))
current_pos = chunk_end + 1
# Download chunks in sequence
for chunk_start, chunk_end in chunks:
await self.download_chunk(client, task, chunk_start, chunk_end)
except Exception as e:
raise DownloadError(f"Download failed: {str(e)}") from e
except Exception as e:
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
def has_disk_space(self, tasks: list[DownloadTask]) -> bool:
try:
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
free_space = shutil.disk_usage(dir_path).free
# Add 10% buffer for safety
required_space = int(total_remaining_size * 1.1)
if free_space < required_space:
self.console.print(
f"[red]Not enough disk space. Required: {required_space // (1024 * 1024)} MB, "
f"Available: {free_space // (1024 * 1024)} MB[/red]"
)
return False
return True
except Exception as e:
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
async def download_all(self, tasks: list[DownloadTask]) -> None:
if not tasks:
raise ValueError("No download tasks provided")
if not os.environ.get("LLAMA_DOWNLOAD_NO_SPACE_CHECK") and not self.has_disk_space(tasks):
raise DownloadError("Insufficient disk space for downloads")
failed_tasks = []
with self.progress:
for task in tasks:
desc = f"Downloading {Path(task.output_file).name}"
task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
async def download_with_semaphore(task: DownloadTask):
async with semaphore:
try:
await self.download_file(task)
except Exception as e:
failed_tasks.append((task, str(e)))
await asyncio.gather(*(download_with_semaphore(task) for task in tasks))
if failed_tasks:
self.console.print("\n[red]Some downloads failed:[/red]")
for task, error in failed_tasks:
self.console.print(f"[red]- {Path(task.output_file).name}: {error}[/red]")
raise DownloadError(f"{len(failed_tasks)} downloads failed")
def _hf_download(
model: "Model",
hf_token: str,
ignore_patterns: str,
parser: argparse.ArgumentParser,
):
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_stack.core.utils.model_utils import model_local_dir
repo_id = model.huggingface_repo
if repo_id is None:
raise ValueError(f"No repo id found for model {model.descriptor()}")
output_dir = model_local_dir(model.descriptor())
os.makedirs(output_dir, exist_ok=True)
try:
true_output_dir = snapshot_download(
repo_id,
local_dir=output_dir,
ignore_patterns=ignore_patterns,
token=hf_token,
library_name="llama-stack",
)
except GatedRepoError:
parser.error(
"It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository and have provided the proper Hugging Face API token "
"using the option `--hf-token` or by running `huggingface-cli login`."
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.")
except Exception as e:
parser.error(e)
print(f"\nSuccessfully downloaded model to {true_output_dir}")
def _meta_download(
model: "Model",
model_id: str,
meta_url: str,
info: "LlamaDownloadInfo",
max_concurrent_downloads: int,
):
from llama_stack.core.utils.model_utils import model_local_dir
output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True)
# Create download tasks for each file
tasks = []
for f in info.files:
output_file = str(output_dir / f)
url = meta_url.replace("*", f"{info.folder}/{f}")
total_size = info.pth_size if "consolidated" in f else 0
tasks.append(DownloadTask(url=url, output_file=output_file, total_size=total_size, max_retries=3))
# Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))
cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr)
cprint(
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
file=sys.stderr,
)
cprint(
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
color="yellow",
file=sys.stderr,
)
class ModelEntry(BaseModel):
model_id: str
files: dict[str, str]
model_config = ConfigDict(protected_namespaces=())
class Manifest(BaseModel):
models: list[ModelEntry]
expires_on: datetime
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
from llama_stack.core.utils.model_utils import model_local_dir
with open(manifest_file) as f:
d = json.load(f)
manifest = Manifest(**d)
if datetime.now(UTC) > manifest.expires_on.astimezone(UTC):
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console()
for entry in manifest.models:
console.print(f"[blue]Downloading model {entry.model_id}...[/blue]")
output_dir = Path(model_local_dir(entry.model_id))
os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()):
console.print(f"[yellow]Output directory {output_dir} is not empty.[/yellow]")
while True:
resp = input("Do you want to (C)ontinue download or (R)estart completely? (continue/restart): ")
if resp.lower() in ["restart", "r"]:
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
break
elif resp.lower() in ["continue", "c"]:
console.print("[blue]Continuing download...[/blue]")
break
else:
console.print("[red]Invalid response. Please try again.[/red]")
# Create download tasks for all files in the manifest
tasks = [
DownloadTask(url=url, output_file=str(output_dir / fname), max_retries=3)
for fname, url in entry.files.items()
]
# Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
"""Main download command handler"""
try:
if args.manifest_file:
_download_from_manifest(args.manifest_file, args.max_parallel)
return
if args.model_id is None:
parser.error("Please provide a model id")
return
# Handle comma-separated model IDs
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
from .model.safety_models import (
prompt_guard_download_info_map,
prompt_guard_model_sku_map,
)
prompt_guard_model_sku_map = prompt_guard_model_sku_map()
prompt_guard_download_info_map = prompt_guard_download_info_map()
for model_id in model_ids:
if model_id in prompt_guard_model_sku_map.keys():
model = prompt_guard_model_sku_map[model_id]
info = prompt_guard_download_info_map[model_id]
else:
model = resolve_model(model_id)
if model is None:
parser.error(f"Model {model_id} not found")
continue
info = llama_meta_net_info(model)
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url or input(
f"Please provide the signed URL for model {model_id} you received via email "
f"after visiting https://www.llama.com/llama-downloads/ "
f"(e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
if "llamameta.net" not in meta_url:
parser.error("Invalid Meta URL provided")
_meta_download(model, model_id, meta_url, info, args.max_parallel)
except Exception as e:
parser.error(f"Download failed: {str(e)}")

View file

@ -6,11 +6,10 @@
import argparse
from .download import Download
from .model import ModelParser
from llama_stack.log import setup_logging
from .stack import StackParser
from .stack.utils import print_subcommand_description
from .verify_download import VerifyDownload
class LlamaCLIParser:
@ -30,10 +29,7 @@ class LlamaCLIParser:
subparsers = self.parser.add_subparsers(title="subcommands")
# Add sub-commands
ModelParser.create(subparsers)
StackParser.create(subparsers)
Download.create(subparsers)
VerifyDownload.create(subparsers)
print_subcommand_description(self.parser, subparsers)
@ -48,6 +44,9 @@ class LlamaCLIParser:
def main():
# Initialize logging from environment variables before any other operations
setup_logging()
parser = LlamaCLIParser()
args = parser.parse_args()
parser.run(args)

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .model import ModelParser # noqa

View file

@ -1,70 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.sku_list import resolve_model
class ModelDescribe(Subcommand):
"""Show details about a model"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"describe",
prog="llama model describe",
description="Show details about a llama model",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_describe_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model-id",
type=str,
required=True,
help="See `llama model list` or `llama model list --show-all` for the list of available models",
)
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku_map
prompt_guard_model_map = prompt_guard_model_sku_map()
if args.model_id in prompt_guard_model_map.keys():
model = prompt_guard_model_map[args.model_id]
else:
model = resolve_model(args.model_id)
if model is None:
self.parser.error(
f"Model {args.model_id} not found; try 'llama model list' for a list of available models."
)
return
headers = [
"Model",
model.descriptor(),
]
rows = [
("Hugging Face ID", model.huggingface_repo or "<Not Available>"),
("Description", model.description),
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
("Weights format", model.quantization_format.value),
("Model params.json", json.dumps(model.arch_args, indent=4)),
]
print_table(
rows,
headers,
separate_rows=True,
)

View file

@ -1,24 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class ModelDownload(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"download",
prog="llama model download",
description="Download a model from llama.meta.com or Hugging Face Hub",
formatter_class=argparse.RawTextHelpFormatter,
)
from llama_stack.cli.download import setup_download_parser
setup_download_parser(self.parser)

View file

@ -1,119 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import time
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.models.llama.sku_list import all_registered_models
def _get_model_size(model_dir):
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
def _convert_to_model_descriptor(model):
for m in all_registered_models():
if model == m.descriptor().replace(":", "-"):
return str(m.descriptor())
return str(model)
def _run_model_list_downloaded_cmd() -> None:
headers = ["Model", "Size", "Modified Time"]
rows = []
for model in os.listdir(DEFAULT_CHECKPOINT_DIR):
abs_path = os.path.join(DEFAULT_CHECKPOINT_DIR, model)
space_usage = _get_model_size(abs_path)
model_size = f"{space_usage / (1024**3):.2f} GB"
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
rows.append(
[
_convert_to_model_descriptor(model),
model_size,
modified_time,
]
)
print_table(
rows,
headers,
separate_rows=True,
)
class ModelList(Subcommand):
"""List available llama models"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama model list",
description="Show available llama models",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_list_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--show-all",
action="store_true",
help="Show all models (not just defaults)",
)
self.parser.add_argument(
"--downloaded",
action="store_true",
help="List the downloaded models",
)
self.parser.add_argument(
"-s",
"--search",
type=str,
required=False,
help="Search for the input string as a substring in the model descriptor(ID)",
)
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_skus
if args.downloaded:
return _run_model_list_downloaded_cmd()
headers = [
"Model Descriptor(ID)",
"Hugging Face Repo",
"Context Length",
]
rows = []
for model in all_registered_models() + prompt_guard_model_skus():
if not args.show_all and not model.is_featured:
continue
descriptor = model.descriptor()
if not args.search or args.search.lower() in descriptor.lower():
rows.append(
[
descriptor,
model.huggingface_repo,
f"{model.max_seq_length // 1024}K",
]
)
if len(rows) == 0:
print(f"Did not find any model matching `{args.search}`.")
else:
print_table(
rows,
headers,
separate_rows=True,
)

View file

@ -1,43 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.model.describe import ModelDescribe
from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.remove import ModelRemove
from llama_stack.cli.model.verify_download import ModelVerifyDownload
from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand
class ModelParser(Subcommand):
"""Llama cli for model interface apis"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"model",
prog="llama model",
description="Work with llama models",
formatter_class=argparse.RawTextHelpFormatter,
)
self.parser.set_defaults(func=lambda args: self.parser.print_help())
subparsers = self.parser.add_subparsers(title="model_subcommands")
# Add sub-commands
ModelDownload.create(subparsers)
ModelList.create(subparsers)
ModelPromptFormat.create(subparsers)
ModelDescribe.create(subparsers)
ModelVerifyDownload.create(subparsers)
ModelRemove.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -1,133 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from io import StringIO
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
ROOT_DIR = Path(__file__).parent.parent.parent
class ModelPromptFormat(Subcommand):
"""Llama model cli for describe a model prompt format (message formats)"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"prompt-format",
prog="llama model prompt-format",
description="Show llama model message formats",
epilog=textwrap.dedent(
"""
Example:
llama model prompt-format <options>
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_template_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model-name",
type=str,
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
"(Run `llama model list` to see a list of valid model names)",
)
self.parser.add_argument(
"-l",
"--list",
action="store_true",
help="List all available models",
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
import importlib.resources
# Only Llama 3.1 and 3.2 are supported
supported_model_ids = [
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
]
model_list = [m.value for m in supported_model_ids]
if args.list:
headers = ["Model(s)"]
rows = []
for m in model_list:
rows.append(
[
m,
]
)
print_table(
rows,
headers,
separate_rows=True,
)
return
try:
model_id = CoreModelId(args.model_name)
except ValueError:
self.parser.error(
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
f"Run `llama model list` to see the valid model names."
)
if model_id not in supported_model_ids:
self.parser.error(
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
f"Run `llama model list` to see the valid model names."
)
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
llama_3_2_vision_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "vision_prompt_format.md"
if model_family(model_id) == ModelFamily.llama3_1:
with importlib.resources.as_file(llama_3_1_file) as f:
content = f.open("r").read()
elif model_family(model_id) == ModelFamily.llama3_2:
if is_multimodal(model_id):
with importlib.resources.as_file(llama_3_2_vision_file) as f:
content = f.open("r").read()
else:
with importlib.resources.as_file(llama_3_2_text_file) as f:
content = f.open("r").read()
render_markdown_to_pager(content)
def render_markdown_to_pager(markdown_content: str):
from rich.console import Console
from rich.markdown import Markdown
from rich.style import Style
from rich.text import Text
class LeftAlignedHeaderMarkdown(Markdown):
def parse_header(self, token):
level = token.type.count("h")
content = Text(token.content)
header_style = Style(color="bright_blue", bold=True)
header = Text(f"{'#' * level} ", style=header_style) + content
self.add_text(header)
# Render the Markdown
md = LeftAlignedHeaderMarkdown(markdown_content)
# Capture the rendered output
output = StringIO()
console = Console(file=output, force_terminal=True, width=100) # Set a fixed width
console.print(md)
rendered_content = output.getvalue()
print(rendered_content)

View file

@ -1,68 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import shutil
from llama_stack.cli.subcommand import Subcommand
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.models.llama.sku_list import resolve_model
class ModelRemove(Subcommand):
"""Remove the downloaded llama model"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"remove",
prog="llama model remove",
description="Remove the downloaded llama model",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_remove_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model",
required=True,
help="Specify the llama downloaded model name, see `llama model list --downloaded`",
)
self.parser.add_argument(
"-f",
"--force",
action="store_true",
help="Used to forcefully remove the llama model from the storage without further confirmation",
)
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku_map
prompt_guard_model_map = prompt_guard_model_sku_map()
if args.model in prompt_guard_model_map.keys():
model = prompt_guard_model_map[args.model]
else:
model = resolve_model(args.model)
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model.replace(":", "-"))
if model is None or not os.path.isdir(model_path):
print(f"'{args.model}' is not a valid llama model or does not exist.")
return
if args.force:
shutil.rmtree(model_path)
print(f"{args.model} removed.")
else:
if input(f"Are you sure you want to remove {args.model}? (y/n): ").strip().lower() == "y":
shutil.rmtree(model_path)
print(f"{args.model} removed.")
else:
print("Removal aborted.")

View file

@ -1,64 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
class PromptGuardModel(BaseModel):
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
model_id: str
huggingface_repo: str
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
is_featured: bool = False
max_seq_length: int = 512
is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
arch_args: dict[str, Any] = Field(default_factory=dict)
def descriptor(self) -> str:
return self.model_id
model_config = ConfigDict(protected_namespaces=())
def prompt_guard_model_skus():
return [
PromptGuardModel(model_id="Prompt-Guard-86M", huggingface_repo="meta-llama/Prompt-Guard-86M"),
PromptGuardModel(
model_id="Llama-Prompt-Guard-2-86M",
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-86M",
),
PromptGuardModel(
model_id="Llama-Prompt-Guard-2-22M",
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-22M",
),
]
def prompt_guard_model_sku_map() -> dict[str, Any]:
return {model.model_id: model for model in prompt_guard_model_skus()}
def prompt_guard_download_info_map() -> dict[str, LlamaDownloadInfo]:
return {
model.model_id: LlamaDownloadInfo(
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
files=[
"model.safetensors",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer_config.json",
],
pth_size=1,
)
for model in prompt_guard_model_skus()
}

View file

@ -1,24 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class ModelVerifyDownload(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"verify-download",
prog="llama model verify-download",
description="Verify the downloaded checkpoints' checksums for models downloaded from Meta",
formatter_class=argparse.RawTextHelpFormatter,
)
from llama_stack.cli.verify_download import setup_verify_download_parser
setup_verify_download_parser(self.parser)

View file

@ -1,490 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import importlib.resources
import json
import os
import shutil
import sys
import textwrap
from functools import lru_cache
from importlib.abc import Traversable
from pathlib import Path
import yaml
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.validation import Validator
from termcolor import colored, cprint
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.table import print_table
from llama_stack.core.build import (
SERVER_DEPENDENCIES,
build_image,
get_provider_dependencies,
)
from llama_stack.core.configure import parse_and_maybe_upgrade_config
from llama_stack.core.datatypes import (
BuildConfig,
BuildProvider,
DistributionSpec,
Provider,
StackRunConfig,
)
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.external import load_external_apis
from llama_stack.core.resolver import InvalidProviderError
from llama_stack.core.stack import replace_env_vars
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.exec import formulate_run_args, run_command
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
@lru_cache
def available_distros_specs() -> dict[str, BuildConfig]:
import yaml
distro_specs = {}
for p in DISTRIBS_PATH.rglob("*build.yaml"):
distro_name = p.parent.name
with open(p) as f:
build_config = BuildConfig(**yaml.safe_load(f))
distro_specs[distro_name] = build_config
return distro_specs
def run_stack_build_command(args: argparse.Namespace) -> None:
if args.list_distros:
return _run_distro_list_cmd()
if args.image_type == ImageType.VENV.value:
current_venv = os.environ.get("VIRTUAL_ENV")
image_name = args.image_name or current_venv
else:
image_name = args.image_name
if args.template:
cprint(
"The --template argument is deprecated. Please use --distro instead.",
color="red",
file=sys.stderr,
)
distro_name = args.template
else:
distro_name = args.distribution
if distro_name:
available_distros = available_distros_specs()
if distro_name not in available_distros:
cprint(
f"Could not find distribution {distro_name}. Please run `llama stack build --list-distros` to check out the available distributions",
color="red",
file=sys.stderr,
)
sys.exit(1)
build_config = available_distros[distro_name]
if args.image_type:
build_config.image_type = args.image_type
else:
cprint(
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {distro_name}",
color="red",
file=sys.stderr,
)
sys.exit(1)
elif args.providers:
provider_list: dict[str, list[BuildProvider]] = dict()
for api_provider in args.providers.split(","):
if "=" not in api_provider:
cprint(
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
color="red",
file=sys.stderr,
)
sys.exit(1)
api, provider_type = api_provider.split("=")
providers_for_api = get_provider_registry().get(Api(api), None)
if providers_for_api is None:
cprint(
f"{api} is not a valid API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
if provider_type in providers_for_api:
provider = BuildProvider(
provider_type=provider_type,
module=None,
)
provider_list.setdefault(api, []).append(provider)
else:
cprint(
f"{provider} is not a valid provider for the {api} API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
distribution_spec = DistributionSpec(
providers=provider_list,
description=",".join(args.providers),
)
if not args.image_type:
cprint(
f"Please specify a image-type (container | venv) for {args.template}",
color="red",
file=sys.stderr,
)
sys.exit(1)
build_config = BuildConfig(image_type=args.image_type, distribution_spec=distribution_spec)
elif not args.config and not distro_name:
name = prompt(
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
validator=Validator.from_callable(
lambda x: len(x) > 0,
error_message="Name cannot be empty, please enter a name",
),
)
image_type = prompt(
"> Enter the image type you want your Llama Stack to be built as (use <TAB> to see options): ",
completer=WordCompleter([e.value for e in ImageType]),
complete_while_typing=True,
validator=Validator.from_callable(
lambda x: x in [e.value for e in ImageType],
error_message="Invalid image type. Use <TAB> to see options",
),
)
image_name = f"llamastack-{name}"
cprint(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs.
""",
),
color="green",
file=sys.stderr,
)
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
providers: dict[str, list[BuildProvider]] = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
if not available_providers:
continue
api_provider = prompt(
f"> Enter provider for API {api.value}: ",
completer=WordCompleter(available_providers),
complete_while_typing=True,
validator=Validator.from_callable(
lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
error_message="Invalid provider, use <TAB> to see options",
),
)
string_providers = api_provider.split(" ")
for provider in string_providers:
providers.setdefault(api.value, []).append(BuildProvider(provider_type=provider))
description = prompt(
"\n > (Optional) Enter a short description for your Llama Stack: ",
default="",
)
distribution_spec = DistributionSpec(
providers=providers,
description=description,
)
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
else:
with open(args.config) as f:
try:
contents = yaml.safe_load(f)
contents = replace_env_vars(contents)
build_config = BuildConfig(**contents)
if args.image_type:
build_config.image_type = args.image_type
except Exception as e:
cprint(
f"Could not parse config file {args.config}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(1)
if args.print_deps_only:
print(f"# Dependencies for {distro_name or args.config or image_name}")
normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
print(f"uv pip install {' '.join(normal_deps)}")
for special_dep in special_deps:
print(f"uv pip install {special_dep}")
for external_dep in external_provider_dependencies:
print(f"uv pip install {external_dep}")
return
try:
run_config = _run_stack_build_command_from_build_config(
build_config,
image_name=image_name,
config_path=args.config,
distro_name=distro_name,
)
except (Exception, RuntimeError) as exc:
import traceback
cprint(
f"Error building stack: {exc}",
color="red",
file=sys.stderr,
)
cprint("Stack trace:", color="red", file=sys.stderr)
traceback.print_exc()
sys.exit(1)
if run_config is None:
cprint(
"Run config path is empty",
color="red",
file=sys.stderr,
)
sys.exit(1)
if args.run:
config_dict = yaml.safe_load(run_config.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if config.external_providers_dir and not config.external_providers_dir.exists():
config.external_providers_dir.mkdir(exist_ok=True)
run_args = formulate_run_args(args.image_type, image_name or config.image_name)
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)])
run_command(run_args)
def _generate_run_config(
build_config: BuildConfig,
build_dir: Path,
image_name: str,
) -> Path:
"""
Generate a run.yaml template file for user to edit from a build.yaml file
"""
apis = list(build_config.distribution_spec.providers.keys())
run_config = StackRunConfig(
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
image_name=image_name,
apis=apis,
providers={},
external_providers_dir=build_config.external_providers_dir
if build_config.external_providers_dir
else EXTERNAL_PROVIDERS_DIR,
)
if not run_config.inference_store:
run_config.inference_store = SqliteSqlStoreConfig(
**SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=(DISTRIBS_BASE_DIR / image_name).as_posix(), db_name="inference_store.db"
)
)
# build providers dict
provider_registry = get_provider_registry(build_config)
for api in apis:
run_config.providers[api] = []
providers = build_config.distribution_spec.providers[api]
for provider in providers:
pid = provider.provider_type.split("::")[-1]
p = provider_registry[Api(api)][provider.provider_type]
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)
try:
config_type = instantiate_class_type(provider_registry[Api(api)][provider.provider_type].config_class)
except (ModuleNotFoundError, ValueError) as exc:
# HACK ALERT:
# This code executes after building is done, the import cannot work since the
# package is either available in the venv or container - not available on the host.
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
# external
cprint(
f"Failed to import provider {provider.provider_type} for API {api} - assuming it's external, skipping: {exc}",
color="yellow",
file=sys.stderr,
)
# Set config_type to None to avoid UnboundLocalError
config_type = None
if config_type is not None and hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
else:
config = {}
p_spec = Provider(
provider_id=pid,
provider_type=provider.provider_type,
config=config,
module=provider.module,
)
run_config.providers[api].append(p_spec)
run_config_file = build_dir / f"{image_name}-run.yaml"
with open(run_config_file, "w") as f:
to_write = json.loads(run_config.model_dump_json())
f.write(yaml.dump(to_write, sort_keys=False))
# Only print this message for non-container builds since it will be displayed before the
# container is built
# For non-container builds, the run.yaml is generated at the very end of the build process so it
# makes sense to display this message
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr)
return run_config_file
def _run_stack_build_command_from_build_config(
build_config: BuildConfig,
image_name: str | None = None,
distro_name: str | None = None,
config_path: str | None = None,
) -> Path | Traversable:
image_name = image_name or build_config.image_name
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
if distro_name:
image_name = f"distribution-{distro_name}"
else:
if not image_name:
raise ValueError("Please specify an image name when building a container image without a template")
else:
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
image_name = "__system__"
if not image_name:
raise ValueError("Please specify an image name when building a venv image")
# At this point, image_name should be guaranteed to be a string
if image_name is None:
raise ValueError("image_name should not be None after validation")
if distro_name:
build_dir = DISTRIBS_BASE_DIR / distro_name
build_file_path = build_dir / f"{distro_name}-build.yaml"
else:
if image_name is None:
raise ValueError("image_name cannot be None")
build_dir = DISTRIBS_BASE_DIR / image_name
build_file_path = build_dir / f"{image_name}-build.yaml"
os.makedirs(build_dir, exist_ok=True)
run_config_file = None
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
# Only do this if we're building a container image and we're not using a template
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not distro_name and config_path:
cprint("Generating run.yaml file", color="yellow", file=sys.stderr)
run_config_file = _generate_run_config(build_config, build_dir, image_name)
with open(build_file_path, "w") as f:
to_write = json.loads(build_config.model_dump_json(exclude_none=True))
f.write(yaml.dump(to_write, sort_keys=False))
# We first install the external APIs so that the build process can use them and discover the
# providers dependencies
if build_config.external_apis_dir:
cprint("Installing external APIs", color="yellow", file=sys.stderr)
external_apis = load_external_apis(build_config)
if external_apis:
# install the external APIs
packages = []
for _, api_spec in external_apis.items():
if api_spec.pip_packages:
packages.extend(api_spec.pip_packages)
cprint(
f"Installing {api_spec.name} with pip packages {api_spec.pip_packages}",
color="yellow",
file=sys.stderr,
)
return_code = run_command(["uv", "pip", "install", *packages])
if return_code != 0:
packages_str = ", ".join(packages)
raise RuntimeError(
f"Failed to install external APIs packages: {packages_str} (return code: {return_code})"
)
return_code = build_image(
build_config,
image_name,
distro_or_config=distro_name or config_path or str(build_file_path),
run_config=run_config_file.as_posix() if run_config_file else None,
)
if return_code != 0:
raise RuntimeError(f"Failed to build image {image_name}")
if distro_name:
# copy run.yaml from distribution to build_dir instead of generating it again
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
run_config_file = build_dir / f"{distro_name}-run.yaml"
with importlib.resources.as_file(distro_path) as path:
shutil.copy(path, run_config_file)
cprint("Build Successful!", color="green", file=sys.stderr)
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
if build_config.image_type == LlamaStackImageType.VENV:
cprint(
"You can run the new Llama Stack distro (after activating "
+ colored(image_name, "cyan")
+ ") via: "
+ colored(f"llama stack run {run_config_file}", "blue"),
color="green",
file=sys.stderr,
)
elif build_config.image_type == LlamaStackImageType.CONTAINER:
cprint(
"You can run the container with: "
+ colored(
f"docker run -p 8321:8321 -v ~/.llama:/root/.llama localhost/{image_name} --port 8321", "blue"
),
color="green",
file=sys.stderr,
)
return distro_path
else:
return _generate_run_config(build_config, build_dir, image_name)
def _run_distro_list_cmd() -> None:
headers = [
"Distribution Name",
# "Providers",
"Description",
]
rows = []
for distro_name, spec in available_distros_specs().items():
rows.append(
[
distro_name,
# json.dumps(spec.distribution_spec.providers, indent=2),
spec.distribution_spec.description,
]
)
print_table(
rows,
headers,
separate_rows=True,
)

View file

@ -0,0 +1,182 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import sys
from pathlib import Path
import yaml
from termcolor import cprint
from llama_stack.cli.stack.utils import ImageType
from llama_stack.core.build import get_provider_dependencies
from llama_stack.core.datatypes import (
BuildConfig,
BuildProvider,
DistributionSpec,
)
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.stack import replace_env_vars
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
logger = get_logger(name=__name__, category="cli")
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"aiosqlite",
"fastapi",
"fire",
"httpx",
"uvicorn",
"opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http",
]
def format_output_deps_only(
normal_deps: list[str],
special_deps: list[str],
external_deps: list[str],
uv: bool = False,
) -> str:
"""Format dependencies as a list."""
lines = []
uv_str = ""
if uv:
uv_str = "uv pip install "
# Quote deps with commas
quoted_normal_deps = [quote_if_needed(dep) for dep in normal_deps]
lines.append(f"{uv_str}{' '.join(quoted_normal_deps)}")
for special_dep in special_deps:
lines.append(f"{uv_str}{quote_special_dep(special_dep)}")
for external_dep in external_deps:
lines.append(f"{uv_str}{quote_special_dep(external_dep)}")
return "\n".join(lines)
def run_stack_list_deps_command(args: argparse.Namespace) -> None:
if args.config:
try:
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
config_file = resolve_config_or_distro(args.config, Mode.BUILD)
except ValueError as e:
cprint(
f"Could not parse config file {args.config}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(1)
if config_file:
with open(config_file) as f:
try:
contents = yaml.safe_load(f)
contents = replace_env_vars(contents)
build_config = BuildConfig(**contents)
build_config.image_type = "venv"
except Exception as e:
cprint(
f"Could not parse config file {config_file}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(1)
elif args.providers:
provider_list: dict[str, list[BuildProvider]] = dict()
for api_provider in args.providers.split(","):
if "=" not in api_provider:
cprint(
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
color="red",
file=sys.stderr,
)
sys.exit(1)
api, provider_type = api_provider.split("=")
providers_for_api = get_provider_registry().get(Api(api), None)
if providers_for_api is None:
cprint(
f"{api} is not a valid API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
if provider_type in providers_for_api:
provider = BuildProvider(
provider_type=provider_type,
module=None,
)
provider_list.setdefault(api, []).append(provider)
else:
cprint(
f"{provider_type} is not a valid provider for the {api} API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
distribution_spec = DistributionSpec(
providers=provider_list,
description=",".join(args.providers),
)
build_config = BuildConfig(image_type=ImageType.VENV.value, distribution_spec=distribution_spec)
normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
# Add external API dependencies
if build_config.external_apis_dir:
from llama_stack.core.external import load_external_apis
external_apis = load_external_apis(build_config)
if external_apis:
for _, api_spec in external_apis.items():
normal_deps.extend(api_spec.pip_packages)
# Format and output based on requested format
output = format_output_deps_only(
normal_deps=normal_deps,
special_deps=special_deps,
external_deps=external_provider_dependencies,
uv=args.format == "uv",
)
print(output)
def quote_if_needed(dep):
# Add quotes if the dependency contains special characters that need escaping in shell
# This includes: commas, comparison operators (<, >, <=, >=, ==, !=)
needs_quoting = any(char in dep for char in [",", "<", ">", "="])
return f"'{dep}'" if needs_quoting else dep
def quote_special_dep(dep_string):
"""
Quote individual packages in a special dependency string.
Special deps may contain multiple packages and flags like --extra-index-url.
We need to quote only the package specs that contain special characters.
"""
parts = dep_string.split()
quoted_parts = []
for part in parts:
# Don't quote flags (they start with -)
if part.startswith("-"):
quoted_parts.append(part)
else:
# Quote package specs that need it
quoted_parts.append(quote_if_needed(part))
return " ".join(quoted_parts)

View file

@ -1,100 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.subcommand import Subcommand
class StackBuild(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"build",
prog="llama stack build",
description="Build a Llama stack container",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_build_command)
def _add_arguments(self):
self.parser.add_argument(
"--config",
type=str,
default=None,
help="Path to a config file to use for the build. You can find example configs in llama_stack.cores/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
)
self.parser.add_argument(
"--template",
type=str,
default=None,
help="""(deprecated) Name of the example template config to use for build. You may use `llama stack build --list-distros` to check out the available distributions""",
)
self.parser.add_argument(
"--distro",
"--distribution",
dest="distribution",
type=str,
default=None,
help="""Name of the distribution to use for build. You may use `llama stack build --list-distros` to check out the available distributions""",
)
self.parser.add_argument(
"--list-distros",
"--list-distributions",
action="store_true",
dest="list_distros",
default=False,
help="Show the available distributions for building a Llama Stack distribution",
)
self.parser.add_argument(
"--image-type",
type=str,
help="Image Type to use for the build. If not specified, will use the image type from the template config.",
choices=[e.value for e in ImageType],
default=None, # no default so we can detect if a user specified --image-type and override image_type in the config
)
self.parser.add_argument(
"--image-name",
type=str,
help=textwrap.dedent(
f"""[for image-type={"|".join(e.value for e in ImageType)}] Name of the virtual environment to use for
the build. If not specified, currently active environment will be used if found.
"""
),
default=None,
)
self.parser.add_argument(
"--print-deps-only",
default=False,
action="store_true",
help="Print the dependencies for the stack only, without building the stack",
)
self.parser.add_argument(
"--run",
action="store_true",
default=False,
help="Run the stack after building using the same image type, name, and other applicable arguments",
)
self.parser.add_argument(
"--providers",
type=str,
default=None,
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
# always keep implementation completely silo-ed away from CLI so CLI
# can be fast to load and reduces dependencies
from ._build import run_stack_build_command
return run_stack_build_command(args)

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class StackListDeps(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list-deps",
prog="llama stack list-deps",
description="list the dependencies for a llama stack distribution",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_list_deps_command)
def _add_arguments(self):
self.parser.add_argument(
"config",
type=str,
nargs="?", # Make it optional
metavar="config | distro",
help="Path to config file to use or name of known distro (llama stack list for a list).",
)
self.parser.add_argument(
"--providers",
type=str,
default=None,
help="sync dependencies for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
)
self.parser.add_argument(
"--format",
type=str,
choices=["uv", "deps-only"],
default="deps-only",
help="Output format: 'uv' shows shell commands, 'deps-only' shows just the list of dependencies without `uv` (default)",
)
def _run_stack_list_deps_command(self, args: argparse.Namespace) -> None:
# always keep implementation completely silo-ed away from CLI so CLI
# can be fast to load and reduces dependencies
from ._list_deps import run_stack_list_deps_command
return run_stack_list_deps_command(args)

View file

@ -15,10 +15,10 @@ import yaml
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.subcommand import Subcommand
from llama_stack.core.datatypes import LoggingConfig, StackRunConfig
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.log import get_logger
from llama_stack.log import LoggingConfig, get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent

View file

@ -11,8 +11,8 @@ from llama_stack.cli.stack.list_stacks import StackListBuilds
from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand
from .build import StackBuild
from .list_apis import StackListApis
from .list_deps import StackListDeps
from .list_providers import StackListProviders
from .remove import StackRemove
from .run import StackRun
@ -39,7 +39,7 @@ class StackParser(Subcommand):
subparsers = self.parser.add_subparsers(title="stack_subcommands")
# Add sub-commands
StackBuild.create(subparsers)
StackListDeps.create(subparsers)
StackListApis.create(subparsers)
StackListProviders.create(subparsers)
StackRun.create(subparsers)

View file

@ -4,7 +4,37 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import sys
from enum import Enum
from functools import lru_cache
from pathlib import Path
import yaml
from termcolor import cprint
from llama_stack.core.datatypes import (
BuildConfig,
Provider,
StackRunConfig,
StorageConfig,
)
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.resolver import InvalidProviderError
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
)
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "distributions"
class ImageType(Enum):
@ -19,3 +49,103 @@ def print_subcommand_description(parser, subparsers):
description = subcommand.description
description_text += f" {name:<21} {description}\n"
parser.epilog = description_text
def generate_run_config(
build_config: BuildConfig,
build_dir: Path,
image_name: str,
) -> Path:
"""
Generate a run.yaml template file for user to edit from a build.yaml file
"""
apis = list(build_config.distribution_spec.providers.keys())
distro_dir = DISTRIBS_BASE_DIR / image_name
run_config = StackRunConfig(
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
image_name=image_name,
apis=apis,
providers={},
storage=StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(db_path=str(distro_dir / "kvstore.db")),
"sql_default": SqliteSqlStoreConfig(db_path=str(distro_dir / "sql_store.db")),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
conversations=SqlStoreReference(backend="sql_default", table_name="openai_conversations"),
),
),
external_providers_dir=build_config.external_providers_dir
if build_config.external_providers_dir
else EXTERNAL_PROVIDERS_DIR,
)
# build providers dict
provider_registry = get_provider_registry(build_config)
for api in apis:
run_config.providers[api] = []
providers = build_config.distribution_spec.providers[api]
for provider in providers:
pid = provider.provider_type.split("::")[-1]
p = provider_registry[Api(api)][provider.provider_type]
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)
try:
config_type = instantiate_class_type(provider_registry[Api(api)][provider.provider_type].config_class)
except (ModuleNotFoundError, ValueError) as exc:
# HACK ALERT:
# This code executes after building is done, the import cannot work since the
# package is either available in the venv or container - not available on the host.
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
# external
cprint(
f"Failed to import provider {provider.provider_type} for API {api} - assuming it's external, skipping: {exc}",
color="yellow",
file=sys.stderr,
)
# Set config_type to None to avoid UnboundLocalError
config_type = None
if config_type is not None and hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
else:
config = {}
p_spec = Provider(
provider_id=pid,
provider_type=provider.provider_type,
config=config,
module=provider.module,
)
run_config.providers[api].append(p_spec)
run_config_file = build_dir / f"{image_name}-run.yaml"
with open(run_config_file, "w") as f:
to_write = json.loads(run_config.model_dump_json())
f.write(yaml.dump(to_write, sort_keys=False))
# Only print this message for non-container builds since it will be displayed before the
# container is built
# For non-container builds, the run.yaml is generated at the very end of the build process so it
# makes sense to display this message
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr)
return run_config_file
@lru_cache
def available_templates_specs() -> dict[str, BuildConfig]:
import yaml
template_specs = {}
for p in TEMPLATES_PATH.rglob("*build.yaml"):
template_name = p.parent.name
with open(p) as f:
build_config = BuildConfig(**yaml.safe_load(f))
template_specs[template_name] = build_config
return template_specs

View file

@ -1,141 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import hashlib
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from llama_stack.cli.subcommand import Subcommand
@dataclass
class VerificationResult:
filename: str
expected_hash: str
actual_hash: str | None
exists: bool
matches: bool
class VerifyDownload(Subcommand):
"""Llama cli for verifying downloaded model files"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"verify-download",
prog="llama verify-download",
description="Verify integrity of downloaded model files",
formatter_class=argparse.RawTextHelpFormatter,
)
setup_verify_download_parser(self.parser)
def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--model-id",
required=True,
help="Model ID to verify (only for models downloaded from Meta)",
)
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
def calculate_sha256(filepath: Path, chunk_size: int = 8192) -> str:
sha256_hash = hashlib.sha256()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
def load_checksums(checklist_path: Path) -> dict[str, str]:
checksums = {}
with open(checklist_path) as f:
for line in f:
if line.strip():
sha256sum, filepath = line.strip().split(" ", 1)
# Remove leading './' if present
filepath = filepath.lstrip("./")
checksums[filepath] = sha256sum
return checksums
def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -> list[VerificationResult]:
results = []
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
for filepath, expected_hash in checksums.items():
full_path = model_dir / filepath
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
exists = full_path.exists()
actual_hash = None
matches = False
if exists:
actual_hash = calculate_sha256(full_path)
matches = actual_hash == expected_hash
results.append(
VerificationResult(
filename=filepath,
expected_hash=expected_hash,
actual_hash=actual_hash,
exists=exists,
matches=matches,
)
)
progress.remove_task(task_id)
return results
def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_stack.core.utils.model_utils import model_local_dir
console = Console()
model_dir = Path(model_local_dir(args.model_id))
checklist_path = model_dir / "checklist.chk"
if not model_dir.exists():
parser.error(f"Model directory not found: {model_dir}")
if not checklist_path.exists():
parser.error(f"Checklist file not found: {checklist_path}")
checksums = load_checksums(checklist_path)
results = verify_files(model_dir, checksums, console)
# Print results
console.print("\nVerification Results:")
all_good = True
for result in results:
if not result.exists:
console.print(f"[red]❌ {result.filename}: File not found[/red]")
all_good = False
elif not result.matches:
console.print(
f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
f" Expected: {result.expected_hash}\n"
f" Got: {result.actual_hash}"
)
all_good = False
else:
console.print(f"[green]✓ {result.filename}: Verified[/green]")
if all_good:
console.print("\n[green]All files verified successfully![/green]")

View file

@ -41,7 +41,7 @@ class AccessRule(BaseModel):
A rule defines a list of action either to permit or to forbid. It may specify a
principal or a resource that must match for the rule to take effect. The resource
to match should be specified in the form of a type qualified identifier, e.g.
model::my-model or vector_db::some-db, or a wildcard for all resources of a type,
model::my-model or vector_store::some-db, or a wildcard for all resources of a type,
e.g. model::*. If the principal or resource are not specified, they will match all
requests.
@ -79,9 +79,9 @@ class AccessRule(BaseModel):
description: any user has read access to any resource created by a member of their team
- forbid:
actions: [create, read, delete]
resource: vector_db::*
resource: vector_store::*
unless: user with admin in roles
description: only user with admin role can use vector_db resources
description: only user with admin role can use vector_store resources
"""

View file

@ -1,410 +0,0 @@
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
BUILD_PLATFORM=${BUILD_PLATFORM:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
# Path to the run.yaml file in the container
RUN_CONFIG_PATH=/app/run.yaml
BUILD_CONTEXT_DIR=$(pwd)
set -euo pipefail
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
# Usage function
usage() {
echo "Usage: $0 --image-name <image_name> --container-base <container_base> --normal-deps <pip_dependencies> [--run-config <run_config>] [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
echo "Example: $0 --image-name llama-stack-img --container-base python:3.12-slim --normal-deps 'numpy pandas' --run-config ./run.yaml --external-provider-deps 'foo' --optional-deps 'bar'"
exit 1
}
# Parse arguments
image_name=""
container_base=""
normal_deps=""
external_provider_deps=""
optional_deps=""
run_config=""
distro_or_config=""
while [[ $# -gt 0 ]]; do
key="$1"
case "$key" in
--image-name)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --image-name requires a string value" >&2
usage
fi
image_name="$2"
shift 2
;;
--container-base)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --container-base requires a string value" >&2
usage
fi
container_base="$2"
shift 2
;;
--normal-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --normal-deps requires a string value" >&2
usage
fi
normal_deps="$2"
shift 2
;;
--external-provider-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --external-provider-deps requires a string value" >&2
usage
fi
external_provider_deps="$2"
shift 2
;;
--optional-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --optional-deps requires a string value" >&2
usage
fi
optional_deps="$2"
shift 2
;;
--run-config)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --run-config requires a string value" >&2
usage
fi
run_config="$2"
shift 2
;;
--distro-or-config)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --distro-or-config requires a string value" >&2
usage
fi
distro_or_config="$2"
shift 2
;;
*)
echo "Unknown option: $1" >&2
usage
;;
esac
done
# Check required arguments
if [[ -z "$image_name" || -z "$container_base" || -z "$normal_deps" ]]; then
echo "Error: --image-name, --container-base, and --normal-deps are required." >&2
usage
fi
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
TEMP_DIR=$(mktemp -d)
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
add_to_container() {
output_file="$TEMP_DIR/Containerfile"
if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file"
else
cat >>"$output_file"
fi
}
if ! is_command_available "$CONTAINER_BINARY"; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1
fi
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
add_to_container << EOF
FROM $container_base
WORKDIR /app
# We install the Python 3.12 dev headers and build tools so that any
# C-extension wheels (e.g. polyleven, faiss-cpu) can compile successfully.
RUN dnf -y update && dnf install -y iputils git net-tools wget \
vim-minimal python3.12 python3.12-pip python3.12-wheel \
python3.12-setuptools python3.12-devel gcc gcc-c++ make && \
ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1
RUN pip install uv
EOF
else
add_to_container << EOF
FROM $container_base
WORKDIR /app
RUN apt-get update && apt-get install -y \
iputils-ping net-tools iproute2 dnsutils telnet \
curl wget telnet git\
procps psmisc lsof \
traceroute \
bubblewrap \
gcc g++ \
&& rm -rf /var/lib/apt/lists/*
ENV UV_SYSTEM_PYTHON=1
RUN pip install uv
EOF
fi
# Add pip dependencies first since llama-stack is what will change most often
# so we can reuse layers.
if [ -n "$normal_deps" ]; then
read -ra pip_args <<< "$normal_deps"
quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container << EOF
RUN uv pip install --no-cache $quoted_deps
EOF
fi
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
read -ra pip_args <<< "$part"
quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container <<EOF
RUN uv pip install --no-cache $quoted_deps
EOF
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
read -ra pip_args <<< "$part"
quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container <<EOF
RUN uv pip install --no-cache $quoted_deps
EOF
add_to_container <<EOF
RUN python3 - <<PYTHON | uv pip install --no-cache -r -
import importlib
import sys
try:
package_name = '$part'.split('==')[0].split('>=')[0].split('<=')[0].split('!=')[0].split('<')[0].split('>')[0]
module = importlib.import_module(f'{package_name}.provider')
spec = module.get_provider_spec()
if hasattr(spec, 'pip_packages') and spec.pip_packages:
if isinstance(spec.pip_packages, (list, tuple)):
print('\n'.join(spec.pip_packages))
except Exception as e:
print(f'Error getting provider spec for {package_name}: {e}', file=sys.stderr)
PYTHON
EOF
done
fi
get_python_cmd() {
if is_command_available python; then
echo "python"
elif is_command_available python3; then
echo "python3"
else
echo "Error: Neither python nor python3 is installed. Please install Python to continue." >&2
exit 1
fi
}
if [ -n "$run_config" ]; then
# Copy the run config to the build context since it's an absolute path
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
# Parse the run.yaml configuration to identify external provider directories
# If external providers are specified, copy their directory to the container
# and update the configuration to reference the new container path
python_cmd=$(get_python_cmd)
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
external_providers_dir=$(eval echo "$external_providers_dir")
if [ -n "$external_providers_dir" ]; then
if [ -d "$external_providers_dir" ]; then
echo "Copying external providers directory: $external_providers_dir"
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
add_to_container << EOF
COPY providers.d /.llama/providers.d
EOF
fi
# Edit the run.yaml file to change the external_providers_dir to /.llama/providers.d
if [ "$(uname)" = "Darwin" ]; then
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
else
sed -i 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
fi
fi
# Copy run config into docker image
add_to_container << EOF
COPY run.yaml $RUN_CONFIG_PATH
EOF
fi
stack_mount="/app/llama-stack-source"
client_mount="/app/llama-stack-client-source"
install_local_package() {
local dir="$1"
local mount_point="$2"
local name="$3"
if [ ! -d "$dir" ]; then
echo "${RED}Warning: $name is set but directory does not exist: $dir${NC}" >&2
exit 1
fi
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
add_to_container << EOF
COPY $dir $mount_point
EOF
fi
add_to_container << EOF
RUN uv pip install --no-cache -e $mount_point
EOF
}
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
fi
if [ -n "$LLAMA_STACK_DIR" ]; then
install_local_package "$LLAMA_STACK_DIR" "$stack_mount" "LLAMA_STACK_DIR"
else
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
add_to_container << EOF
RUN uv pip install --no-cache fastapi libcst
EOF
add_to_container << EOF
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-stack==$TEST_PYPI_VERSION
EOF
else
if [ -n "$PYPI_VERSION" ]; then
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
else
SPEC_VERSION="llama-stack"
fi
add_to_container << EOF
RUN uv pip install --no-cache $SPEC_VERSION
EOF
fi
fi
# remove uv after installation
add_to_container << EOF
RUN pip uninstall -y uv
EOF
# If a run config is provided, we use the llama stack CLI
if [[ -n "$run_config" ]]; then
add_to_container << EOF
ENTRYPOINT ["llama", "stack", "run", "$RUN_CONFIG_PATH"]
EOF
elif [[ "$distro_or_config" != *.yaml ]]; then
add_to_container << EOF
ENTRYPOINT ["llama", "stack", "run", "$distro_or_config"]
EOF
fi
# Add other require item commands genearic to all containers
add_to_container << EOF
RUN mkdir -p /.llama /.cache && chmod -R g+rw /app /.llama /.cache
EOF
printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
cat "$TEMP_DIR"/Containerfile
printf "\n"
# Start building the CLI arguments
CLI_ARGS=()
# Read CONTAINER_OPTS and put it in an array
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
if [ -n "$LLAMA_STACK_DIR" ]; then
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
fi
fi
if is_command_available selinuxenabled && selinuxenabled; then
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
CLI_ARGS+=("--security-opt" "label=disable")
fi
# Set version tag based on PyPI version
if [ -n "$PYPI_VERSION" ]; then
version_tag="$PYPI_VERSION"
elif [ -n "$TEST_PYPI_VERSION" ]; then
version_tag="test-$TEST_PYPI_VERSION"
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_STACK_CLIENT_DIR" ]]; then
version_tag="dev"
else
URL="https://pypi.org/pypi/llama-stack/json"
version_tag=$(curl -s $URL | jq -r '.info.version')
fi
# Add version tag to image name
image_tag="$image_name:$version_tag"
# Detect platform architecture
ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then
CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
CLI_ARGS+=("--platform" "linux/arm64")
elif [ "$ARCH" = "x86_64" ]; then
CLI_ARGS+=("--platform" "linux/amd64")
else
echo "Unsupported architecture: $ARCH"
exit 1
fi
echo "PWD: $(pwd)"
echo "Containerfile: $TEMP_DIR/Containerfile"
set -x
$CONTAINER_BINARY build \
"${CLI_ARGS[@]}" \
-t "$image_tag" \
-f "$TEMP_DIR/Containerfile" \
"$BUILD_CONTEXT_DIR"
# clean up tmp/configs
rm -rf "$BUILD_CONTEXT_DIR/run.yaml" "$TEMP_DIR"
set +x
echo "Success!"

View file

@ -1,220 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
VIRTUAL_ENV=${VIRTUAL_ENV:-}
set -euo pipefail
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# Usage function
usage() {
echo "Usage: $0 --env-name <env_name> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
echo "Example: $0 --env-name mybuild --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
exit 1
}
# Parse arguments
env_name=""
normal_deps=""
external_provider_deps=""
optional_deps=""
while [[ $# -gt 0 ]]; do
key="$1"
case "$key" in
--env-name)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --env-name requires a string value" >&2
usage
fi
env_name="$2"
shift 2
;;
--normal-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --normal-deps requires a string value" >&2
usage
fi
normal_deps="$2"
shift 2
;;
--external-provider-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --external-provider-deps requires a string value" >&2
usage
fi
external_provider_deps="$2"
shift 2
;;
--optional-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --optional-deps requires a string value" >&2
usage
fi
optional_deps="$2"
shift 2
;;
*)
echo "Unknown option: $1" >&2
usage
;;
esac
done
# Check required arguments
if [[ -z "$env_name" || -z "$normal_deps" ]]; then
echo "Error: --env-name and --normal-deps are required." >&2
usage
fi
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi
ENVNAME=""
# pre-run checks to make sure we can proceed with the installation
pre_run_checks() {
local env_name="$1"
if ! is_command_available uv; then
echo "uv is not installed, trying to install it."
if ! is_command_available pip; then
echo "pip is not installed, cannot automatically install 'uv'."
echo "Follow this link to install it:"
echo "https://docs.astral.sh/uv/getting-started/installation/"
exit 1
else
pip install uv
fi
fi
# checking if an environment with the same name already exists
if [ -d "$env_name" ]; then
echo "Environment '$env_name' already exists, re-using it."
fi
}
run() {
# Use only global variables set by flag parser
if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then
echo "Installing dependencies in system Python environment"
export UV_SYSTEM_PYTHON=1
elif [ "$VIRTUAL_ENV" == "$env_name" ]; then
echo "Virtual environment $env_name is already active"
else
echo "Using virtual environment $env_name"
uv venv "$env_name"
source "$env_name/bin/activate"
fi
if [ -n "$TEST_PYPI_VERSION" ]; then
uv pip install fastapi libcst
uv pip install --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-stack=="$TEST_PYPI_VERSION" \
$normal_deps
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install "$part"
done
fi
else
if [ -n "$LLAMA_STACK_DIR" ]; then
# only warn if DIR does not start with "git+"
if [ ! -d "$LLAMA_STACK_DIR" ] && [[ "$LLAMA_STACK_DIR" != git+* ]]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR"
# editable only if LLAMA_STACK_DIR does not start with "git+"
if [[ "$LLAMA_STACK_DIR" != git+* ]]; then
EDITABLE="-e"
else
EDITABLE=""
fi
uv pip install --no-cache-dir $EDITABLE "$LLAMA_STACK_DIR"
else
uv pip install --no-cache-dir llama-stack
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
# only warn if DIR does not start with "git+"
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ] && [[ "$LLAMA_STACK_CLIENT_DIR" != git+* ]]; then
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
# editable only if LLAMA_STACK_CLIENT_DIR does not start with "git+"
if [[ "$LLAMA_STACK_CLIENT_DIR" != git+* ]]; then
EDITABLE="-e"
else
EDITABLE=""
fi
uv pip install --no-cache-dir $EDITABLE "$LLAMA_STACK_CLIENT_DIR"
fi
printf "Installing pip dependencies\n"
uv pip install $normal_deps
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "Installing special provider module: $part"
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "Installing external provider module: $part"
uv pip install "$part"
echo "Getting provider spec for module: $part and installing dependencies"
package_name=$(echo "$part" | sed 's/[<>=!].*//')
python3 -c "
import importlib
import sys
try:
module = importlib.import_module(f'$package_name.provider')
spec = module.get_provider_spec()
if hasattr(spec, 'pip_packages') and spec.pip_packages:
print('\\n'.join(spec.pip_packages))
except Exception as e:
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
" | uv pip install -r -
done
fi
fi
}
pre_run_checks "$env_name"
run

View file

@ -64,7 +64,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if config.apis:
apis_to_serve = config.apis
else:
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]
apis_to_serve = [a.value for a in Api if a not in (Api.inspect, Api.providers)]
for api_str in apis_to_serve:
api = Api(api_str)
@ -159,6 +159,37 @@ def upgrade_from_routing_table(
config_dict["apis"] = config_dict["apis_to_serve"]
config_dict.pop("apis_to_serve", None)
# Add default storage config if not present
if "storage" not in config_dict:
config_dict["storage"] = {
"backends": {
"kv_default": {
"type": "kv_sqlite",
"db_path": "~/.llama/kvstore.db",
},
"sql_default": {
"type": "sql_sqlite",
"db_path": "~/.llama/sql_store.db",
},
},
"stores": {
"metadata": {
"namespace": "registry",
"backend": "kv_default",
},
"inference": {
"table_name": "inference_store",
"backend": "sql_default",
"max_write_queue_size": 10000,
"num_writers": 4,
},
"conversations": {
"table_name": "openai_conversations",
"backend": "sql_default",
},
},
}
return config_dict

View file

@ -4,12 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import secrets
import time
from typing import Any
from typing import Any, Literal
from openai import NOT_GIVEN
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.conversations.conversations import (
@ -17,20 +15,16 @@ from llama_stack.apis.conversations.conversations import (
ConversationDeletedResource,
ConversationItem,
ConversationItemDeletedResource,
ConversationItemInclude,
ConversationItemList,
Conversations,
Metadata,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.core.datatypes import AccessRule, StackRunConfig
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import (
SqliteSqlStoreConfig,
SqlStoreConfig,
sqlstore_impl,
)
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
logger = get_logger(name=__name__, category="openai_conversations")
@ -38,13 +32,11 @@ logger = get_logger(name=__name__, category="openai_conversations")
class ConversationServiceConfig(BaseModel):
"""Configuration for the built-in conversation service.
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
:param run_config: Stack run configuration for resolving persistence
:param policy: Access control rules
"""
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
)
run_config: StackRunConfig
policy: list[AccessRule] = []
@ -63,14 +55,16 @@ class ConversationServiceImpl(Conversations):
self.deps = deps
self.policy = config.policy
base_sql_store = sqlstore_impl(config.conversations_store)
# Use conversations store reference from run config
conversations_ref = config.run_config.storage.stores.conversations
if not conversations_ref:
raise ValueError("storage.stores.conversations must be configured in run config")
base_sql_store = sqlstore_impl(conversations_ref)
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
async def initialize(self) -> None:
"""Initialize the store and create tables."""
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
await self.sql_store.create_table(
"openai_conversations",
{
@ -135,7 +129,7 @@ class ConversationServiceImpl(Conversations):
object="conversation",
)
logger.info(f"Created conversation {conversation_id}")
logger.debug(f"Created conversation {conversation_id}")
return conversation
async def get_conversation(self, conversation_id: str) -> Conversation:
@ -161,7 +155,7 @@ class ConversationServiceImpl(Conversations):
"""Delete a conversation with the given ID."""
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
logger.info(f"Deleted conversation {conversation_id}")
logger.debug(f"Deleted conversation {conversation_id}")
return ConversationDeletedResource(id=conversation_id)
def _validate_conversation_id(self, conversation_id: str) -> None:
@ -193,12 +187,15 @@ class ConversationServiceImpl(Conversations):
await self._get_validated_conversation(conversation_id)
created_items = []
created_at = int(time.time())
base_time = int(time.time())
for item in items:
for i, item in enumerate(items):
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
# make each timestamp unique to maintain order
created_at = base_time + i
item_record = {
"id": item_id,
"conversation_id": conversation_id,
@ -219,7 +216,7 @@ class ConversationServiceImpl(Conversations):
created_items.append(item_dict)
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
logger.debug(f"Created {len(created_items)} items in conversation {conversation_id}")
# Convert created items (dicts) to proper ConversationItem types
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
@ -250,19 +247,30 @@ class ConversationServiceImpl(Conversations):
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
return adapter.validate_python(record["item_data"])
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
async def list_items(
self,
conversation_id: str,
after: str | None = None,
include: list[ConversationItemInclude] | None = None,
limit: int | None = None,
order: Literal["asc", "desc"] | None = None,
) -> ConversationItemList:
"""List items in the conversation."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
# check if conversation exists
await self.get_conversation(conversation_id)
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
records = result.data
if order != NOT_GIVEN and order == "asc":
if order is not None and order == "asc":
records.sort(key=lambda x: x["created_at"])
else:
records.sort(key=lambda x: x["created_at"], reverse=True)
actual_limit = 20
if limit != NOT_GIVEN and isinstance(limit, int):
actual_limit = limit
actual_limit = limit or 20
records = records[:actual_limit]
items = [record["item_data"] for record in records]
@ -302,5 +310,5 @@ class ConversationServiceImpl(Conversations):
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
logger.debug(f"Deleted item {item_id} from conversation {conversation_id}")
return ConversationItemDeletedResource(id=item_id)

View file

@ -23,12 +23,16 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.storage.datatypes import (
KVStoreReference,
StorageBackendType,
StorageConfig,
)
from llama_stack.log import LoggingConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
LLAMA_STACK_RUN_CONFIG_VERSION = 2
@ -68,7 +72,7 @@ class ShieldWithOwner(Shield, ResourceWithOwner):
pass
class VectorDBWithOwner(VectorDB, ResourceWithOwner):
class VectorStoreWithOwner(VectorStore, ResourceWithOwner):
pass
@ -88,12 +92,12 @@ class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
pass
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | ToolGroup
RoutableObject = Model | Shield | VectorStore | Dataset | ScoringFn | Benchmark | ToolGroup
RoutableObjectWithProvider = Annotated[
ModelWithOwner
| ShieldWithOwner
| VectorDBWithOwner
| VectorStoreWithOwner
| DatasetWithOwner
| ScoringFnWithOwner
| BenchmarkWithOwner
@ -176,12 +180,20 @@ class DistributionSpec(BaseModel):
)
class LoggingConfig(BaseModel):
category_levels: dict[str, str] = Field(
default_factory=dict,
description="""
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
)
class TelemetryConfig(BaseModel):
"""
Configuration for telemetry.
Llama Stack uses OpenTelemetry for telemetry. Please refer to https://opentelemetry.io/docs/languages/sdk-configuration/
for env variables to configure the OpenTelemetry SDK.
Example:
```bash
OTEL_SERVICE_NAME=llama-stack OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 uv run llama stack run starter
```
"""
enabled: bool = Field(default=False, description="enable or disable telemetry")
class OAuth2JWKSConfig(BaseModel):
@ -335,12 +347,41 @@ class AuthenticationRequiredError(Exception):
pass
class QualifiedModel(BaseModel):
"""A qualified model identifier, consisting of a provider ID and a model ID."""
provider_id: str
model_id: str
class VectorStoresConfig(BaseModel):
"""Configuration for vector stores in the stack."""
default_provider_id: str | None = Field(
default=None,
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
)
default_embedding_model: QualifiedModel | None = Field(
default=None,
description="Default embedding model configuration for vector stores.",
)
class SafetyConfig(BaseModel):
"""Configuration for default moderations model."""
default_shield_id: str | None = Field(
default=None,
description="ID of the shield to use for when `model` is not specified in the `moderations` API request.",
)
class QuotaPeriod(StrEnum):
DAY = "day"
class QuotaConfig(BaseModel):
kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
kvstore: KVStoreReference = Field(description="Config for KV store backend (SQLite only for now)")
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
authenticated_max_requests: int = Field(
default=1000, description="Max requests for authenticated clients per period"
@ -383,6 +424,18 @@ def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | N
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
class RegisteredResources(BaseModel):
"""Registry of resources available in the distribution."""
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)
vector_stores: list[VectorStoreInput] = Field(default_factory=list)
datasets: list[DatasetInput] = Field(default_factory=list)
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
@ -422,18 +475,6 @@ class ServerConfig(BaseModel):
)
class InferenceStoreConfig(BaseModel):
sql_store_config: SqlStoreConfig
max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
num_writers: int = Field(default=4, description="Number of concurrent background writers")
class ResponsesStoreConfig(BaseModel):
sql_store_config: SqlStoreConfig
max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store")
num_writers: int = Field(default=4, description="Number of concurrent background writers")
class StackRunConfig(BaseModel):
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
@ -460,39 +501,19 @@ One or more providers to use for each API. The same provider_type (e.g., meta-re
can be instantiated multiple times (with different configs) if necessary.
""",
)
metadata_store: KVStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the distribution registry. If not specified,
a default SQLite store will be used.""",
storage: StorageConfig = Field(
description="Catalog of named storage backends and references available to the stack",
)
inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the inference API. Can be either a
InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
If not specified, a default SQLite store will be used.""",
registered_resources: RegisteredResources = Field(
default_factory=RegisteredResources,
description="Registry of resources available in the distribution",
)
conversations_store: SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the conversations API.
If not specified, a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)
vector_dbs: list[VectorDBInput] = Field(default_factory=list)
datasets: list[DatasetInput] = Field(default_factory=list)
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig, description="Configuration for telemetry")
server: ServerConfig = Field(
default_factory=ServerConfig,
description="Configuration for the HTTP(S) server",
@ -508,6 +529,16 @@ If not specified, a default SQLite store will be used.""",
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
vector_stores: VectorStoresConfig | None = Field(
default=None,
description="Configuration for vector stores, including default embedding model",
)
safety: SafetyConfig | None = Field(
default=None,
description="Configuration for default moderations model",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
@ -517,6 +548,50 @@ If not specified, a default SQLite store will be used.""",
return Path(v)
return v
@model_validator(mode="after")
def validate_server_stores(self) -> "StackRunConfig":
backend_map = self.storage.backends
stores = self.storage.stores
kv_backends = {
name
for name, cfg in backend_map.items()
if cfg.type
in {
StorageBackendType.KV_REDIS,
StorageBackendType.KV_SQLITE,
StorageBackendType.KV_POSTGRES,
StorageBackendType.KV_MONGODB,
}
}
sql_backends = {
name
for name, cfg in backend_map.items()
if cfg.type in {StorageBackendType.SQL_SQLITE, StorageBackendType.SQL_POSTGRES}
}
def _ensure_backend(reference, expected_set, store_name: str) -> None:
if reference is None:
return
backend_name = reference.backend
if backend_name not in backend_map:
raise ValueError(
f"{store_name} references unknown backend '{backend_name}'. "
f"Available backends: {sorted(backend_map)}"
)
if backend_name not in expected_set:
raise ValueError(
f"{store_name} references backend '{backend_name}' of type "
f"'{backend_map[backend_name].type.value}', but a backend of type "
f"{'kv_*' if expected_set is kv_backends else 'sql_*'} is required."
)
_ensure_backend(stores.metadata, kv_backends, "storage.stores.metadata")
_ensure_backend(stores.inference, sql_backends, "storage.stores.inference")
_ensure_backend(stores.conversations, sql_backends, "storage.stores.conversations")
_ensure_backend(stores.responses, sql_backends, "storage.stores.responses")
_ensure_backend(stores.prompts, kv_backends, "storage.stores.prompts")
return self
class BuildConfig(BaseModel):
version: int = LLAMA_STACK_BUILD_CONFIG_VERSION

View file

@ -47,10 +47,6 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
routing_table_api=Api.shields,
router_api=Api.safety,
),
AutoRoutedApiInfo(
routing_table_api=Api.vector_dbs,
router_api=Api.vector_io,
),
AutoRoutedApiInfo(
routing_table_api=Api.datasets,
router_api=Api.datasetio,
@ -67,6 +63,10 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
routing_table_api=Api.tool_groups,
router_api=Api.tool_runtime,
),
AutoRoutedApiInfo(
routing_table_api=Api.vector_stores,
router_api=Api.vector_io,
),
]

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable
IdFactory = Callable[[], str]
IdOverride = Callable[[str, IdFactory], str]
_id_override: IdOverride | None = None
def generate_object_id(kind: str, factory: IdFactory) -> str:
"""Generate an identifier for the given kind using the provided factory.
Allows tests to override ID generation deterministically by installing an
override callback via :func:`set_id_override`.
"""
override = _id_override
if override is not None:
return override(kind, factory)
return factory()
def set_id_override(override: IdOverride) -> IdOverride | None:
"""Install an override used to generate deterministic identifiers."""
global _id_override
previous = _id_override
_id_override = override
return previous
def reset_id_override(previous: IdOverride | None) -> None:
"""Restore the previous override returned by :func:`set_id_override`."""
global _id_override
_id_override = previous

View file

@ -32,7 +32,7 @@ from termcolor import cprint
from llama_stack.core.build import print_pip_install_help
from llama_stack.core.configure import parse_and_maybe_upgrade_config
from llama_stack.core.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
@ -44,16 +44,13 @@ from llama_stack.core.stack import (
get_stack_run_config_from_distro,
replace_env_vars,
)
from llama_stack.core.telemetry import Telemetry
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger,
start_trace,
)
from llama_stack.log import get_logger, setup_logging
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
logger = get_logger(name=__name__, category="core")
@ -204,10 +201,14 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
skip_logger_removal: bool = False,
):
super().__init__()
# Initialize logging from environment variables first
setup_logging()
# when using the library client, we should not log to console since many
# of our logs are intended for server-side usage
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if sinks_from_env := os.environ.get("TELEMETRY_SINKS", None):
current_sinks = sinks_from_env.strip().lower().split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if in_notebook():
import nest_asyncio
@ -281,7 +282,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
else:
prefix = "!" if in_notebook() else ""
cprint(
f"Please run:\n\n{prefix}llama stack build --distro {self.config_path_or_distro_name} --image-type venv\n\n",
f"Please run:\n\n{prefix}llama stack list-deps {self.config_path_or_distro_name} | xargs -L1 uv pip install\n\n",
"yellow",
file=sys.stderr,
)
@ -293,8 +294,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise _e
assert self.impls is not None
if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry])
if self.config.telemetry.enabled:
setup_logger(Telemetry())
if not os.environ.get("PYTEST_CURRENT_TEST"):
console = Console()
@ -383,7 +384,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
@ -446,7 +447,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
# Prepare body for the function call (handles both Pydantic and traditional params)
body = self._convert_body(func, body)
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
@ -493,21 +495,31 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return await response.parse()
def _convert_body(
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
) -> dict:
if not body:
return {}
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
body = body or {}
exclude_params = exclude_params or set()
func, _, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func)
params_list = [p for p in sig.parameters.values() if p.name != "self"]
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
if len(params_list) == 1:
param = params_list[0]
param_type = param.annotation
if is_unwrapped_body_param(param_type):
base_type = get_args(param_type)[0]
return {param.name: base_type(**body)}
# Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
# Check if there's an unwrapped body parameter among multiple parameters
# (e.g., path param + body param like: vector_store_id: str, params: Annotated[Model, Body(...)])
unwrapped_body_param = None
for param in params_list:
if is_unwrapped_body_param(param.annotation):
unwrapped_body_param = param
break
# Convert parameters to Pydantic models where needed
converted_body = {}
for param_name, param in sig.parameters.items():
@ -518,4 +530,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
else:
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
# handle unwrapped body parameter after processing all named parameters
if unwrapped_body_param:
base_type = get_args(unwrapped_body_param.annotation)[0]
# extract only keys not already used by other params
remaining_keys = {k: v for k, v in body.items() if k not in converted_body}
converted_body[unwrapped_body_param.name] = base_type(**remaining_keys)
return converted_body

View file

@ -11,9 +11,7 @@ from pydantic import BaseModel
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class PromptServiceConfig(BaseModel):
@ -41,10 +39,11 @@ class PromptServiceImpl(Prompts):
self.kvstore: KVStore
async def initialize(self) -> None:
kvstore_config = SqliteKVStoreConfig(
db_path=(DISTRIBS_BASE_DIR / self.config.run_config.image_name / "prompts.db").as_posix()
)
self.kvstore = await kvstore_impl(kvstore_config)
# Use prompts store reference from run config
prompts_ref = self.config.run_config.storage.stores.prompts
if not prompts_ref:
raise ValueError("storage.stores.prompts must be configured in run config")
self.kvstore = await kvstore_impl(prompts_ref)
def _get_default_key(self, prompt_id: str) -> str:
"""Get the KVStore key that stores the default version number."""

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import importlib.metadata
import inspect
from typing import Any
@ -26,10 +27,9 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.core.client import get_client_impl
from llama_stack.core.datatypes import (
@ -55,7 +55,6 @@ from llama_stack.providers.datatypes import (
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolGroupsProtocolPrivate,
VectorDBsProtocolPrivate,
)
logger = get_logger(name=__name__, category="core")
@ -81,11 +80,10 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.inspect: Inspect,
Api.batches: Batches,
Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs,
Api.vector_stores: VectorStore,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
Api.telemetry: Telemetry,
Api.datasetio: DatasetIO,
Api.datasets: Datasets,
Api.scoring: Scoring,
@ -125,7 +123,6 @@ def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (
@ -150,6 +147,7 @@ async def resolve_impls(
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
internal_impls: dict[Api, Any] | None = None,
) -> dict[Api, Any]:
"""
Resolves provider implementations by:
@ -172,7 +170,7 @@ async def resolve_impls(
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy, internal_impls)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
@ -207,9 +205,7 @@ def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str,
module="llama_stack.core.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
# Add telemetry as an optional dependency to all auto-routed providers
optional_api_dependencies=[Api.telemetry],
deps__=([info.routing_table_api.value, Api.telemetry.value]),
deps__=([info.routing_table_api.value]),
),
)
}
@ -280,9 +276,10 @@ async def instantiate_providers(
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
internal_impls: dict[Api, Any] | None = None,
) -> dict[Api, Any]:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {}
impls: dict[Api, Any] = internal_impls.copy() if internal_impls else {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
# Skip providers that are not enabled
@ -391,6 +388,8 @@ async def instantiate_provider(
args = [config, deps]
if "policy" in inspect.signature(getattr(module, method)).parameters:
args.append(policy)
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
args.append(run_config.telemetry.enabled)
fn = getattr(module, method)
impl = await fn(*args)

View file

@ -6,7 +6,10 @@
from typing import Any
from llama_stack.core.datatypes import AccessRule, RoutedProtocol
from llama_stack.core.datatypes import (
AccessRule,
RoutedProtocol,
)
from llama_stack.core.stack import StackRunConfig
from llama_stack.core.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
@ -26,16 +29,16 @@ async def get_routing_table_impl(
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
from ..routing_tables.vector_stores import VectorStoresRoutingTable
api_to_tables = {
"vector_dbs": VectorDBsRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable,
"benchmarks": BenchmarksRoutingTable,
"tool_groups": ToolGroupsRoutingTable,
"vector_stores": VectorStoresRoutingTable,
}
if api.value not in api_to_tables:
@ -65,25 +68,28 @@ async def get_auto_router_impl(
"eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
}
api_to_deps = {
"inference": {"telemetry": Api.telemetry},
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")
api_to_dep_impl = {}
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
if dep_api in deps:
api_to_dep_impl[dep_name] = deps[dep_api]
# TODO: move pass configs to routers instead
if api == Api.inference and run_config.inference_store:
if api == Api.inference:
inference_ref = run_config.storage.stores.inference
if not inference_ref:
raise ValueError("storage.stores.inference must be configured in run config")
inference_store = InferenceStore(
config=run_config.inference_store,
reference=inference_ref,
policy=policy,
)
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store
api_to_dep_impl["telemetry_enabled"] = run_config.telemetry.enabled
elif api == Api.vector_io:
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
elif api == Api.safety:
api_to_dep_impl["safety_config"] = run_config.safety
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize()

View file

@ -10,9 +10,10 @@ from collections.abc import AsyncGenerator, AsyncIterator
from datetime import UTC, datetime
from typing import Annotated, Any
from fastapi import Body
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -31,27 +32,34 @@ from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAICompletionWithInputMessages,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
Order,
RerankResponse,
StopReason,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
logger = get_logger(name=__name__, category="core::routers")
@ -62,14 +70,14 @@ class InferenceRouter(Inference):
def __init__(
self,
routing_table: RoutingTable,
telemetry: Telemetry | None = None,
store: InferenceStore | None = None,
telemetry_enabled: bool = False,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
self.telemetry_enabled = telemetry_enabled
self.store = store
if self.telemetry:
if self.telemetry_enabled:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
@ -151,7 +159,7 @@ class InferenceRouter(Inference):
model: Model,
) -> list[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
if self.telemetry_enabled:
for metric in metrics:
enqueue_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
@ -179,64 +187,43 @@ class InferenceRouter(Inference):
raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model
async def openai_completion(
async def rerank(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
logger.debug(f"InferenceRouter.rerank: {model}")
model_obj = await self._get_model(model, ModelType.rerank)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.rerank(
model=model_obj.identifier,
query=query,
items=items,
max_num_results=max_num_results,
)
async def openai_completion(
self,
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
)
model_obj = await self._get_model(model, ModelType.llm)
params = dict(
model=model_obj.identifier,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
suffix=suffix,
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
)
model_obj = await self._get_model(params.model, ModelType.llm)
# Update params with the resolved model identifier
params.model = model_obj.identifier
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
return await provider.openai_completion(**params)
if params.stream:
return await provider.openai_completion(params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
# response_stream = await provider.openai_completion(**params)
response = await provider.openai_completion(**params)
if self.telemetry:
response = await provider.openai_completion(params)
if self.telemetry_enabled:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
@ -254,95 +241,51 @@ class InferenceRouter(Inference):
async def openai_chat_completion(
self,
model: str,
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
)
model_obj = await self._get_model(model, ModelType.llm)
model_obj = await self._get_model(params.model, ModelType.llm)
# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
if tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
if tools is None:
if params.tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
if params.tools is None:
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
if tools:
for tool in tools:
if params.tools:
for tool in params.tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
# Some providers make tool calls even when tool_choice is "none"
# so just clear them both out to avoid unexpected tool calls
if tool_choice == "none" and tools is not None:
tool_choice = None
tools = None
if params.tool_choice == "none" and params.tools is not None:
params.tool_choice = None
params.tools = None
# Update params with the resolved model identifier
params.model = model_obj.identifier
params = dict(
model=model_obj.identifier,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
response_stream = await provider.openai_chat_completion(**params)
if params.stream:
response_stream = await provider.openai_chat_completion(params)
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
# We need to add metrics to each chunk and store the final completion
return self.stream_tokens_and_compute_metrics_openai_chat(
response=response_stream,
model=model_obj,
messages=messages,
messages=params.messages,
)
response = await self._nonstream_openai_chat_completion(provider, params)
# Store the response with the ID that will be returned to the client
if self.store:
asyncio.create_task(self.store.store_chat_completion(response, messages))
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
if self.telemetry:
if self.telemetry_enabled:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
@ -359,26 +302,18 @@ class InferenceRouter(Inference):
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)],
) -> OpenAIEmbeddingsResponse:
logger.debug(
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
)
model_obj = await self._get_model(model, ModelType.embedding)
params = dict(
model=model_obj.identifier,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}",
)
model_obj = await self._get_model(params.model, ModelType.embedding)
# Update model to use resolved identifier
params.model = model_obj.identifier
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_embeddings(**params)
return await provider.openai_embeddings(params)
async def list_chat_completions(
self,
@ -396,8 +331,10 @@ class InferenceRouter(Inference):
return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
async def _nonstream_openai_chat_completion(
self, provider: Inference, params: OpenAIChatCompletionRequestWithExtraBody
) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(params)
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
@ -456,7 +393,7 @@ class InferenceRouter(Inference):
else:
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry_enabled:
complete = True
completion_tokens = await self._count_tokens(completion_text)
# if we are done receiving tokens
@ -464,7 +401,7 @@ class InferenceRouter(Inference):
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for streaming completion metrics
if self.telemetry:
if self.telemetry_enabled:
# Log metrics in the new span context
completion_metrics = self._construct_metrics(
prompt_tokens=prompt_tokens,
@ -513,7 +450,7 @@ class InferenceRouter(Inference):
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for completion metrics
if self.telemetry:
if self.telemetry_enabled:
# Log metrics in the new span context
completion_metrics = self._construct_metrics(
prompt_tokens=prompt_tokens,
@ -611,7 +548,7 @@ class InferenceRouter(Inference):
completion_text += "".join(choice_data["content_parts"])
# Add metrics to the chunk
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage:
metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,

View file

@ -10,6 +10,7 @@ from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import SafetyConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
@ -20,9 +21,11 @@ class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
safety_config: SafetyConfig | None = None,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
self.safety_config = safety_config
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
@ -60,26 +63,47 @@ class SafetyRouter(Safety):
params=params,
)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
async def get_shield_id(self, model: str) -> str:
"""Get Shield id from model (provider_resource_id) of shield."""
list_shields_response = await self.routing_table.list_shields()
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
list_shields_response = await self.routing_table.list_shields()
shields = list_shields_response.data
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
selected_shield: Shield | None = None
provider_model: str | None = model
if model:
matches: list[Shield] = [s for s in shields if model == s.provider_resource_id]
if not matches:
raise ValueError(f"No shield associated with provider_resource id {model}")
raise ValueError(
f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in shields]}"
)
if len(matches) > 1:
raise ValueError(f"Multiple shields associated with provider_resource id {model}")
return matches[0]
raise ValueError(
f"Multiple shields associated with provider_resource id {model}: matched shields {[s.identifier for s in matches]}"
)
selected_shield = matches[0]
else:
default_shield_id = self.safety_config.default_shield_id if self.safety_config else None
if not default_shield_id:
raise ValueError(
"No moderation model specified and no default_shield_id configured in safety config: select model "
f"from {[s.provider_resource_id or s.identifier for s in shields]}"
)
shield_id = await get_shield_id(self, model)
selected_shield = next((s for s in shields if s.identifier == default_shield_id), None)
if selected_shield is None:
raise ValueError(
f"Default moderation model not found. Choose from {[s.provider_resource_id or s.identifier for s in shields]}."
)
provider_model = selected_shield.provider_resource_id
shield_id = selected_shield.identifier
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
response = await provider.run_moderation(
input=input,
model=model,
model=provider_model,
)
return response

View file

@ -37,24 +37,24 @@ class ToolRuntimeRouter(ToolRuntime):
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
vector_store_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_db_ids, query_config)
return await provider.query(content, vector_store_ids, query_config)
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
vector_store_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
def __init__(
self,

View file

@ -6,12 +6,16 @@
import asyncio
import uuid
from typing import Any
from typing import Annotated, Any
from fastapi import Body
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
OpenAICreateVectorStoreRequestWithExtraBody,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
@ -27,6 +31,7 @@ from llama_stack.apis.vector_io import (
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
@ -39,9 +44,11 @@ class VectorIORouter(VectorIO):
def __init__(
self,
routing_table: RoutingTable,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
self.vector_stores_config = vector_stores_config
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
@ -51,49 +58,18 @@ class VectorIORouter(VectorIO):
logger.debug("VectorIORouter.shutdown")
pass
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
"""Get the first available embedding model identifier."""
try:
# Get all models from the routing table
all_models = await self.routing_table.get_all_with_type("model")
async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int:
"""Get the embedding dimension for a specific embedding model."""
all_models = await self.routing_table.get_all_with_type("model")
# Filter for embedding models
embedding_models = [
model
for model in all_models
if hasattr(model, "model_type") and model.model_type == ModelType.embedding
]
if embedding_models:
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
for model in all_models:
if model.identifier == embedding_model_id and model.model_type == ModelType.embedding:
dimension = model.metadata.get("embedding_dimension")
if dimension is None:
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
return embedding_models[0].identifier, dimension
else:
logger.warning("No embedding models found in the routing table")
return None
except Exception as e:
logger.error(f"Error getting embedding models: {e}")
return None
raise ValueError(f"Embedding model '{embedding_model_id}' has no embedding_dimension in metadata")
return int(dimension)
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
embedding_dimension,
provider_id,
vector_db_name,
provider_vector_db_id,
)
raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model")
async def insert_chunks(
self,
@ -101,8 +77,10 @@ class VectorIORouter(VectorIO):
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
doc_ids = [chunk.document_id for chunk in chunks[:3]]
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, "
f"ttl_seconds={ttl_seconds}, chunk_ids={doc_ids}{' and more...' if len(chunks) > 3 else ''}"
)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
@ -120,46 +98,76 @@ class VectorIORouter(VectorIO):
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = None,
provider_id: str | None = None,
params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
# Extract llama-stack-specific parameters from extra_body
extra = params.model_extra or {}
embedding_model = extra.get("embedding_model")
embedding_dimension = extra.get("embedding_dimension")
provider_id = extra.get("provider_id")
# If no embedding model is provided, use the first available one
if embedding_model is None:
embedding_model_info = await self._get_first_embedding_model()
if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system")
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
# Use default embedding model if not specified
if (
embedding_model is None
and self.vector_stores_config
and self.vector_stores_config.default_embedding_model is not None
):
# Construct the full model ID with provider prefix
embedding_provider_id = self.vector_stores_config.default_embedding_model.provider_id
model_id = self.vector_stores_config.default_embedding_model.model_id
embedding_model = f"{embedding_provider_id}/{model_id}"
vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db(
vector_db_id=vector_db_id,
if embedding_model is not None and embedding_dimension is None:
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
# Auto-select provider if not specified
if provider_id is None:
num_providers = len(self.routing_table.impls_by_provider_id)
if num_providers == 0:
raise ValueError("No vector_io providers available")
if num_providers > 1:
available_providers = list(self.routing_table.impls_by_provider_id.keys())
# Use default configured provider
if self.vector_stores_config and self.vector_stores_config.default_provider_id:
default_provider = self.vector_stores_config.default_provider_id
if default_provider in available_providers:
provider_id = default_provider
logger.debug(f"Using configured default vector store provider: {provider_id}")
else:
raise ValueError(
f"Configured default vector store provider '{default_provider}' not found. "
f"Available providers: {available_providers}"
)
else:
raise ValueError(
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
f"Available providers: {available_providers}"
)
else:
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
vector_store_id = f"vs_{uuid.uuid4()}"
registered_vector_store = await self.routing_table.register_vector_store(
vector_store_id=vector_store_id,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
provider_id=provider_id,
provider_vector_db_id=vector_db_id,
vector_db_name=name,
)
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
return await provider.openai_create_vector_store(
name=name,
file_ids=file_ids,
expires_after=expires_after,
chunking_strategy=chunking_strategy,
metadata=metadata,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
provider_id=registered_vector_db.provider_id,
provider_vector_db_id=registered_vector_db.provider_resource_id,
provider_vector_store_id=vector_store_id,
vector_store_name=params.name,
)
provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier)
# Update model_extra with registered values so provider uses the already-registered vector_store
if params.model_extra is None:
params.model_extra = {}
params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id
params.model_extra["provider_id"] = registered_vector_store.provider_id
if embedding_model is not None:
params.model_extra["embedding_model"] = embedding_model
if embedding_dimension is not None:
params.model_extra["embedding_dimension"] = embedding_dimension
return await provider.openai_create_vector_store(params)
async def openai_list_vector_stores(
self,
@ -171,15 +179,15 @@ class VectorIORouter(VectorIO):
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
# Route to default provider for now - could aggregate from all providers in the future
# call retrieve on each vector dbs to get list of vector stores
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
vector_stores = await self.routing_table.get_all_with_type("vector_store")
all_stores = []
for vector_db in vector_dbs:
for vector_store in vector_stores:
try:
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
provider = await self.routing_table.get_provider_impl(vector_store.identifier)
vector_store = await provider.openai_retrieve_vector_store(vector_store.identifier)
all_stores.append(vector_store)
except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
logger.error(f"Error retrieving vector store {vector_store.identifier}: {e}")
continue
# Sort by created_at
@ -219,7 +227,8 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
@ -229,7 +238,8 @@ class VectorIORouter(VectorIO):
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
return await self.routing_table.openai_update_vector_store(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
@ -254,7 +264,8 @@ class VectorIORouter(VectorIO):
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
return await self.routing_table.openai_search_vector_store(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
@ -272,7 +283,8 @@ class VectorIORouter(VectorIO):
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
return await self.routing_table.openai_attach_file_to_vector_store(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
@ -289,7 +301,8 @@ class VectorIORouter(VectorIO):
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
return await self.routing_table.openai_list_files_in_vector_store(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
@ -304,7 +317,8 @@ class VectorIORouter(VectorIO):
file_id: str,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_retrieve_vector_store_file(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -315,7 +329,8 @@ class VectorIORouter(VectorIO):
file_id: str,
) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
return await self.routing_table.openai_retrieve_vector_store_file_contents(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -327,7 +342,8 @@ class VectorIORouter(VectorIO):
attributes: dict[str, Any],
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_update_vector_store_file(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
@ -339,7 +355,8 @@ class VectorIORouter(VectorIO):
file_id: str,
) -> VectorStoreFileDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_delete_vector_store_file(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -370,17 +387,13 @@ class VectorIORouter(VectorIO):
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)],
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(file_ids)} files")
return await self.routing_table.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id,
file_ids=file_ids,
attributes=attributes,
chunking_strategy=chunking_strategy,
logger.debug(
f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(params.file_ids)} files"
)
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch(vector_store_id, params)
async def openai_retrieve_vector_store_file_batch(
self,
@ -388,7 +401,8 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}")
return await self.routing_table.openai_retrieve_vector_store_file_batch(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)
@ -404,7 +418,8 @@ class VectorIORouter(VectorIO):
order: str | None = "desc",
) -> VectorStoreFilesListInBatchResponse:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}")
return await self.routing_table.openai_list_files_in_vector_store_file_batch(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
after=after,
@ -420,7 +435,8 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}")
return await self.routing_table.openai_cancel_vector_store_file_batch(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)

View file

@ -9,7 +9,6 @@ from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import Model
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import Action
from llama_stack.core.datatypes import (
@ -17,6 +16,7 @@ from llama_stack.core.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
ScoringFnWithOwner,
)
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.core.store import DistributionRegistry
@ -41,7 +41,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
elif api == Api.safety:
return await p.register_shield(obj)
elif api == Api.vector_io:
return await p.register_vector_db(obj)
return await p.register_vector_store(obj)
elif api == Api.datasetio:
return await p.register_dataset(obj)
elif api == Api.scoring:
@ -57,7 +57,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.vector_io:
return await p.unregister_vector_db(obj.identifier)
return await p.unregister_vector_store(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.safety:
@ -108,13 +108,13 @@ class CommonRoutingTableImpl(RoutingTable):
elif api == Api.safety:
p.shield_store = self
elif api == Api.vector_io:
p.vector_db_store = self
p.vector_store_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
await add_objects(scoring_functions, pid, ScoringFnWithOwner)
elif api == Api.eval:
p.benchmark_store = self
elif api == Api.tool_runtime:
@ -134,15 +134,15 @@ class CommonRoutingTableImpl(RoutingTable):
from .scoring_functions import ScoringFunctionsRoutingTable
from .shields import ShieldsRoutingTable
from .toolgroups import ToolGroupsRoutingTable
from .vector_dbs import VectorDBsRoutingTable
from .vector_stores import VectorStoresRoutingTable
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, VectorDBsRoutingTable):
return ("VectorIO", "vector_db")
elif isinstance(self, VectorStoresRoutingTable):
return ("VectorIO", "vector_store")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
@ -248,25 +248,7 @@ class CommonRoutingTableImpl(RoutingTable):
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
# first try to get the model by identifier
# this works if model_id is an alias or is of the form provider_id/provider_model_id
model = await routing_table.get_object_by_identifier("model", model_id)
if model is not None:
return model
logger.warning(
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
"searching in all providers. This is only for backwards compatibility and will stop working "
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
)
# if not found, this means model_id is an unscoped provider_model_id, we need
# to iterate (given a lack of an efficient index on the KVStore)
models = await routing_table.get_all_with_type("model")
matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0:
if not model:
raise ModelNotFoundError(model_id)
if len(matching_models) > 1:
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
return matching_models[0]
return model

View file

@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
try:
models = await provider.list_models()
except Exception as e:
logger.debug(f"Model refresh failed for provider {provider_id}: {e}")
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
@ -104,15 +104,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
# an identifier different than provider_model_id implies it is an alias, so that
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
# so as a general rule we must use the provider_id to disambiguate.
if model_id != provider_model_id:
identifier = model_id
else:
identifier = f"{provider_id}/{provider_model_id}"
identifier = f"{provider_id}/{provider_model_id}"
model = ModelWithOwner(
identifier=identifier,
provider_resource_id=provider_model_id,

View file

@ -6,12 +6,11 @@
from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
# Removed VectorStores import to avoid exposing public API
from llama_stack.apis.vector_io.vector_io import (
SearchRankingOptions,
VectorStoreChunkingStrategy,
@ -24,7 +23,7 @@ from llama_stack.apis.vector_io.vector_io import (
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import (
VectorDBWithOwner,
VectorStoreWithOwner,
)
from llama_stack.log import get_logger
@ -33,25 +32,24 @@ from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
class VectorStoresRoutingTable(CommonRoutingTableImpl):
"""Internal routing table for vector_store operations.
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise VectorStoreNotFoundError(vector_db_id)
return vector_db
Does not inherit from VectorStores to avoid exposing public API endpoints.
Only provides internal routing functionality for VectorIORouter.
"""
async def register_vector_db(
# Internal methods only - no public API exposure
async def register_vector_store(
self,
vector_db_id: str,
vector_store_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB:
provider_vector_store_id: str | None = None,
vector_store_name: str | None = None,
) -> Any:
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
@ -66,49 +64,24 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding:
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
provider = self.impls_by_provider_id[provider_id]
logger.warning(
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
)
vector_store = await provider.openai_create_vector_store(
name=vector_db_name or vector_db_id,
embedding_model=embedding_model,
embedding_dimension=model.metadata["embedding_dimension"],
vector_store = VectorStoreWithOwner(
identifier=vector_store_id,
type=ResourceType.vector_store.value,
provider_id=provider_id,
provider_vector_db_id=provider_vector_db_id,
provider_resource_id=provider_vector_store_id,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
vector_store_name=vector_store_name,
)
vector_store_id = vector_store.id
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
logger.warning(
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
)
vector_db_data = {
"identifier": vector_store_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": actual_provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_store.name,
}
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
async def unregister_vector_db(self, vector_db_id: str) -> None:
existing_vector_db = await self.get_vector_db(vector_db_id)
await self.unregister_object(existing_vector_db)
await self.register_object(vector_store)
return vector_store
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
@ -119,7 +92,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
@ -132,12 +105,22 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id)
await self.unregister_vector_store(vector_store_id)
return result
async def unregister_vector_store(self, vector_store_id: str) -> None:
"""Remove the vector store from the routing table registry."""
try:
vector_store_obj = await self.get_object_by_identifier("vector_store", vector_store_id)
if vector_store_obj:
await self.unregister_object(vector_store_obj)
except Exception as e:
# Log the error but don't fail the operation
logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}")
async def openai_search_vector_store(
self,
vector_store_id: str,
@ -148,7 +131,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
@ -167,7 +150,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
@ -185,7 +168,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
@ -201,7 +184,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
@ -213,7 +196,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
@ -226,7 +209,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
@ -239,7 +222,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
@ -253,7 +236,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
attributes: dict[str, Any] | None = None,
chunking_strategy: Any | None = None,
):
await self.assert_action_allowed("update", "vector_db", vector_store_id)
await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id,
@ -267,7 +250,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("read", "vector_db", vector_store_id)
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
@ -284,7 +267,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
limit: int | None = 20,
order: str | None = "desc",
):
await self.assert_action_allowed("read", "vector_db", vector_store_id)
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
@ -301,7 +284,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("update", "vector_db", vector_store_id)
await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id,

View file

@ -27,6 +27,11 @@ class AuthenticationMiddleware:
3. Extracts user attributes from the provider's response
4. Makes these attributes available to the route handlers for access control
Unauthenticated Access:
Endpoints can opt out of authentication by setting require_authentication=False
in their @webmethod decorator. This is typically used for operational endpoints
like /health and /version to support monitoring, load balancers, and observability tools.
The middleware supports multiple authentication providers through the AuthProvider interface:
- Kubernetes: Validates tokens against the Kubernetes API server
- Custom: Validates tokens against a custom endpoint
@ -88,7 +93,26 @@ class AuthenticationMiddleware:
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# First, handle authentication
# Find the route and check if authentication is required
path = scope.get("path", "")
method = scope.get("method", hdrs.METH_GET)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
webmethod = None
try:
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass here to run auth anyways
pass
# If webmethod explicitly sets require_authentication=False, allow without auth
if webmethod and webmethod.require_authentication is False:
logger.debug(f"Allowing unauthenticated access to endpoint: {path}")
return await self.app(scope, receive, send)
# Handle authentication
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
@ -127,19 +151,7 @@ class AuthenticationMiddleware:
)
# Scope-based API access control
path = scope.get("path", "")
method = scope.get("method", hdrs.METH_GET)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
try:
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
return await self.app(scope, receive, send)
if webmethod.required_scope:
if webmethod and webmethod.required_scope:
user = user_from_scope(scope)
if not _has_required_scope(webmethod.required_scope, user):
return await self._send_auth_error(

View file

@ -5,13 +5,11 @@
# the root directory of this source tree.
import ssl
import time
from abc import ABC, abstractmethod
from asyncio import Lock
from urllib.parse import parse_qs, urljoin, urlparse
import httpx
from jose import jwt
import jwt
from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import TokenValidationError
@ -74,13 +72,30 @@ class AuthProvider(ABC):
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes: dict[str, list[str]] = {}
for claim_key, attribute_key in mapping.items():
if claim_key not in claims:
# First try dot notation for nested traversal (e.g., "resource_access.llamastack.roles")
# Then fall back to literal key with dots (e.g., "my.dotted.key")
claim: object = claims
keys = claim_key.split(".")
for key in keys:
if isinstance(claim, dict) and key in claim:
claim = claim[key]
else:
claim = None
break
if claim is None and claim_key in claims:
# Fall back to checking if claim_key exists as a literal key
claim = claims[claim_key]
if claim is None:
continue
claim = claims[claim_key]
if isinstance(claim, list):
values = claim
else:
elif isinstance(claim, str):
values = claim.split()
else:
continue
if attribute_key in attributes:
attributes[attribute_key].extend(values)
@ -98,9 +113,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
def __init__(self, config: OAuth2TokenAuthConfig):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
self._jwks_client: jwt.PyJWKClient | None = None
async def validate_token(self, token: str, scope: dict | None = None) -> User:
if self.config.jwks:
@ -109,23 +122,60 @@ class OAuth2TokenAuthProvider(AuthProvider):
return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured")
def _get_jwks_client(self) -> jwt.PyJWKClient:
if self._jwks_client is None:
ssl_context = None
if not self.config.verify_tls:
# Disable SSL verification if verify_tls is False
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
elif self.config.tls_cafile:
# Use custom CA file if provided
ssl_context = ssl.create_default_context(
cafile=self.config.tls_cafile.as_posix(),
)
# If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults)
# Prepare headers for JWKS request - this is needed for Kubernetes to authenticate
# to the JWK endpoint, we must use the token in the config to authenticate
headers = {}
if self.config.jwks and self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
self._jwks_client = jwt.PyJWKClient(
self.config.jwks.uri if self.config.jwks else None,
cache_keys=True,
max_cached_keys=10,
lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None,
headers=headers,
ssl_context=ssl_context,
)
return self._jwks_client
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the JWT token."""
await self._refresh_jwks()
try:
header = jwt.get_unverified_header(token)
kid = header["kid"]
if kid not in self._jwks:
raise ValueError(f"Unknown key ID: {kid}")
key_data = self._jwks[kid]
algorithm = header.get("alg", "RS256")
jwks_client: jwt.PyJWKClient = self._get_jwks_client()
signing_key = jwks_client.get_signing_key_from_jwt(token)
algorithm = jwt.get_unverified_header(token)["alg"]
claims = jwt.decode(
token,
key_data,
signing_key.key,
algorithms=[algorithm],
audience=self.config.audience,
issuer=self.config.issuer,
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
)
# Decode and verify the JWT
claims = jwt.decode(
token,
signing_key.key,
algorithms=[algorithm],
audience=self.config.audience,
issuer=self.config.issuer,
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
)
except Exception as exc:
raise ValueError("Invalid JWT token") from exc
@ -201,37 +251,6 @@ class OAuth2TokenAuthProvider(AuthProvider):
else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
If the cache is expired, we refresh the JWKS from the JWKS URI.
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
* It doesn't have user authentication flows
* It doesn't have refresh tokens
"""
async with self._jwks_lock:
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
headers = {}
if self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
updated[kid] = k
self._jwks = updated
self._jwks_at = time.time()
class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""

View file

@ -10,10 +10,10 @@ from datetime import UTC, datetime, timedelta
from starlette.types import ASGIApp, Receive, Scope, Send
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
logger = get_logger(name=__name__, category="core::server")
@ -33,7 +33,7 @@ class QuotaMiddleware:
def __init__(
self,
app: ASGIApp,
kv_config: KVStoreConfig,
kv_config: KVStoreReference,
anonymous_max_requests: int,
authenticated_max_requests: int,
window_seconds: int = 86400,
@ -45,15 +45,15 @@ class QuotaMiddleware:
self.authenticated_max_requests = authenticated_max_requests
self.window_seconds = window_seconds
if isinstance(self.kv_config, SqliteKVStoreConfig):
logger.warning(
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
f"window_seconds={self.window_seconds}"
)
async def _get_kv(self) -> KVStore:
if self.kv is None:
self.kv = await kvstore_impl(self.kv_config)
backend_config = _KVSTORE_BACKENDS.get(self.kv_config.backend)
if backend_config and backend_config.type == StorageBackendType.KV_SQLITE:
logger.warning(
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
f"window_seconds={self.window_seconds}"
)
return self.kv
async def __call__(self, scope: Scope, receive: Receive, send: Send):

View file

@ -36,7 +36,6 @@ from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.core.access_control.access_control import AccessDeniedError
from llama_stack.core.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
process_cors_config,
)
@ -53,19 +52,13 @@ from llama_stack.core.stack import (
cast_image_name_to_string,
replace_env_vars,
)
from llama_stack.core.telemetry import Telemetry
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, setup_logger
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger
from llama_stack.log import LoggingConfig, get_logger, setup_logging
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
setup_logger,
)
from .auth import AuthenticationMiddleware
from .quota import QuotaMiddleware
@ -138,6 +131,13 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}")
elif isinstance(exc, AuthenticationRequiredError):
return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}")
elif hasattr(exc, "status_code") and isinstance(getattr(exc, "status_code", None), int):
# Handle provider SDK exceptions (e.g., OpenAI's APIStatusError and subclasses)
# These include AuthenticationError (401), PermissionDeniedError (403), etc.
# This preserves the actual HTTP status code from the provider
status_code = exc.status_code
detail = str(exc)
return HTTPException(status_code=status_code, detail=detail)
else:
return HTTPException(
status_code=httpx.codes.INTERNAL_SERVER_ERROR,
@ -167,7 +167,9 @@ class StackApp(FastAPI):
@asynccontextmanager
async def lifespan(app: StackApp):
logger.info("Starting up")
server_version = parse_version("llama-stack")
logger.info(f"Starting up Llama Stack server (version: {server_version})")
assert app.stack is not None
app.stack.create_registry_refresh_task()
yield
@ -177,7 +179,17 @@ async def lifespan(app: StackApp):
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
# If there's a stream parameter at top level, use it
if "stream" in kwargs:
return kwargs["stream"]
# If there's a stream parameter inside a "params" parameter, e.g. openai_chat_completion() use it
if "params" in kwargs:
params = kwargs["params"]
if hasattr(params, "stream"):
return params.stream
return False
async def maybe_await(value):
@ -232,15 +244,31 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
await log_request_pre_validation(request)
test_context_token = None
test_context_var = None
reset_test_context_fn = None
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user):
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
from llama_stack.core.testing_context import (
TEST_CONTEXT,
reset_test_context,
sync_test_context_from_provider_data,
)
test_context_token = sync_test_context_from_provider_data()
test_context_var = TEST_CONTEXT
reset_test_context_fn = reset_test_context
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
)
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
if test_context_var is not None:
context_vars.append(test_context_var)
gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars)
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
@ -258,6 +286,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e
finally:
if test_context_token is not None and reset_test_context_fn is not None:
reset_test_context_fn(test_context_token)
sig = inspect.signature(func)
@ -338,6 +369,9 @@ def create_app() -> StackApp:
Returns:
Configured StackApp instance.
"""
# Initialize logging from environment variables first
setup_logging()
config_file = os.getenv("LLAMA_STACK_CONFIG")
if config_file is None:
raise ValueError("LLAMA_STACK_CONFIG environment variable is required")
@ -409,10 +443,8 @@ def create_app() -> StackApp:
if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
if config.telemetry.enabled:
setup_logger(Telemetry())
# Load external APIs if configured
external_apis = load_external_apis(config)
@ -470,7 +502,8 @@ def create_app() -> StackApp:
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
if config.telemetry.enabled:
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
return app

View file

@ -7,8 +7,8 @@ from aiohttp import hdrs
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.core.telemetry.tracing import end_trace, start_trace
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
logger = get_logger(name=__name__, category="core::server")

View file

@ -33,16 +33,25 @@ from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageBackendConfig,
StorageConfig,
)
from llama_stack.core.store.registry import create_dist_registry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
@ -53,7 +62,6 @@ logger = get_logger(name=__name__, category="core")
class LlamaStack(
Providers,
VectorDBs,
Inference,
Agents,
Safety,
@ -83,7 +91,6 @@ class LlamaStack(
RESOURCES = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
@ -103,7 +110,7 @@ TEST_RECORDING_CONTEXT = None
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
objects = getattr(run_config.registered_resources, rsrc)
if api not in impls:
continue
@ -132,6 +139,66 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
)
async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig | None, impls: dict[Api, Any]):
"""Validate vector stores configuration."""
if vector_stores_config is None:
return
default_embedding_model = vector_stores_config.default_embedding_model
if default_embedding_model is None:
return
provider_id = default_embedding_model.provider_id
model_id = default_embedding_model.model_id
default_model_id = f"{provider_id}/{model_id}"
if Api.models not in impls:
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
default_model = models_list.get(default_model_id)
if default_model is None:
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
embedding_dimension = default_model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
try:
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
if safety_config is None or safety_config.default_shield_id is None:
return
if Api.shields not in impls:
raise ValueError("Safety configuration requires the shields API to be enabled")
if Api.safety not in impls:
raise ValueError("Safety configuration requires the safety API to be enabled")
shields_impl = impls[Api.shields]
response = await shields_impl.list_shields()
shields_by_id = {shield.identifier: shield for shield in response.data}
default_shield_id = safety_config.default_shield_id
# don't validate if there are no shields registered
if shields_by_id and default_shield_id not in shields_by_id:
available = sorted(shields_by_id)
raise ValueError(
f"Configured default_shield_id '{default_shield_id}' not found among registered shields."
f" Available shields: {available}"
)
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name
@ -306,6 +373,25 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
impls[Api.conversations] = conversations_impl
def _initialize_storage(run_config: StackRunConfig):
kv_backends: dict[str, StorageBackendConfig] = {}
sql_backends: dict[str, StorageBackendConfig] = {}
for backend_name, backend_config in run_config.storage.backends.items():
type = backend_config.type.value
if type.startswith("kv_"):
kv_backends[backend_name] = backend_config
elif type.startswith("sql_"):
sql_backends[backend_name] = backend_config
else:
raise ValueError(f"Unknown storage backend type: {type}")
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
register_kvstore_backends(kv_backends)
register_sqlstore_backends(sql_backends)
class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
self.run_config = run_config
@ -316,22 +402,31 @@ class Stack:
# asked for in the run config.
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
from llama_stack.testing.api_recorder import setup_api_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
TEST_RECORDING_CONTEXT = setup_api_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
_initialize_storage(self.run_config)
stores = self.run_config.storage.stores
if not stores.metadata:
raise ValueError("storage.stores.metadata must be configured with a kv_* backend")
dist_registry, _ = await create_dist_registry(stores.metadata, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
impls = await resolve_impls(
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, self.run_config)
internal_impls = {}
add_internal_implementations(internal_impls, self.run_config)
impls = await resolve_impls(
self.run_config,
self.provider_registry or get_provider_registry(self.run_config),
dist_registry,
policy,
internal_impls,
)
if Api.prompts in impls:
await impls[Api.prompts].initialize()
@ -339,8 +434,9 @@ class Stack:
await impls[Api.conversations].initialize()
await register_resources(self.run_config, impls)
await refresh_registry_once(impls)
await validate_vector_stores_config(self.run_config.vector_stores, impls)
await validate_safety_config(self.run_config.safety, impls)
self.impls = impls
def create_registry_refresh_task(self):
@ -381,7 +477,7 @@ class Stack:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
logger.error(f"Error during API recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
@ -460,5 +556,17 @@ def run_config_from_adhoc_config_spec(
image_name="distro-test",
apis=list(provider_configs_by_api.keys()),
providers=provider_configs_by_api,
storage=StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(db_path=f"{distro_dir}/kvstore.db"),
"sql_default": SqliteSqlStoreConfig(db_path=f"{distro_dir}/sql_store.db"),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
conversations=SqlStoreReference(backend="sql_default", table_name="openai_conversations"),
prompts=KVStoreReference(backend="kv_default", namespace="prompts"),
),
),
)
return config

View file

@ -0,0 +1,287 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from abc import abstractmethod
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Literal
from pydantic import BaseModel, Field, field_validator
class StorageBackendType(StrEnum):
KV_REDIS = "kv_redis"
KV_SQLITE = "kv_sqlite"
KV_POSTGRES = "kv_postgres"
KV_MONGODB = "kv_mongodb"
SQL_SQLITE = "sql_sqlite"
SQL_POSTGRES = "sql_postgres"
class CommonConfig(BaseModel):
namespace: str | None = Field(
default=None,
description="All keys will be prefixed with this namespace",
)
class RedisKVStoreConfig(CommonConfig):
type: Literal[StorageBackendType.KV_REDIS] = StorageBackendType.KV_REDIS
host: str = "localhost"
port: int = 6379
@property
def url(self) -> str:
return f"redis://{self.host}:{self.port}"
@classmethod
def pip_packages(cls) -> list[str]:
return ["redis"]
@classmethod
def sample_run_config(cls):
return {
"type": StorageBackendType.KV_REDIS.value,
"host": "${env.REDIS_HOST:=localhost}",
"port": "${env.REDIS_PORT:=6379}",
}
class SqliteKVStoreConfig(CommonConfig):
type: Literal[StorageBackendType.KV_SQLITE] = StorageBackendType.KV_SQLITE
db_path: str = Field(
description="File path for the sqlite database",
)
@classmethod
def pip_packages(cls) -> list[str]:
return ["aiosqlite"]
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
return {
"type": StorageBackendType.KV_SQLITE.value,
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
}
class PostgresKVStoreConfig(CommonConfig):
type: Literal[StorageBackendType.KV_POSTGRES] = StorageBackendType.KV_POSTGRES
host: str = "localhost"
port: int | str = 5432
db: str = "llamastack"
user: str
password: str | None = None
ssl_mode: str | None = None
ca_cert_path: str | None = None
table_name: str = "llamastack_kvstore"
@classmethod
def sample_run_config(cls, table_name: str = "llamastack_kvstore", **kwargs):
return {
"type": StorageBackendType.KV_POSTGRES.value,
"host": "${env.POSTGRES_HOST:=localhost}",
"port": "${env.POSTGRES_PORT:=5432}",
"db": "${env.POSTGRES_DB:=llamastack}",
"user": "${env.POSTGRES_USER:=llamastack}",
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
"table_name": "${env.POSTGRES_TABLE_NAME:=" + table_name + "}",
}
@classmethod
@field_validator("table_name")
def validate_table_name(cls, v: str) -> str:
# PostgreSQL identifiers rules:
# - Must start with a letter or underscore
# - Can contain letters, numbers, and underscores
# - Maximum length is 63 bytes
pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$"
if not re.match(pattern, v):
raise ValueError(
"Invalid table name. Must start with letter or underscore and contain only letters, numbers, and underscores"
)
if len(v) > 63:
raise ValueError("Table name must be less than 63 characters")
return v
@classmethod
def pip_packages(cls) -> list[str]:
return ["psycopg2-binary"]
class MongoDBKVStoreConfig(CommonConfig):
type: Literal[StorageBackendType.KV_MONGODB] = StorageBackendType.KV_MONGODB
host: str = "localhost"
port: int = 27017
db: str = "llamastack"
user: str | None = None
password: str | None = None
collection_name: str = "llamastack_kvstore"
@classmethod
def pip_packages(cls) -> list[str]:
return ["pymongo"]
@classmethod
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
return {
"type": StorageBackendType.KV_MONGODB.value,
"host": "${env.MONGODB_HOST:=localhost}",
"port": "${env.MONGODB_PORT:=5432}",
"db": "${env.MONGODB_DB}",
"user": "${env.MONGODB_USER}",
"password": "${env.MONGODB_PASSWORD}",
"collection_name": "${env.MONGODB_COLLECTION_NAME:=" + collection_name + "}",
}
class SqlAlchemySqlStoreConfig(BaseModel):
@property
@abstractmethod
def engine_str(self) -> str: ...
# TODO: move this when we have a better way to specify dependencies with internal APIs
@classmethod
def pip_packages(cls) -> list[str]:
return ["sqlalchemy[asyncio]"]
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal[StorageBackendType.SQL_SQLITE] = StorageBackendType.SQL_SQLITE
db_path: str = Field(
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
)
@property
def engine_str(self) -> str:
return "sqlite+aiosqlite:///" + Path(self.db_path).expanduser().as_posix()
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"):
return {
"type": StorageBackendType.SQL_SQLITE.value,
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
}
@classmethod
def pip_packages(cls) -> list[str]:
return super().pip_packages() + ["aiosqlite"]
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal[StorageBackendType.SQL_POSTGRES] = StorageBackendType.SQL_POSTGRES
host: str = "localhost"
port: int | str = 5432
db: str = "llamastack"
user: str
password: str | None = None
@property
def engine_str(self) -> str:
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
@classmethod
def pip_packages(cls) -> list[str]:
return super().pip_packages() + ["asyncpg"]
@classmethod
def sample_run_config(cls, **kwargs):
return {
"type": StorageBackendType.SQL_POSTGRES.value,
"host": "${env.POSTGRES_HOST:=localhost}",
"port": "${env.POSTGRES_PORT:=5432}",
"db": "${env.POSTGRES_DB:=llamastack}",
"user": "${env.POSTGRES_USER:=llamastack}",
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
}
# reference = (backend_name, table_name)
class SqlStoreReference(BaseModel):
"""A reference to a 'SQL-like' persistent store. A table name must be provided."""
table_name: str = Field(
description="Name of the table to use for the SqlStore",
)
backend: str = Field(
description="Name of backend from storage.backends",
)
# reference = (backend_name, namespace)
class KVStoreReference(BaseModel):
"""A reference to a 'key-value' persistent store. A namespace must be provided."""
namespace: str = Field(
description="Key prefix for KVStore backends",
)
backend: str = Field(
description="Name of backend from storage.backends",
)
StorageBackendConfig = Annotated[
RedisKVStoreConfig
| SqliteKVStoreConfig
| PostgresKVStoreConfig
| MongoDBKVStoreConfig
| SqliteSqlStoreConfig
| PostgresSqlStoreConfig,
Field(discriminator="type"),
]
class InferenceStoreReference(SqlStoreReference):
"""Inference store configuration with queue tuning."""
max_write_queue_size: int = Field(
default=10000,
description="Max queued writes for inference store",
)
num_writers: int = Field(
default=4,
description="Number of concurrent background writers",
)
class ResponsesStoreReference(InferenceStoreReference):
"""Responses store configuration with queue tuning."""
class ServerStoresConfig(BaseModel):
metadata: KVStoreReference | None = Field(
default=None,
description="Metadata store configuration (uses KV backend)",
)
inference: InferenceStoreReference | None = Field(
default=None,
description="Inference store configuration (uses SQL backend)",
)
conversations: SqlStoreReference | None = Field(
default=None,
description="Conversations store configuration (uses SQL backend)",
)
responses: ResponsesStoreReference | None = Field(
default=None,
description="Responses store configuration (uses SQL backend)",
)
prompts: KVStoreReference | None = Field(
default=None,
description="Prompts store configuration (uses KV backend)",
)
class StorageConfig(BaseModel):
backends: dict[str, StorageBackendConfig] = Field(
description="Named backend configurations (e.g., 'default', 'cache')",
)
stores: ServerStoresConfig = Field(
default_factory=lambda: ServerStoresConfig(),
description="Named references to storage backends used by the stack core",
)

View file

@ -11,10 +11,9 @@ from typing import Protocol
import pydantic
from llama_stack.core.datatypes import RoutableObjectWithProvider
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
logger = get_logger(__name__, category="core::registry")
@ -96,11 +95,10 @@ class DiskDistributionRegistry(DistributionRegistry):
async def register(self, obj: RoutableObjectWithProvider) -> bool:
existing_obj = await self.get(obj.type, obj.identifier)
# dont register if the object's providerid already exists
if existing_obj and existing_obj.provider_id == obj.provider_id:
if existing_obj and existing_obj != obj:
raise ValueError(
f"Provider '{obj.provider_id}' is already registered."
f"Unregister the existing provider first before registering it again."
f"Object of type '{obj.type}' and identifier '{obj.identifier}' already exists. "
"Unregister it first if you want to replace it."
)
await self.kvstore.set(
@ -192,16 +190,10 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def create_dist_registry(
metadata_store: KVStoreConfig | None,
image_name: str,
metadata_store: KVStoreReference, image_name: str
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if metadata_store:
dist_kvstore = await kvstore_impl(metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix())
)
dist_kvstore = await kvstore_impl(metadata_store)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
await dist_registry.initialize()
return dist_registry, dist_kvstore

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .telemetry import Telemetry
from .trace_protocol import serialize_value, trace_protocol
from .tracing import (
CURRENT_TRACE_CONTEXT,
ROOT_SPAN_MARKERS,
end_trace,
enqueue_event,
get_current_span,
setup_logger,
span,
start_trace,
)
__all__ = [
"Telemetry",
"trace_protocol",
"serialize_value",
"CURRENT_TRACE_CONTEXT",
"ROOT_SPAN_MARKERS",
"end_trace",
"enqueue_event",
"get_current_span",
"setup_logger",
"span",
"start_trace",
]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import datetime
import os
import threading
from typing import Any
@ -13,43 +13,24 @@ from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExp
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from llama_stack.apis.telemetry import (
Event,
MetricEvent,
MetricLabelMatcher,
MetricQueryType,
QueryCondition,
QueryMetricsResponse,
QuerySpanTreeResponse,
QueryTracesResponse,
Span,
SpanEndPayload,
SpanStartPayload,
SpanStatus,
StructuredLogEvent,
Telemetry,
Trace,
UnstructuredLogEvent,
)
from llama_stack.core.datatypes import Api
from llama_stack.apis.telemetry import (
Telemetry as TelemetryBase,
)
from llama_stack.core.telemetry.tracing import ROOT_SPAN_MARKERS
from llama_stack.log import get_logger
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor,
)
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
SQLiteSpanProcessor,
)
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS
from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"active_spans": {},
@ -68,66 +49,48 @@ def is_tracing_enabled(tracer):
return span.is_recording()
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None:
self.config = config
self.datasetio_api = deps.get(Api.datasetio)
class Telemetry(TelemetryBase):
def __init__(self) -> None:
self.meter = None
resource = Resource.create(
{
ResourceAttributes.SERVICE_NAME: self.config.service_name,
}
)
global _TRACER_PROVIDER
# Initialize the correct span processor based on the provider state.
# This is needed since once the span processor is set, it cannot be unset.
# Recreating the telemetry adapter multiple times will result in duplicate span processors.
# Since the library client can be recreated multiple times in a notebook,
# the kernel will hold on to the span processor and cause duplicate spans to be written.
if _TRACER_PROVIDER is None:
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
_TRACER_PROVIDER = provider
if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
if _TRACER_PROVIDER is None:
provider = TracerProvider()
trace.set_tracer_provider(provider)
_TRACER_PROVIDER = provider
# Use single OTLP endpoint for all telemetry signals
if TelemetrySink.OTEL_TRACE in self.config.sinks or TelemetrySink.OTEL_METRIC in self.config.sinks:
if self.config.otel_exporter_otlp_endpoint is None:
raise ValueError(
"otel_exporter_otlp_endpoint is required when OTEL_TRACE or OTEL_METRIC is enabled"
)
# Use single OTLP endpoint for all telemetry signals
# Let OpenTelemetry SDK handle endpoint construction automatically
# The SDK will read OTEL_EXPORTER_OTLP_ENDPOINT and construct appropriate URLs
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
if TelemetrySink.OTEL_TRACE in self.config.sinks:
span_exporter = OTLPSpanExporter()
span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
span_exporter = OTLPSpanExporter()
span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
if TelemetrySink.OTEL_METRIC in self.config.sinks:
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(metric_provider)
if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(SQLiteSpanProcessor(self.config.sqlite_db_path))
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor(print_attributes=True))
if TelemetrySink.OTEL_METRIC in self.config.sinks:
self.meter = metrics.get_meter(__name__)
if TelemetrySink.SQLITE in self.config.sinks:
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
metric_provider = MeterProvider(metric_readers=[metric_reader])
metrics.set_meter_provider(metric_provider)
self.is_otel_endpoint_set = True
else:
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry")
self.is_otel_endpoint_set = False
self.meter = metrics.get_meter(__name__)
self._lock = _global_lock
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
trace.get_tracer_provider().force_flush()
if self.is_otel_endpoint_set:
trace.get_tracer_provider().force_flush()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent):
@ -139,47 +102,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
else:
raise ValueError(f"Unknown event type: {event}")
async def query_metrics(
self,
metric_name: str,
start_time: int,
end_time: int | None = None,
granularity: str | None = None,
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:
"""Query metrics from the telemetry store.
Args:
metric_name: The name of the metric to query (e.g., "prompt_tokens")
start_time: Start time as Unix timestamp
end_time: End time as Unix timestamp (defaults to now if None)
granularity: Time granularity for aggregation
query_type: Type of query (RANGE or INSTANT)
label_matchers: Label filters to apply
Returns:
QueryMetricsResponse with metric time series data
"""
# Convert timestamps to datetime objects
start_dt = datetime.datetime.fromtimestamp(start_time, datetime.UTC)
end_dt = datetime.datetime.fromtimestamp(end_time, datetime.UTC) if end_time else None
# Use SQLite trace store if available
if hasattr(self, "trace_store") and self.trace_store:
return await self.trace_store.query_metrics(
metric_name=metric_name,
start_time=start_dt,
end_time=end_dt,
granularity=granularity,
query_type=query_type,
label_matchers=label_matchers,
)
else:
raise ValueError(
f"In order to query_metrics, you must have {TelemetrySink.SQLITE} set in your telemetry sinks"
)
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage
@ -326,39 +248,3 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
else:
raise ValueError(f"Unknown structured log event: {event}")
async def query_traces(
self,
attribute_filters: list[QueryCondition] | None = None,
limit: int | None = 100,
offset: int | None = 0,
order_by: list[str] | None = None,
) -> QueryTracesResponse:
return QueryTracesResponse(
data=await self.trace_store.query_traces(
attribute_filters=attribute_filters,
limit=limit,
offset=offset,
order_by=order_by,
)
)
async def get_trace(self, trace_id: str) -> Trace:
return await self.trace_store.get_trace(trace_id)
async def get_span(self, trace_id: str, span_id: str) -> Span:
return await self.trace_store.get_span(trace_id, span_id)
async def get_span_tree(
self,
span_id: str,
attributes_to_return: list[str] | None = None,
max_depth: int | None = None,
) -> QuerySpanTreeResponse:
return QuerySpanTreeResponse(
data=await self.trace_store.get_span_tree(
span_id=span_id,
attributes_to_return=attributes_to_return,
max_depth=max_depth,
)
)

View file

@ -9,27 +9,29 @@ import inspect
import json
from collections.abc import AsyncGenerator, Callable
from functools import wraps
from typing import Any
from typing import Any, cast
from pydantic import BaseModel
from llama_stack.models.llama.datatypes import Primitive
type JSONValue = Primitive | list["JSONValue"] | dict[str, "JSONValue"]
def serialize_value(value: Any) -> Primitive:
def serialize_value(value: Any) -> str:
return str(_prepare_for_json(value))
def _prepare_for_json(value: Any) -> str:
def _prepare_for_json(value: Any) -> JSONValue:
"""Serialize a single value into JSON-compatible format."""
if value is None:
return ""
elif isinstance(value, str | int | float | bool):
return value
elif hasattr(value, "_name_"):
return value._name_
return cast(str, value._name_)
elif isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
return cast(JSONValue, json.loads(value.model_dump_json()))
elif isinstance(value, list | tuple | set):
return [_prepare_for_json(item) for item in value]
elif isinstance(value, dict):
@ -37,53 +39,53 @@ def _prepare_for_json(value: Any) -> str:
else:
try:
json.dumps(value)
return value
return cast(JSONValue, value)
except Exception:
return str(value)
def trace_protocol[T](cls: type[T]) -> type[T]:
def trace_protocol[T: type[Any]](cls: T) -> T:
"""
A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes.
"""
def trace_method(method: Callable) -> Callable:
def trace_method(method: Callable[..., Any]) -> Callable[..., Any]:
is_async = asyncio.iscoroutinefunction(method)
is_async_gen = inspect.isasyncgenfunction(method)
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple[str, str, dict[str, Primitive]]:
class_name = self.__class__.__name__
method_name = method.__name__
span_type = "async_generator" if is_async_gen else "async" if is_async else "sync"
sig = inspect.signature(method)
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
combined_args = {}
combined_args: dict[str, str] = {}
for i, arg in enumerate(args):
param_name = param_names[i] if i < len(param_names) else f"position_{i + 1}"
combined_args[param_name] = serialize_value(arg)
for k, v in kwargs.items():
combined_args[str(k)] = serialize_value(v)
span_attributes = {
span_attributes: dict[str, Primitive] = {
"__autotraced__": True,
"__class__": class_name,
"__method__": method_name,
"__type__": span_type,
"__args__": str(combined_args),
"__args__": json.dumps(combined_args),
}
return class_name, method_name, span_attributes
@wraps(method)
async def async_gen_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncGenerator:
from llama_stack.providers.utils.telemetry import tracing
async def async_gen_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
from llama_stack.core.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
count = 0
try:
count = 0
async for item in method(self, *args, **kwargs):
yield item
count += 1
@ -92,7 +94,7 @@ def trace_protocol[T](cls: type[T]) -> type[T]:
@wraps(method)
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.providers.utils.telemetry import tracing
from llama_stack.core.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
@ -107,7 +109,7 @@ def trace_protocol[T](cls: type[T]) -> type[T]:
@wraps(method)
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.providers.utils.telemetry import tracing
from llama_stack.core.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
@ -127,16 +129,17 @@ def trace_protocol[T](cls: type[T]) -> type[T]:
else:
return sync_wrapper
original_init_subclass = getattr(cls, "__init_subclass__", None)
original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None))
def __init_subclass__(cls_child, **kwargs): # noqa: N807
def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807
if original_init_subclass:
original_init_subclass(**kwargs)
cast(Callable[..., None], original_init_subclass)(**kwargs)
for name, method in vars(cls_child).items():
if inspect.isfunction(method) and not name.startswith("_"):
setattr(cls_child, name, trace_method(method)) # noqa: B010
cls.__init_subclass__ = classmethod(__init_subclass__)
cls_any = cast(Any, cls)
cls_any.__init_subclass__ = classmethod(__init_subclass__)
return cls

View file

@ -15,7 +15,7 @@ import time
from collections.abc import Callable
from datetime import UTC, datetime
from functools import wraps
from typing import Any
from typing import Any, Self
from llama_stack.apis.telemetry import (
Event,
@ -28,8 +28,8 @@ from llama_stack.apis.telemetry import (
Telemetry,
UnstructuredLogEvent,
)
from llama_stack.core.telemetry.trace_protocol import serialize_value
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
logger = get_logger(__name__, category="core")
@ -89,9 +89,6 @@ def generate_trace_id() -> str:
return trace_id_to_str(trace_id)
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
BACKGROUND_LOGGER = None
LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS = 60.0
@ -104,7 +101,7 @@ class BackgroundLogger:
self._last_queue_full_log_time: float = 0.0
self._dropped_since_last_notice: int = 0
def log_event(self, event):
def log_event(self, event: Event) -> None:
try:
self.log_queue.put_nowait(event)
except queue.Full:
@ -137,10 +134,13 @@ class BackgroundLogger:
finally:
self.log_queue.task_done()
def __del__(self):
def __del__(self) -> None:
self.log_queue.join()
BACKGROUND_LOGGER: BackgroundLogger | None = None
def enqueue_event(event: Event) -> None:
"""Enqueue a telemetry event to the background logger if available.
@ -155,13 +155,12 @@ def enqueue_event(event: Event) -> None:
class TraceContext:
spans: list[Span] = []
def __init__(self, logger: BackgroundLogger, trace_id: str):
self.logger = logger
self.trace_id = trace_id
self.spans: list[Span] = []
def push_span(self, name: str, attributes: dict[str, Any] = None) -> Span:
def push_span(self, name: str, attributes: dict[str, Any] | None = None) -> Span:
current_span = self.get_current_span()
span = Span(
span_id=generate_span_id(),
@ -188,7 +187,7 @@ class TraceContext:
self.spans.append(span)
return span
def pop_span(self, status: SpanStatus = SpanStatus.OK):
def pop_span(self, status: SpanStatus = SpanStatus.OK) -> None:
span = self.spans.pop()
if span is not None:
self.logger.log_event(
@ -203,10 +202,15 @@ class TraceContext:
)
)
def get_current_span(self):
def get_current_span(self) -> Span | None:
return self.spans[-1] if self.spans else None
CURRENT_TRACE_CONTEXT: contextvars.ContextVar[TraceContext | None] = contextvars.ContextVar(
"trace_context", default=None
)
def setup_logger(api: Telemetry, level: int = logging.INFO):
global BACKGROUND_LOGGER
@ -217,12 +221,12 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
root_logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceContext:
async def start_trace(name: str, attributes: dict[str, Any] | None = None) -> TraceContext | None:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
return
return None
trace_id = generate_trace_id()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
@ -269,7 +273,7 @@ def severity(levelname: str) -> LogSeverity:
# TODO: ideally, the actual emitting should be done inside a separate daemon
# process completely isolated from the server
class TelemetryHandler(logging.Handler):
def emit(self, record: logging.LogRecord):
def emit(self, record: logging.LogRecord) -> None:
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
if record.module in ("asyncio", "selector_events"):
return
@ -293,17 +297,17 @@ class TelemetryHandler(logging.Handler):
)
)
def close(self):
def close(self) -> None:
pass
class SpanContextManager:
def __init__(self, name: str, attributes: dict[str, Any] = None):
def __init__(self, name: str, attributes: dict[str, Any] | None = None):
self.name = name
self.attributes = attributes
self.span = None
self.span: Span | None = None
def __enter__(self):
def __enter__(self) -> Self:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
@ -313,7 +317,7 @@ class SpanContextManager:
self.span = context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback) -> None:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
@ -322,13 +326,13 @@ class SpanContextManager:
context.pop_span()
def set_attribute(self, key: str, value: Any):
def set_attribute(self, key: str, value: Any) -> None:
if self.span:
if self.span.attributes is None:
self.span.attributes = {}
self.span.attributes[key] = serialize_value(value)
async def __aenter__(self):
async def __aenter__(self) -> Self:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
@ -338,7 +342,7 @@ class SpanContextManager:
self.span = context.push_span(self.name, self.attributes)
return self
async def __aexit__(self, exc_type, exc_value, traceback):
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
@ -347,19 +351,19 @@ class SpanContextManager:
context.pop_span()
def __call__(self, func: Callable):
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def sync_wrapper(*args, **kwargs):
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
with self:
return func(*args, **kwargs)
@wraps(func)
async def async_wrapper(*args, **kwargs):
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
async with self:
return await func(*args, **kwargs)
@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> Any:
if asyncio.iscoroutinefunction(func):
return async_wrapper(*args, **kwargs)
else:
@ -368,7 +372,7 @@ class SpanContextManager:
return wrapper
def span(name: str, attributes: dict[str, Any] = None):
def span(name: str, attributes: dict[str, Any] | None = None) -> SpanContextManager:
return SpanContextManager(name, attributes)

View file

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from contextvars import ContextVar
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
TEST_CONTEXT: ContextVar[str | None] = ContextVar("llama_stack_test_context", default=None)
def get_test_context() -> str | None:
return TEST_CONTEXT.get()
def set_test_context(value: str | None):
return TEST_CONTEXT.set(value)
def reset_test_context(token) -> None:
TEST_CONTEXT.reset(token)
def sync_test_context_from_provider_data():
"""Sync test context from provider data when running in server test mode."""
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
return None
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
if stack_config_type != "server":
return None
try:
provider_data = PROVIDER_DATA_VAR.get()
except LookupError:
provider_data = None
if provider_data and "__test_id" in provider_data:
return TEST_CONTEXT.set(provider_data["__test_id"])
return None
def is_debug_mode() -> bool:
"""Check if test recording debug mode is enabled via LLAMA_STACK_TEST_DEBUG env var."""
return os.environ.get("LLAMA_STACK_TEST_DEBUG", "").lower() in ("1", "true", "yes")

View file

@ -9,7 +9,7 @@
1. Start up Llama Stack API server. More details [here](https://llamastack.github.io/latest/getting_started/index.htmll).
```
llama stack build --distro together --image-type venv
llama stack list-deps together | xargs -L1 uv pip install
llama stack run together
```

View file

@ -11,19 +11,17 @@ from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.core.ui.page.distribution.models import models
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
from llama_stack.core.ui.page.distribution.shields import shields
from llama_stack.core.ui.page.distribution.vector_dbs import vector_dbs
def resources_page():
options = [
"Models",
"Vector Databases",
"Shields",
"Scoring Functions",
"Datasets",
"Benchmarks",
]
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
icons = ["magic", "shield", "file-bar-graph", "database", "list-task"]
selected_resource = option_menu(
None,
options,
@ -37,8 +35,6 @@ def resources_page():
)
if selected_resource == "Benchmarks":
benchmarks()
elif selected_resource == "Vector Databases":
vector_dbs()
elif selected_resource == "Datasets":
datasets()
elif selected_resource == "Models":

View file

@ -1,20 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def vector_dbs():
st.header("Vector Databases")
vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
if len(vector_dbs_info) > 0:
selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
st.json(vector_dbs_info[selected_vector_db])
else:
st.info("No vector databases found")

View file

@ -1,301 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
import streamlit as st
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack.apis.common.content_types import ToolCallDelta
from llama_stack.core.ui.modules.api import llama_stack_api
from llama_stack.core.ui.modules.utils import data_url_from_file
def rag_chat_page():
st.title("🦙 RAG")
def reset_agent_and_chat():
st.session_state.clear()
st.cache_resource.clear()
def should_disable_input():
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
def log_message(message):
with st.chat_message(message["role"]):
if "tool_output" in message and message["tool_output"]:
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
st.write(message["tool_output"])
st.markdown(message["content"])
with st.sidebar:
# File/Directory Upload Section
st.subheader("Upload Documents", divider=True)
uploaded_files = st.file_uploader(
"Upload file(s) or directory",
accept_multiple_files=True,
type=["txt", "pdf", "doc", "docx"], # Add more file types as needed
)
# Process uploaded files
if uploaded_files:
st.success(f"Successfully uploaded {len(uploaded_files)} files")
# Add memory bank name input field
vector_db_name = st.text_input(
"Document Collection Name",
value="rag_vector_db",
help="Enter a unique identifier for this document collection",
)
if st.button("Create Document Collection"):
documents = [
RAGDocument(
document_id=uploaded_file.name,
content=data_url_from_file(uploaded_file),
)
for i, uploaded_file in enumerate(uploaded_files)
]
providers = llama_stack_api.client.providers.list()
vector_io_provider = None
for x in providers:
if x.api == "vector_io":
vector_io_provider = x.provider_id
llama_stack_api.client.vector_dbs.register(
vector_db_id=vector_db_name, # Use the user-provided name
embedding_dimension=384,
embedding_model="all-MiniLM-L6-v2",
provider_id=vector_io_provider,
)
# insert documents using the custom vector db name
llama_stack_api.client.tool_runtime.rag_tool.insert(
vector_db_id=vector_db_name, # Use the user-provided name
documents=documents,
chunk_size_in_tokens=512,
)
st.success("Vector database created successfully!")
st.subheader("RAG Parameters", divider=True)
rag_mode = st.radio(
"RAG mode",
["Direct", "Agent-based"],
captions=[
"RAG is performed by directly retrieving the information and augmenting the user query",
"RAG is performed by an agent activating a dedicated knowledge search tool.",
],
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# select memory banks
vector_dbs = llama_stack_api.client.vector_dbs.list()
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
selected_vector_dbs = st.multiselect(
label="Select Document Collections to use in RAG queries",
options=vector_dbs,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
st.subheader("Inference Parameters", divider=True)
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
label="Choose a model",
options=available_models,
index=0,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful assistant. ",
help="Initial instructions given to the AI to set its behavior and context",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True):
reset_agent_and_chat()
st.rerun()
# Chat Interface
if "messages" not in st.session_state:
st.session_state.messages = []
if "displayed_messages" not in st.session_state:
st.session_state.displayed_messages = []
# Display chat history
for message in st.session_state.displayed_messages:
log_message(message)
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
@st.cache_resource
def create_agent():
return Agent(
llama_stack_api.client,
model=selected_model,
instructions=system_prompt,
sampling_params={
"strategy": strategy,
},
tools=[
dict(
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": list(selected_vector_dbs),
},
)
],
)
if rag_mode == "Agent-based":
agent = create_agent()
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
session_id = st.session_state["agent_session_id"]
def agent_process_prompt(prompt):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Send the prompt to the agent
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
session_id=session_id,
)
# Display assistant response
with st.chat_message("assistant"):
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
for log in AgentEventLogger().log(response):
log.print()
if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip()
retrieval_message_placeholder.write(retrieval_response)
else:
full_response += log.content
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
st.session_state.displayed_messages.append(
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
)
def direct_process_prompt(prompt):
# Add the system prompt in the beginning of the conversation
if len(st.session_state.messages) == 0:
st.session_state.messages.append({"role": "system", "content": system_prompt})
# Query the vector DB
rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
content=prompt, vector_db_ids=list(selected_vector_dbs)
)
prompt_context = rag_response.content
with st.chat_message("assistant"):
with st.expander(label="Retrieval Output", expanded=False):
st.write(prompt_context)
retrieval_message_placeholder = st.empty()
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
# Construct the extended prompt
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
# Run inference directly
st.session_state.messages.append({"role": "user", "content": extended_prompt})
response = llama_stack_api.client.inference.chat_completion(
messages=st.session_state.messages,
model_id=selected_model,
sampling_params={
"strategy": strategy,
},
stream=True,
)
# Display assistant response
for chunk in response:
response_delta = chunk.event.delta
if isinstance(response_delta, ToolCallDelta):
retrieval_response += response_delta.tool_call.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response)
else:
full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
st.session_state.messages.append(response_dict)
st.session_state.displayed_messages.append(response_dict)
# Chat input
if prompt := st.chat_input("Ask a question about your documents"):
# Add user message to chat history
st.session_state.displayed_messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# store the prompt to process it after page refresh
st.session_state.prompt = prompt
# force page refresh to disable the settings widgets
st.rerun()
if "prompt" in st.session_state and st.session_state.prompt is not None:
if rag_mode == "Agent-based":
agent_process_prompt(st.session_state.prompt)
else: # rag_mode == "Direct"
direct_process_prompt(st.session_state.prompt)
st.session_state.prompt = None
rag_chat_page()

View file

@ -32,7 +32,7 @@ def tool_chat_page():
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
selected_vector_dbs = []
selected_vector_stores = []
def reset_agent():
st.session_state.clear()
@ -55,13 +55,13 @@ def tool_chat_page():
)
if "builtin::rag" in toolgroup_selection:
vector_dbs = llama_stack_api.client.vector_dbs.list() or []
if not vector_dbs:
vector_stores = llama_stack_api.client.vector_stores.list() or []
if not vector_stores:
st.info("No vector databases available for selection.")
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
selected_vector_dbs = st.multiselect(
vector_stores = [vector_store.identifier for vector_store in vector_stores]
selected_vector_stores = st.multiselect(
label="Select Document Collections to use in RAG queries",
options=vector_dbs,
options=vector_stores,
on_change=reset_agent,
)
@ -119,7 +119,7 @@ def tool_chat_page():
tool_dict = dict(
name="builtin::rag",
args={
"vector_db_ids": list(selected_vector_dbs),
"vector_store_ids": list(selected_vector_stores),
},
)
toolgroup_selection[i] = tool_dict

View file

@ -42,25 +42,25 @@ def resolve_config_or_distro(
# Strategy 1: Try as file path first
config_path = Path(config_or_distro)
if config_path.exists() and config_path.is_file():
logger.info(f"Using file path: {config_path}")
logger.debug(f"Using file path: {config_path}")
return config_path.resolve()
# Strategy 2: Try as distribution name (if no .yaml extension)
if not config_or_distro.endswith(".yaml"):
distro_config = _get_distro_config_path(config_or_distro, mode)
if distro_config.exists():
logger.info(f"Using distribution: {distro_config}")
logger.debug(f"Using distribution: {distro_config}")
return distro_config
# Strategy 3: Try as built distribution name
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
logger.debug(f"Using built distribution: {distrib_config}")
return distrib_config
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
logger.debug(f"Using built distribution: {distrib_config}")
return distrib_config
# Strategy 4: Failed - provide helpful error

View file

@ -25,6 +25,8 @@ distribution_spec:
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
- provider_type: remote::qdrant
- provider_type: remote::weaviate
files:
- provider_type: inline::localfs
safety:
@ -32,8 +34,6 @@ distribution_spec:
- provider_type: inline::code-scanner
agents:
- provider_type: inline::meta-reference
telemetry:
- provider_type: inline::meta-reference
post_training:
- provider_type: inline::torchtune-cpu
eval:

View file

@ -10,7 +10,6 @@ apis:
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
@ -94,30 +93,30 @@ providers:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/faiss_store.db
persistence:
namespace: vector_io::faiss
backend: kv_default
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec_registry.db
persistence:
namespace: vector_io::sqlite_vec
backend: kv_default
- provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus
config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/milvus_registry.db
persistence:
namespace: vector_io::milvus
backend: kv_default
- provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests/}/chroma_remote_registry.db
persistence:
namespace: vector_io::chroma_remote
backend: kv_default
- provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector
config:
@ -126,17 +125,32 @@ providers:
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
persistence:
namespace: vector_io::pgvector
backend: kv_default
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:
api_key: ${env.QDRANT_API_KEY:=}
persistence:
namespace: vector_io::qdrant_remote
backend: kv_default
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
provider_type: remote::weaviate
config:
weaviate_api_key: null
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
persistence:
namespace: vector_io::weaviate
backend: kv_default
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/files_metadata.db
table_name: files_metadata
backend: sql_default
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -148,20 +162,15 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
post_training:
- provider_id: torchtune-cpu
provider_type: inline::torchtune-cpu
@ -172,21 +181,21 @@ providers:
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/meta_reference_eval.db
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/huggingface_datasetio.db
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/localfs_datasetio.db
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
@ -216,30 +225,57 @@ providers:
provider_type: inline::reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/batches.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/inference_store.db
models: []
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
namespace: batches
backend: kv_default
storage:
backends:
kv_default:
type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/kvstore.db
sql_default:
type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sql_store.db
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models: []
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:
enabled: true
vector_stores:
default_provider_id: faiss
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
safety:
default_shield_id: llama-guard

View file

@ -14,8 +14,6 @@ distribution_spec:
- provider_type: inline::llama-guard
agents:
- provider_type: inline::meta-reference
telemetry:
- provider_type: inline::meta-reference
eval:
- provider_type: inline::meta-reference
datasetio:

View file

@ -32,7 +32,6 @@ def get_distribution_template() -> DistributionTemplate:
],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [
BuildProvider(provider_type="remote::huggingface"),
@ -87,11 +86,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="tgi1",
)
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
model_id="nomic-embed-text-v1.5",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
"embedding_dimension": 768,
},
)
default_tool_groups = [

View file

@ -157,7 +157,7 @@ docker run \
Make sure you have done `pip install llama-stack` and have the Llama Stack CLI available.
```bash
llama stack build --distro {{ name }} --image-type conda
llama stack list-deps {{ name }} | xargs -L1 pip install
INFERENCE_MODEL=$INFERENCE_MODEL \
DEH_URL=$DEH_URL \
CHROMA_URL=$CHROMA_URL \

View file

@ -7,7 +7,6 @@ apis:
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
@ -27,9 +26,9 @@ providers:
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db
persistence:
namespace: vector_io::chroma_remote
backend: kv_default
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -39,40 +38,35 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/meta_reference_eval.db
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/huggingface_datasetio.db
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/localfs_datasetio.db
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
@ -95,36 +89,56 @@ providers:
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: tgi0
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: tgi1
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields:
- shield_id: ${env.SAFETY_MODEL}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: brave-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
storage:
backends:
kv_default:
type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/kvstore.db
sql_default:
type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/sql_store.db
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: tgi0
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: tgi1
model_type: llm
- metadata:
embedding_dimension: 768
model_id: nomic-embed-text-v1.5
provider_id: sentence-transformers
model_type: embedding
shields:
- shield_id: ${env.SAFETY_MODEL}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: brave-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:
enabled: true

View file

@ -7,7 +7,6 @@ apis:
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
@ -23,9 +22,9 @@ providers:
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db
persistence:
namespace: vector_io::chroma_remote
backend: kv_default
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -35,40 +34,35 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/meta_reference_eval.db
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/huggingface_datasetio.db
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/localfs_datasetio.db
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
@ -91,31 +85,51 @@ providers:
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: tgi0
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: brave-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
storage:
backends:
kv_default:
type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/kvstore.db
sql_default:
type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/sql_store.db
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: tgi0
model_type: llm
- metadata:
embedding_dimension: 768
model_id: nomic-embed-text-v1.5
provider_id: sentence-transformers
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: brave-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:
enabled: true

View file

@ -12,8 +12,6 @@ distribution_spec:
- provider_type: inline::llama-guard
agents:
- provider_type: inline::meta-reference
telemetry:
- provider_type: inline::meta-reference
eval:
- provider_type: inline::meta-reference
datasetio:

View file

@ -29,31 +29,7 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ llama model list --downloaded
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
Please check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models using the Hugging Face CLI.
```
## Running the Distribution
@ -94,10 +70,10 @@ docker run \
### Via venv
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
Make sure you have the Llama Stack CLI available.
```bash
llama stack build --distro {{ name }} --image-type venv
llama stack list-deps meta-reference-gpu | xargs -L1 uv pip install
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
llama stack run distributions/{{ name }}/run.yaml \
--port 8321

View file

@ -34,7 +34,6 @@ def get_distribution_template() -> DistributionTemplate:
],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [
BuildProvider(provider_type="remote::huggingface"),
@ -77,11 +76,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="meta-reference-inference",
)
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
model_id="nomic-embed-text-v1.5",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
"embedding_dimension": 768,
},
)
safety_model = ModelInput(

View file

@ -7,7 +7,6 @@ apis:
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
@ -38,9 +37,9 @@ providers:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/faiss_store.db
persistence:
namespace: vector_io::faiss
backend: kv_default
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -50,40 +49,35 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/meta_reference_eval.db
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/huggingface_datasetio.db
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/localfs_datasetio.db
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
@ -108,36 +102,56 @@ providers:
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: meta-reference-safety
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields:
- shield_id: ${env.SAFETY_MODEL}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
storage:
backends:
kv_default:
type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/kvstore.db
sql_default:
type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/sql_store.db
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: meta-reference-safety
model_type: llm
- metadata:
embedding_dimension: 768
model_id: nomic-embed-text-v1.5
provider_id: sentence-transformers
model_type: embedding
shields:
- shield_id: ${env.SAFETY_MODEL}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:
enabled: true

View file

@ -7,7 +7,6 @@ apis:
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
@ -28,9 +27,9 @@ providers:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/faiss_store.db
persistence:
namespace: vector_io::faiss
backend: kv_default
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -40,40 +39,35 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/meta_reference_eval.db
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/huggingface_datasetio.db
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/localfs_datasetio.db
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
@ -98,31 +92,51 @@ providers:
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
storage:
backends:
kv_default:
type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/kvstore.db
sql_default:
type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/sql_store.db
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference
model_type: llm
- metadata:
embedding_dimension: 768
model_id: nomic-embed-text-v1.5
provider_id: sentence-transformers
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:
enabled: true

View file

@ -10,8 +10,6 @@ distribution_spec:
- provider_type: remote::nvidia
agents:
- provider_type: inline::meta-reference
telemetry:
- provider_type: inline::meta-reference
eval:
- provider_type: remote::nvidia
post_training:

View file

@ -126,11 +126,11 @@ docker run \
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
If you've set up your local development environment, you can also install the distribution dependencies using your local virtual environment.
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
llama stack build --distro nvidia --image-type venv
llama stack list-deps nvidia | xargs -L1 uv pip install
NVIDIA_API_KEY=$NVIDIA_API_KEY \
INFERENCE_MODEL=$INFERENCE_MODEL \
llama stack run ./run.yaml \

View file

@ -21,7 +21,6 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
"vector_io": [BuildProvider(provider_type="inline::faiss")],
"safety": [BuildProvider(provider_type="remote::nvidia")],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"eval": [BuildProvider(provider_type="remote::nvidia")],
"post_training": [BuildProvider(provider_type="remote::nvidia")],
"datasetio": [

View file

@ -9,7 +9,6 @@ apis:
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
@ -29,9 +28,9 @@ providers:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/faiss_store.db
persistence:
namespace: vector_io::faiss
backend: kv_default
safety:
- provider_id: nvidia
provider_type: remote::nvidia
@ -42,20 +41,15 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
eval:
- provider_id: nvidia
provider_type: remote::nvidia
@ -74,8 +68,8 @@ providers:
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/localfs_datasetio.db
namespace: datasetio::localfs
backend: kv_default
- provider_id: nvidia
provider_type: remote::nvidia
config:
@ -95,32 +89,52 @@ providers:
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/nvidia/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/files_metadata.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: nvidia
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: nvidia
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
table_name: files_metadata
backend: sql_default
storage:
backends:
kv_default:
type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/kvstore.db
sql_default:
type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/sql_store.db
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: nvidia
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: nvidia
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:
enabled: true

View file

@ -9,7 +9,6 @@ apis:
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
@ -24,9 +23,9 @@ providers:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/faiss_store.db
persistence:
namespace: vector_io::faiss
backend: kv_default
safety:
- provider_id: nvidia
provider_type: remote::nvidia
@ -37,20 +36,15 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
eval:
- provider_id: nvidia
provider_type: remote::nvidia
@ -84,22 +78,42 @@ providers:
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/nvidia/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/files_metadata.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
models: []
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
table_name: files_metadata
backend: sql_default
storage:
backends:
kv_default:
type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/kvstore.db
sql_default:
type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/sql_store.db
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models: []
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:
enabled: true

View file

@ -16,8 +16,6 @@ distribution_spec:
- provider_type: inline::llama-guard
agents:
- provider_type: inline::meta-reference
telemetry:
- provider_type: inline::meta-reference
eval:
- provider_type: inline::meta-reference
datasetio:

View file

@ -105,7 +105,6 @@ def get_distribution_template() -> DistributionTemplate:
],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [
BuildProvider(provider_type="remote::huggingface"),

View file

@ -7,7 +7,6 @@ apis:
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
@ -40,16 +39,16 @@ providers:
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec_registry.db
persistence:
namespace: vector_io::sqlite_vec
backend: kv_default
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/chroma_remote_registry.db
persistence:
namespace: vector_io::chroma_remote
backend: kv_default
- provider_id: ${env.ENABLE_PGVECTOR:+pgvector}
provider_type: remote::pgvector
config:
@ -58,9 +57,9 @@ providers:
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/pgvector_registry.db
persistence:
namespace: vector_io::pgvector
backend: kv_default
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -70,40 +69,35 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/meta_reference_eval.db
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/huggingface_datasetio.db
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/localfs_datasetio.db
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
@ -128,114 +122,134 @@ providers:
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db
models:
- metadata: {}
model_id: gpt-4o
provider_id: openai
provider_model_id: gpt-4o
model_type: llm
- metadata: {}
model_id: claude-3-5-sonnet-latest
provider_id: anthropic
provider_model_id: claude-3-5-sonnet-latest
model_type: llm
- metadata: {}
model_id: gemini/gemini-1.5-flash
provider_id: gemini
provider_model_id: gemini/gemini-1.5-flash
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: groq
provider_model_id: groq/llama-3.3-70b-versatile
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: together
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
model_type: llm
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: []
datasets:
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/simpleqa?split=train
metadata: {}
dataset_id: simpleqa
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/mmlu_cot?split=test&name=all
metadata: {}
dataset_id: mmlu_cot
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main
metadata: {}
dataset_id: gpqa_cot
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/math_500?split=test
metadata: {}
dataset_id: math_500
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/IfEval?split=train
metadata: {}
dataset_id: ifeval
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/docvqa?split=val
metadata: {}
dataset_id: docvqa
scoring_fns: []
benchmarks:
- dataset_id: simpleqa
scoring_functions:
- llm-as-judge::405b-simpleqa
metadata: {}
benchmark_id: meta-reference-simpleqa
- dataset_id: mmlu_cot
scoring_functions:
- basic::regex_parser_multiple_choice_answer
metadata: {}
benchmark_id: meta-reference-mmlu-cot
- dataset_id: gpqa_cot
scoring_functions:
- basic::regex_parser_multiple_choice_answer
metadata: {}
benchmark_id: meta-reference-gpqa-cot
- dataset_id: math_500
scoring_functions:
- basic::regex_parser_math_response
metadata: {}
benchmark_id: meta-reference-math-500
- dataset_id: ifeval
scoring_functions:
- basic::ifeval
metadata: {}
benchmark_id: meta-reference-ifeval
- dataset_id: docvqa
scoring_functions:
- basic::docvqa
metadata: {}
benchmark_id: meta-reference-docvqa
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
storage:
backends:
kv_default:
type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/kvstore.db
sql_default:
type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sql_store.db
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models:
- metadata: {}
model_id: gpt-4o
provider_id: openai
provider_model_id: gpt-4o
model_type: llm
- metadata: {}
model_id: claude-3-5-sonnet-latest
provider_id: anthropic
provider_model_id: claude-3-5-sonnet-latest
model_type: llm
- metadata: {}
model_id: gemini/gemini-1.5-flash
provider_id: gemini
provider_model_id: gemini/gemini-1.5-flash
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: groq
provider_model_id: groq/llama-3.3-70b-versatile
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: together
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
model_type: llm
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: []
datasets:
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/simpleqa?split=train
metadata: {}
dataset_id: simpleqa
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/mmlu_cot?split=test&name=all
metadata: {}
dataset_id: mmlu_cot
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main
metadata: {}
dataset_id: gpqa_cot
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/math_500?split=test
metadata: {}
dataset_id: math_500
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/IfEval?split=train
metadata: {}
dataset_id: ifeval
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/docvqa?split=val
metadata: {}
dataset_id: docvqa
scoring_fns: []
benchmarks:
- dataset_id: simpleqa
scoring_functions:
- llm-as-judge::405b-simpleqa
metadata: {}
benchmark_id: meta-reference-simpleqa
- dataset_id: mmlu_cot
scoring_functions:
- basic::regex_parser_multiple_choice_answer
metadata: {}
benchmark_id: meta-reference-mmlu-cot
- dataset_id: gpqa_cot
scoring_functions:
- basic::regex_parser_multiple_choice_answer
metadata: {}
benchmark_id: meta-reference-gpqa-cot
- dataset_id: math_500
scoring_functions:
- basic::regex_parser_math_response
metadata: {}
benchmark_id: meta-reference-math-500
- dataset_id: ifeval
scoring_functions:
- basic::ifeval
metadata: {}
benchmark_id: meta-reference-ifeval
- dataset_id: docvqa
scoring_functions:
- basic::docvqa
metadata: {}
benchmark_id: meta-reference-docvqa
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:
enabled: true

View file

@ -11,8 +11,6 @@ distribution_spec:
- provider_type: inline::llama-guard
agents:
- provider_type: inline::meta-reference
telemetry:
- provider_type: inline::meta-reference
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search

Some files were not shown because too many files have changed in this diff Show more