Merge remote-tracking branch 'origin/main' into stores
Some checks failed
Installer CI / smoke-test-on-dev (push) Failing after 3s
Installer CI / lint (push) Failing after 3s

This commit is contained in:
Ashwin Bharambe 2025-10-13 11:07:11 -07:00
commit b72154ce5e
1161 changed files with 609896 additions and 42960 deletions

View file

@ -797,7 +797,7 @@ class Agents(Protocol):
self,
response_id: str,
) -> OpenAIResponseObject:
"""Retrieve an OpenAI response by its ID.
"""Get a model response.
:param response_id: The ID of the OpenAI response to retrieve.
:returns: An OpenAIResponseObject.
@ -812,6 +812,7 @@ class Agents(Protocol):
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
@ -826,11 +827,12 @@ class Agents(Protocol):
),
] = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a new OpenAI response.
"""Create a model response.
:param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response.
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
:returns: An OpenAIResponseObject.
@ -846,7 +848,7 @@ class Agents(Protocol):
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
"""List all OpenAI responses.
"""List all responses.
:param after: The ID of the last response to return.
:param limit: The number of responses to return.
@ -869,7 +871,7 @@ class Agents(Protocol):
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items for a given OpenAI response.
"""List input items.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
@ -884,7 +886,7 @@ class Agents(Protocol):
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
"""Delete an OpenAI response by its ID.
"""Delete a response.
:param response_id: The ID of the OpenAI response to delete.
:returns: An OpenAIDeleteResponseObject

View file

@ -346,6 +346,174 @@ class OpenAIResponseText(BaseModel):
format: OpenAIResponseTextFormat | None = None
# Must match type Literals of OpenAIResponseInputToolWebSearch below
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
@json_schema_type
class OpenAIResponseInputToolWebSearch(BaseModel):
"""Web search tool configuration for OpenAI response inputs.
:param type: Web search tool type variant to use
:param search_context_size: (Optional) Size of search context, must be "low", "medium", or "high"
"""
# Must match values of WebSearchToolTypes above
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
"web_search"
)
# TODO: actually use search_context_size somewhere...
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
# TODO: add user_location
@json_schema_type
class OpenAIResponseInputToolFunction(BaseModel):
"""Function tool configuration for OpenAI response inputs.
:param type: Tool type identifier, always "function"
:param name: Name of the function that can be called
:param description: (Optional) Description of what the function does
:param parameters: (Optional) JSON schema defining the function's parameters
:param strict: (Optional) Whether to enforce strict parameter validation
"""
type: Literal["function"] = "function"
name: str
description: str | None = None
parameters: dict[str, Any] | None
strict: bool | None = None
@json_schema_type
class OpenAIResponseInputToolFileSearch(BaseModel):
"""File search tool configuration for OpenAI response inputs.
:param type: Tool type identifier, always "file_search"
:param vector_store_ids: List of vector store identifiers to search within
:param filters: (Optional) Additional filters to apply to the search
:param max_num_results: (Optional) Maximum number of search results to return (1-50)
:param ranking_options: (Optional) Options for ranking and scoring search results
"""
type: Literal["file_search"] = "file_search"
vector_store_ids: list[str]
filters: dict[str, Any] | None = None
max_num_results: int | None = Field(default=10, ge=1, le=50)
ranking_options: FileSearchRankingOptions | None = None
class ApprovalFilter(BaseModel):
"""Filter configuration for MCP tool approval requirements.
:param always: (Optional) List of tool names that always require approval
:param never: (Optional) List of tool names that never require approval
"""
always: list[str] | None = None
never: list[str] | None = None
class AllowedToolsFilter(BaseModel):
"""Filter configuration for restricting which MCP tools can be used.
:param tool_names: (Optional) List of specific tool names that are allowed
"""
tool_names: list[str] | None = None
@json_schema_type
class OpenAIResponseInputToolMCP(BaseModel):
"""Model Context Protocol (MCP) tool configuration for OpenAI response inputs.
:param type: Tool type identifier, always "mcp"
:param server_label: Label to identify this MCP server
:param server_url: URL endpoint of the MCP server
:param headers: (Optional) HTTP headers to include when connecting to the server
:param require_approval: Approval requirement for tool calls ("always", "never", or filter)
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
"""
type: Literal["mcp"] = "mcp"
server_label: str
server_url: str
headers: dict[str, Any] | None = None
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
allowed_tools: list[str] | AllowedToolsFilter | None = None
OpenAIResponseInputTool = Annotated[
OpenAIResponseInputToolWebSearch
| OpenAIResponseInputToolFileSearch
| OpenAIResponseInputToolFunction
| OpenAIResponseInputToolMCP,
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
@json_schema_type
class OpenAIResponseToolMCP(BaseModel):
"""Model Context Protocol (MCP) tool configuration for OpenAI response object.
:param type: Tool type identifier, always "mcp"
:param server_label: Label to identify this MCP server
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
"""
type: Literal["mcp"] = "mcp"
server_label: str
allowed_tools: list[str] | AllowedToolsFilter | None = None
OpenAIResponseTool = Annotated[
OpenAIResponseInputToolWebSearch
| OpenAIResponseInputToolFileSearch
| OpenAIResponseInputToolFunction
| OpenAIResponseToolMCP, # The only type that differes from that in the inputs is the MCP tool
Field(discriminator="type"),
]
register_schema(OpenAIResponseTool, name="OpenAIResponseTool")
class OpenAIResponseUsageOutputTokensDetails(BaseModel):
"""Token details for output tokens in OpenAI response usage.
:param reasoning_tokens: Number of tokens used for reasoning (o1/o3 models)
"""
reasoning_tokens: int | None = None
class OpenAIResponseUsageInputTokensDetails(BaseModel):
"""Token details for input tokens in OpenAI response usage.
:param cached_tokens: Number of tokens retrieved from cache
"""
cached_tokens: int | None = None
@json_schema_type
class OpenAIResponseUsage(BaseModel):
"""Usage information for OpenAI response.
:param input_tokens: Number of tokens in the input
:param output_tokens: Number of tokens in the output
:param total_tokens: Total tokens used (input + output)
:param input_tokens_details: Detailed breakdown of input token usage
:param output_tokens_details: Detailed breakdown of output token usage
"""
input_tokens: int
output_tokens: int
total_tokens: int
input_tokens_details: OpenAIResponseUsageInputTokensDetails | None = None
output_tokens_details: OpenAIResponseUsageOutputTokensDetails | None = None
@json_schema_type
class OpenAIResponseObject(BaseModel):
"""Complete OpenAI response object containing generation results and metadata.
@ -362,7 +530,9 @@ class OpenAIResponseObject(BaseModel):
:param temperature: (Optional) Sampling temperature used for generation
:param text: Text formatting configuration for the response
:param top_p: (Optional) Nucleus sampling parameter used for generation
:param tools: (Optional) An array of tools the model may call while generating a response.
:param truncation: (Optional) Truncation strategy applied to the response
:param usage: (Optional) Token usage information for the response
"""
created_at: int
@ -379,7 +549,9 @@ class OpenAIResponseObject(BaseModel):
# before the field was added. New responses will have this set always.
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
top_p: float | None = None
tools: list[OpenAIResponseTool] | None = None
truncation: str | None = None
usage: OpenAIResponseUsage | None = None
@json_schema_type
@ -400,7 +572,7 @@ class OpenAIDeleteResponseObject(BaseModel):
class OpenAIResponseObjectStreamResponseCreated(BaseModel):
"""Streaming event indicating a new response has been created.
:param response: The newly created response object
:param response: The response object that was created
:param type: Event type identifier, always "response.created"
"""
@ -408,11 +580,25 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
type: Literal["response.created"] = "response.created"
@json_schema_type
class OpenAIResponseObjectStreamResponseInProgress(BaseModel):
"""Streaming event indicating the response remains in progress.
:param response: Current response state while in progress
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.in_progress"
"""
response: OpenAIResponseObject
sequence_number: int
type: Literal["response.in_progress"] = "response.in_progress"
@json_schema_type
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
"""Streaming event indicating a response has been completed.
:param response: The completed response object
:param response: Completed response object
:param type: Event type identifier, always "response.completed"
"""
@ -420,6 +606,34 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
type: Literal["response.completed"] = "response.completed"
@json_schema_type
class OpenAIResponseObjectStreamResponseIncomplete(BaseModel):
"""Streaming event emitted when a response ends in an incomplete state.
:param response: Response object describing the incomplete state
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.incomplete"
"""
response: OpenAIResponseObject
sequence_number: int
type: Literal["response.incomplete"] = "response.incomplete"
@json_schema_type
class OpenAIResponseObjectStreamResponseFailed(BaseModel):
"""Streaming event emitted when a response fails.
:param response: Response object describing the failure
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.failed"
"""
response: OpenAIResponseObject
sequence_number: int
type: Literal["response.failed"] = "response.failed"
@json_schema_type
class OpenAIResponseObjectStreamResponseOutputItemAdded(BaseModel):
"""Streaming event for when a new output item is added to the response.
@ -650,19 +864,46 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
@json_schema_type
class OpenAIResponseContentPartOutputText(BaseModel):
"""Text content within a streamed response part.
:param type: Content part type identifier, always "output_text"
:param text: Text emitted for this content part
:param annotations: Structured annotations associated with the text
:param logprobs: (Optional) Token log probability details
"""
type: Literal["output_text"] = "output_text"
text: str
# TODO: add annotations, logprobs, etc.
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
logprobs: list[dict[str, Any]] | None = None
@json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel):
"""Refusal content within a streamed response part.
:param type: Content part type identifier, always "refusal"
:param refusal: Refusal text supplied by the model
"""
type: Literal["refusal"] = "refusal"
refusal: str
@json_schema_type
class OpenAIResponseContentPartReasoningText(BaseModel):
"""Reasoning text emitted as part of a streamed response.
:param type: Content part type identifier, always "reasoning_text"
:param text: Reasoning text supplied by the model
"""
type: Literal["reasoning_text"] = "reasoning_text"
text: str
OpenAIResponseContentPart = Annotated[
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal,
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal | OpenAIResponseContentPartReasoningText,
Field(discriminator="type"),
]
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
@ -672,15 +913,19 @@ register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
"""Streaming event for when a new content part is added to a response item.
:param content_index: Index position of the part within the content array
:param response_id: Unique identifier of the response containing this content
:param item_id: Unique identifier of the output item containing this content part
:param output_index: Index position of the output item in the response
:param part: The content part that was added
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.content_part.added"
"""
content_index: int
response_id: str
item_id: str
output_index: int
part: OpenAIResponseContentPart
sequence_number: int
type: Literal["response.content_part.added"] = "response.content_part.added"
@ -690,22 +935,269 @@ class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel):
"""Streaming event for when a content part is completed.
:param content_index: Index position of the part within the content array
:param response_id: Unique identifier of the response containing this content
:param item_id: Unique identifier of the output item containing this content part
:param output_index: Index position of the output item in the response
:param part: The completed content part
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.content_part.done"
"""
content_index: int
response_id: str
item_id: str
output_index: int
part: OpenAIResponseContentPart
sequence_number: int
type: Literal["response.content_part.done"] = "response.content_part.done"
@json_schema_type
class OpenAIResponseObjectStreamResponseReasoningTextDelta(BaseModel):
"""Streaming event for incremental reasoning text updates.
:param content_index: Index position of the reasoning content part
:param delta: Incremental reasoning text being added
:param item_id: Unique identifier of the output item being updated
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.reasoning_text.delta"
"""
content_index: int
delta: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.reasoning_text.delta"] = "response.reasoning_text.delta"
@json_schema_type
class OpenAIResponseObjectStreamResponseReasoningTextDone(BaseModel):
"""Streaming event for when reasoning text is completed.
:param content_index: Index position of the reasoning content part
:param text: Final complete reasoning text
:param item_id: Unique identifier of the completed output item
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.reasoning_text.done"
"""
content_index: int
text: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.reasoning_text.done"] = "response.reasoning_text.done"
@json_schema_type
class OpenAIResponseContentPartReasoningSummary(BaseModel):
"""Reasoning summary part in a streamed response.
:param type: Content part type identifier, always "summary_text"
:param text: Summary text
"""
type: Literal["summary_text"] = "summary_text"
text: str
@json_schema_type
class OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded(BaseModel):
"""Streaming event for when a new reasoning summary part is added.
:param item_id: Unique identifier of the output item
:param output_index: Index position of the output item
:param part: The summary part that was added
:param sequence_number: Sequential number for ordering streaming events
:param summary_index: Index of the summary part within the reasoning summary
:param type: Event type identifier, always "response.reasoning_summary_part.added"
"""
item_id: str
output_index: int
part: OpenAIResponseContentPartReasoningSummary
sequence_number: int
summary_index: int
type: Literal["response.reasoning_summary_part.added"] = "response.reasoning_summary_part.added"
@json_schema_type
class OpenAIResponseObjectStreamResponseReasoningSummaryPartDone(BaseModel):
"""Streaming event for when a reasoning summary part is completed.
:param item_id: Unique identifier of the output item
:param output_index: Index position of the output item
:param part: The completed summary part
:param sequence_number: Sequential number for ordering streaming events
:param summary_index: Index of the summary part within the reasoning summary
:param type: Event type identifier, always "response.reasoning_summary_part.done"
"""
item_id: str
output_index: int
part: OpenAIResponseContentPartReasoningSummary
sequence_number: int
summary_index: int
type: Literal["response.reasoning_summary_part.done"] = "response.reasoning_summary_part.done"
@json_schema_type
class OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta(BaseModel):
"""Streaming event for incremental reasoning summary text updates.
:param delta: Incremental summary text being added
:param item_id: Unique identifier of the output item
:param output_index: Index position of the output item
:param sequence_number: Sequential number for ordering streaming events
:param summary_index: Index of the summary part within the reasoning summary
:param type: Event type identifier, always "response.reasoning_summary_text.delta"
"""
delta: str
item_id: str
output_index: int
sequence_number: int
summary_index: int
type: Literal["response.reasoning_summary_text.delta"] = "response.reasoning_summary_text.delta"
@json_schema_type
class OpenAIResponseObjectStreamResponseReasoningSummaryTextDone(BaseModel):
"""Streaming event for when reasoning summary text is completed.
:param text: Final complete summary text
:param item_id: Unique identifier of the output item
:param output_index: Index position of the output item
:param sequence_number: Sequential number for ordering streaming events
:param summary_index: Index of the summary part within the reasoning summary
:param type: Event type identifier, always "response.reasoning_summary_text.done"
"""
text: str
item_id: str
output_index: int
sequence_number: int
summary_index: int
type: Literal["response.reasoning_summary_text.done"] = "response.reasoning_summary_text.done"
@json_schema_type
class OpenAIResponseObjectStreamResponseRefusalDelta(BaseModel):
"""Streaming event for incremental refusal text updates.
:param content_index: Index position of the content part
:param delta: Incremental refusal text being added
:param item_id: Unique identifier of the output item
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.refusal.delta"
"""
content_index: int
delta: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.refusal.delta"] = "response.refusal.delta"
@json_schema_type
class OpenAIResponseObjectStreamResponseRefusalDone(BaseModel):
"""Streaming event for when refusal text is completed.
:param content_index: Index position of the content part
:param refusal: Final complete refusal text
:param item_id: Unique identifier of the output item
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.refusal.done"
"""
content_index: int
refusal: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.refusal.done"] = "response.refusal.done"
@json_schema_type
class OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded(BaseModel):
"""Streaming event for when an annotation is added to output text.
:param item_id: Unique identifier of the item to which the annotation is being added
:param output_index: Index position of the output item in the response's output array
:param content_index: Index position of the content part within the output item
:param annotation_index: Index of the annotation within the content part
:param annotation: The annotation object being added
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.output_text.annotation.added"
"""
item_id: str
output_index: int
content_index: int
annotation_index: int
annotation: OpenAIResponseAnnotations
sequence_number: int
type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added"
@json_schema_type
class OpenAIResponseObjectStreamResponseFileSearchCallInProgress(BaseModel):
"""Streaming event for file search calls in progress.
:param item_id: Unique identifier of the file search call
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.file_search_call.in_progress"
"""
item_id: str
output_index: int
sequence_number: int
type: Literal["response.file_search_call.in_progress"] = "response.file_search_call.in_progress"
@json_schema_type
class OpenAIResponseObjectStreamResponseFileSearchCallSearching(BaseModel):
"""Streaming event for file search currently searching.
:param item_id: Unique identifier of the file search call
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.file_search_call.searching"
"""
item_id: str
output_index: int
sequence_number: int
type: Literal["response.file_search_call.searching"] = "response.file_search_call.searching"
@json_schema_type
class OpenAIResponseObjectStreamResponseFileSearchCallCompleted(BaseModel):
"""Streaming event for completed file search calls.
:param item_id: Unique identifier of the completed file search call
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.file_search_call.completed"
"""
item_id: str
output_index: int
sequence_number: int
type: Literal["response.file_search_call.completed"] = "response.file_search_call.completed"
OpenAIResponseObjectStream = Annotated[
OpenAIResponseObjectStreamResponseCreated
| OpenAIResponseObjectStreamResponseInProgress
| OpenAIResponseObjectStreamResponseOutputItemAdded
| OpenAIResponseObjectStreamResponseOutputItemDone
| OpenAIResponseObjectStreamResponseOutputTextDelta
@ -725,6 +1217,20 @@ OpenAIResponseObjectStream = Annotated[
| OpenAIResponseObjectStreamResponseMcpCallCompleted
| OpenAIResponseObjectStreamResponseContentPartAdded
| OpenAIResponseObjectStreamResponseContentPartDone
| OpenAIResponseObjectStreamResponseReasoningTextDelta
| OpenAIResponseObjectStreamResponseReasoningTextDone
| OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded
| OpenAIResponseObjectStreamResponseReasoningSummaryPartDone
| OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta
| OpenAIResponseObjectStreamResponseReasoningSummaryTextDone
| OpenAIResponseObjectStreamResponseRefusalDelta
| OpenAIResponseObjectStreamResponseRefusalDone
| OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded
| OpenAIResponseObjectStreamResponseFileSearchCallInProgress
| OpenAIResponseObjectStreamResponseFileSearchCallSearching
| OpenAIResponseObjectStreamResponseFileSearchCallCompleted
| OpenAIResponseObjectStreamResponseIncomplete
| OpenAIResponseObjectStreamResponseFailed
| OpenAIResponseObjectStreamResponseCompleted,
Field(discriminator="type"),
]
@ -760,114 +1266,6 @@ OpenAIResponseInput = Annotated[
register_schema(OpenAIResponseInput, name="OpenAIResponseInput")
# Must match type Literals of OpenAIResponseInputToolWebSearch below
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
@json_schema_type
class OpenAIResponseInputToolWebSearch(BaseModel):
"""Web search tool configuration for OpenAI response inputs.
:param type: Web search tool type variant to use
:param search_context_size: (Optional) Size of search context, must be "low", "medium", or "high"
"""
# Must match values of WebSearchToolTypes above
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
"web_search"
)
# TODO: actually use search_context_size somewhere...
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
# TODO: add user_location
@json_schema_type
class OpenAIResponseInputToolFunction(BaseModel):
"""Function tool configuration for OpenAI response inputs.
:param type: Tool type identifier, always "function"
:param name: Name of the function that can be called
:param description: (Optional) Description of what the function does
:param parameters: (Optional) JSON schema defining the function's parameters
:param strict: (Optional) Whether to enforce strict parameter validation
"""
type: Literal["function"] = "function"
name: str
description: str | None = None
parameters: dict[str, Any] | None
strict: bool | None = None
@json_schema_type
class OpenAIResponseInputToolFileSearch(BaseModel):
"""File search tool configuration for OpenAI response inputs.
:param type: Tool type identifier, always "file_search"
:param vector_store_ids: List of vector store identifiers to search within
:param filters: (Optional) Additional filters to apply to the search
:param max_num_results: (Optional) Maximum number of search results to return (1-50)
:param ranking_options: (Optional) Options for ranking and scoring search results
"""
type: Literal["file_search"] = "file_search"
vector_store_ids: list[str]
filters: dict[str, Any] | None = None
max_num_results: int | None = Field(default=10, ge=1, le=50)
ranking_options: FileSearchRankingOptions | None = None
class ApprovalFilter(BaseModel):
"""Filter configuration for MCP tool approval requirements.
:param always: (Optional) List of tool names that always require approval
:param never: (Optional) List of tool names that never require approval
"""
always: list[str] | None = None
never: list[str] | None = None
class AllowedToolsFilter(BaseModel):
"""Filter configuration for restricting which MCP tools can be used.
:param tool_names: (Optional) List of specific tool names that are allowed
"""
tool_names: list[str] | None = None
@json_schema_type
class OpenAIResponseInputToolMCP(BaseModel):
"""Model Context Protocol (MCP) tool configuration for OpenAI response inputs.
:param type: Tool type identifier, always "mcp"
:param server_label: Label to identify this MCP server
:param server_url: URL endpoint of the MCP server
:param headers: (Optional) HTTP headers to include when connecting to the server
:param require_approval: Approval requirement for tool calls ("always", "never", or filter)
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
"""
type: Literal["mcp"] = "mcp"
server_label: str
server_url: str
headers: dict[str, Any] | None = None
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
allowed_tools: list[str] | AllowedToolsFilter | None = None
OpenAIResponseInputTool = Annotated[
OpenAIResponseInputToolWebSearch
| OpenAIResponseInputToolFileSearch
| OpenAIResponseInputToolFunction
| OpenAIResponseInputToolMCP,
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
class ListOpenAIResponseInputItem(BaseModel):
"""List container for OpenAI response input items.

View file

@ -86,3 +86,18 @@ class TokenValidationError(ValueError):
def __init__(self, message: str) -> None:
super().__init__(message)
class ConversationNotFoundError(ResourceNotFoundError):
"""raised when Llama Stack cannot find a referenced conversation"""
def __init__(self, conversation_id: str) -> None:
super().__init__(conversation_id, "Conversation", "client.conversations.list()")
class InvalidConversationIdError(ValueError):
"""raised when a conversation ID has an invalid format"""
def __init__(self, conversation_id: str) -> None:
message = f"Invalid conversation ID '{conversation_id}'. Expected an ID that begins with 'conv_'."
super().__init__(message)

View file

@ -96,7 +96,6 @@ class Api(Enum, metaclass=DynamicApiMeta):
:cvar telemetry: Observability and system monitoring
:cvar models: Model metadata and management
:cvar shields: Safety shield implementations
:cvar vector_dbs: Vector database management
:cvar datasets: Dataset creation and management
:cvar scoring_functions: Scoring function definitions
:cvar benchmarks: Benchmark suite management
@ -122,7 +121,6 @@ class Api(Enum, metaclass=DynamicApiMeta):
models = "models"
shields = "shields"
vector_dbs = "vector_dbs"
datasets = "datasets"
scoring_functions = "scoring_functions"
benchmarks = "benchmarks"

View file

@ -104,6 +104,11 @@ class OpenAIFileDeleteResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Files(Protocol):
"""Files
This API is used to upload documents that can be used with other Llama Stack APIs.
"""
# OpenAI Files API Endpoints
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
@ -113,7 +118,8 @@ class Files(Protocol):
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
) -> OpenAIFileObject:
"""
"""Upload file.
Upload a file that can be used across various endpoints.
The file upload should be a multipart form request with:
@ -137,7 +143,8 @@ class Files(Protocol):
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
"""
"""List files.
Returns a list of files that belong to the user's organization.
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
@ -154,7 +161,8 @@ class Files(Protocol):
self,
file_id: str,
) -> OpenAIFileObject:
"""
"""Retrieve file.
Returns information about a specific file.
:param file_id: The ID of the file to use for this request.
@ -168,8 +176,7 @@ class Files(Protocol):
self,
file_id: str,
) -> OpenAIFileDeleteResponse:
"""
Delete a file.
"""Delete file.
:param file_id: The ID of the file to use for this request.
:returns: An OpenAIFileDeleteResponse indicating successful deletion.
@ -182,7 +189,8 @@ class Files(Protocol):
self,
file_id: str,
) -> Response:
"""
"""Retrieve file content.
Returns the contents of the specified file.
:param file_id: The ID of the file to use for this request.

View file

@ -14,6 +14,7 @@ from typing import (
runtime_checkable,
)
from fastapi import Body
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
@ -776,12 +777,14 @@ class OpenAIChoiceDelta(BaseModel):
:param refusal: (Optional) The refusal of the delta
:param role: (Optional) The role of the delta
:param tool_calls: (Optional) The tool calls of the delta
:param reasoning_content: (Optional) The reasoning content from the model (non-standard, for o1/o3 models)
"""
content: str | None = None
refusal: str | None = None
role: str | None = None
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
reasoning_content: str | None = None
@json_schema_type
@ -816,6 +819,42 @@ class OpenAIChoice(BaseModel):
logprobs: OpenAIChoiceLogprobs | None = None
class OpenAIChatCompletionUsageCompletionTokensDetails(BaseModel):
"""Token details for output tokens in OpenAI chat completion usage.
:param reasoning_tokens: Number of tokens used for reasoning (o1/o3 models)
"""
reasoning_tokens: int | None = None
class OpenAIChatCompletionUsagePromptTokensDetails(BaseModel):
"""Token details for prompt tokens in OpenAI chat completion usage.
:param cached_tokens: Number of tokens retrieved from cache
"""
cached_tokens: int | None = None
@json_schema_type
class OpenAIChatCompletionUsage(BaseModel):
"""Usage information for OpenAI chat completion.
:param prompt_tokens: Number of tokens in the prompt
:param completion_tokens: Number of tokens in the completion
:param total_tokens: Total tokens used (prompt + completion)
:param input_tokens_details: Detailed breakdown of input token usage
:param output_tokens_details: Detailed breakdown of output token usage
"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: OpenAIChatCompletionUsagePromptTokensDetails | None = None
completion_tokens_details: OpenAIChatCompletionUsageCompletionTokensDetails | None = None
@json_schema_type
class OpenAIChatCompletion(BaseModel):
"""Response from an OpenAI-compatible chat completion request.
@ -825,6 +864,7 @@ class OpenAIChatCompletion(BaseModel):
:param object: The object type, which will be "chat.completion"
:param created: The Unix timestamp in seconds when the chat completion was created
:param model: The model that was used to generate the chat completion
:param usage: Token usage information for the completion
"""
id: str
@ -832,6 +872,7 @@ class OpenAIChatCompletion(BaseModel):
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
usage: OpenAIChatCompletionUsage | None = None
@json_schema_type
@ -843,6 +884,7 @@ class OpenAIChatCompletionChunk(BaseModel):
:param object: The object type, which will be "chat.completion.chunk"
:param created: The Unix timestamp in seconds when the chat completion was created
:param model: The model that was used to generate the chat completion
:param usage: Token usage information (typically included in final chunk with stream_options)
"""
id: str
@ -850,6 +892,7 @@ class OpenAIChatCompletionChunk(BaseModel):
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
usage: OpenAIChatCompletionUsage | None = None
@json_schema_type
@ -995,6 +1038,127 @@ class ListOpenAIChatCompletionResponse(BaseModel):
object: Literal["list"] = "list"
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAICompletionRequestWithExtraBody(BaseModel, extra="allow"):
"""Request parameters for OpenAI-compatible completion endpoint.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
"""
# Standard OpenAI completion parameters
model: str
prompt: str | list[str] | list[int] | list[list[int]]
best_of: int | None = None
echo: bool | None = None
frequency_penalty: float | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_tokens: int | None = None
n: int | None = None
presence_penalty: float | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
top_p: float | None = None
user: str | None = None
suffix: str | None = None
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAIChatCompletionRequestWithExtraBody(BaseModel, extra="allow"):
"""Request parameters for OpenAI-compatible chat completion endpoint.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
"""
# Standard OpenAI chat completion parameters
model: str
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)]
frequency_penalty: float | None = None
function_call: str | dict[str, Any] | None = None
functions: list[dict[str, Any]] | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_completion_tokens: int | None = None
max_tokens: int | None = None
n: int | None = None
parallel_tool_calls: bool | None = None
presence_penalty: float | None = None
response_format: OpenAIResponseFormatParam | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
tool_choice: str | dict[str, Any] | None = None
tools: list[dict[str, Any]] | None = None
top_logprobs: int | None = None
top_p: float | None = None
user: str | None = None
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAIEmbeddingsRequestWithExtraBody(BaseModel, extra="allow"):
"""Request parameters for OpenAI-compatible embeddings endpoint.
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
:param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float".
:param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
:param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
"""
model: str
input: str | list[str]
encoding_format: str | None = "float"
dimensions: int | None = None
user: str | None = None
@runtime_checkable
@trace_protocol
class InferenceProvider(Protocol):
@ -1029,50 +1193,11 @@ class InferenceProvider(Protocol):
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
) -> OpenAICompletion:
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
"""Create completion.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
Generate an OpenAI-compatible completion for the given prompt using the specified model.
:returns: An OpenAICompletion.
"""
...
@ -1081,55 +1206,11 @@ class InferenceProvider(Protocol):
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
"""Create chat completions.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:returns: An OpenAIChatCompletion.
"""
...
@ -1138,26 +1219,20 @@ class InferenceProvider(Protocol):
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)],
) -> OpenAIEmbeddingsResponse:
"""Generate OpenAI-compatible embeddings for the given input using the specified model.
"""Create embeddings.
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
:param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float".
:param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
:param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
Generate OpenAI-compatible embeddings for the given input using the specified model.
:returns: An OpenAIEmbeddingsResponse containing the embeddings.
"""
...
class Inference(InferenceProvider):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
"""Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
@ -1173,7 +1248,7 @@ class Inference(InferenceProvider):
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
"""List all chat completions.
"""List chat completions.
:param after: The ID of the last chat completion to return.
:param limit: The maximum number of chat completions to return.
@ -1188,7 +1263,9 @@ class Inference(InferenceProvider):
)
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
"""Describe a chat completion by its ID.
"""Get chat completion.
Describe a chat completion by its ID.
:param completion_id: ID of the chat completion.
:returns: A OpenAICompletionWithInputMessages.

View file

@ -58,25 +58,36 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable
class Inspect(Protocol):
"""Inspect
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
"""
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
async def list_routes(self) -> ListRoutesResponse:
"""List all available API routes with their methods and implementing providers.
"""List routes.
List all available API routes with their methods and implementing providers.
:returns: Response containing information about all available routes.
"""
...
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1)
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
async def health(self) -> HealthInfo:
"""Get the current health status of the service.
"""Get health status.
Get the current health status of the service.
:returns: Health information indicating if the service is operational.
"""
...
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1)
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
async def version(self) -> VersionInfo:
"""Get the version of the service.
"""Get version.
Get the version of the service.
:returns: Version information containing the service version number.
"""

View file

@ -124,7 +124,9 @@ class Models(Protocol):
self,
model_id: str,
) -> Model:
"""Get a model by its identifier.
"""Get model.
Get a model by its identifier.
:param model_id: The identifier of the model to get.
:returns: A Model.
@ -140,7 +142,9 @@ class Models(Protocol):
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
"""Register a model.
"""Register model.
Register a model.
:param model_id: The identifier of the model to register.
:param provider_model_id: The identifier of the model in the provider.
@ -156,7 +160,9 @@ class Models(Protocol):
self,
model_id: str,
) -> None:
"""Unregister a model.
"""Unregister model.
Unregister a model.
:param model_id: The identifier of the model to unregister.
"""

View file

@ -94,7 +94,9 @@ class ListPromptsResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Prompts(Protocol):
"""Protocol for prompt management operations."""
"""Prompts
Protocol for prompt management operations."""
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
async def list_prompts(self) -> ListPromptsResponse:
@ -109,7 +111,9 @@ class Prompts(Protocol):
self,
prompt_id: str,
) -> ListPromptsResponse:
"""List all versions of a specific prompt.
"""List prompt versions.
List all versions of a specific prompt.
:param prompt_id: The identifier of the prompt to list versions for.
:returns: A ListPromptsResponse containing all versions of the prompt.
@ -122,7 +126,9 @@ class Prompts(Protocol):
prompt_id: str,
version: int | None = None,
) -> Prompt:
"""Get a prompt by its identifier and optional version.
"""Get prompt.
Get a prompt by its identifier and optional version.
:param prompt_id: The identifier of the prompt to get.
:param version: The version of the prompt to get (defaults to latest).
@ -136,7 +142,9 @@ class Prompts(Protocol):
prompt: str,
variables: list[str] | None = None,
) -> Prompt:
"""Create a new prompt.
"""Create prompt.
Create a new prompt.
:param prompt: The prompt text content with variable placeholders.
:param variables: List of variable names that can be used in the prompt template.
@ -153,7 +161,9 @@ class Prompts(Protocol):
variables: list[str] | None = None,
set_as_default: bool = True,
) -> Prompt:
"""Update an existing prompt (increments version).
"""Update prompt.
Update an existing prompt (increments version).
:param prompt_id: The identifier of the prompt to update.
:param prompt: The updated prompt text content.
@ -169,7 +179,9 @@ class Prompts(Protocol):
self,
prompt_id: str,
) -> None:
"""Delete a prompt.
"""Delete prompt.
Delete a prompt.
:param prompt_id: The identifier of the prompt to delete.
"""
@ -181,7 +193,9 @@ class Prompts(Protocol):
prompt_id: str,
version: int,
) -> Prompt:
"""Set which version of a prompt should be the default in get_prompt (latest).
"""Set prompt version.
Set which version of a prompt should be the default in get_prompt (latest).
:param prompt_id: The identifier of the prompt.
:param version: The version to set as default.

View file

@ -42,13 +42,16 @@ class ListProvidersResponse(BaseModel):
@runtime_checkable
class Providers(Protocol):
"""
"""Providers
Providers API for inspecting, listing, and modifying providers and their configurations.
"""
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
async def list_providers(self) -> ListProvidersResponse:
"""List all available providers.
"""List providers.
List all available providers.
:returns: A ListProvidersResponse containing information about all providers.
"""
@ -56,7 +59,9 @@ class Providers(Protocol):
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get detailed information about a specific provider.
"""Get provider.
Get detailed information about a specific provider.
:param provider_id: The ID of the provider to inspect.
:returns: A ProviderInfo object containing the provider's details.

View file

@ -9,7 +9,7 @@ from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.shields import Shield
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -96,16 +96,23 @@ class ShieldStore(Protocol):
@runtime_checkable
@trace_protocol
class Safety(Protocol):
"""Safety
OpenAI-compatible Moderations API.
"""
shield_store: ShieldStore
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any],
) -> RunShieldResponse:
"""Run a shield.
"""Run shield.
Run a shield.
:param shield_id: The identifier of the shield to run.
:param messages: The messages to run the shield on.
@ -117,7 +124,9 @@ class Safety(Protocol):
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
"""Classifies if text and/or image inputs are potentially harmful.
"""Create moderation.
Classifies if text and/or image inputs are potentially harmful.
:param input: Input (or inputs) to classify.
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
:param model: The content moderation model you would like to use.

View file

@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Protocol, runtime_checkable
from typing import Literal
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.schema_utils import json_schema_type
@json_schema_type
@ -61,57 +59,3 @@ class ListVectorDBsResponse(BaseModel):
"""
data: list[VectorDB]
@runtime_checkable
@trace_protocol
class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET", level=LLAMA_STACK_API_V1)
async def list_vector_dbs(self) -> ListVectorDBsResponse:
"""List all vector databases.
:returns: A ListVectorDBsResponse.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_vector_db(
self,
vector_db_id: str,
) -> VectorDB:
"""Get a vector database by its identifier.
:param vector_db_id: The identifier of the vector database to get.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs", method="POST", level=LLAMA_STACK_API_V1)
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
"""Register a vector database.
:param vector_db_id: The identifier of the vector database to register.
:param embedding_model: The embedding model to use.
:param embedding_dimension: The dimension of the embedding model.
:param provider_id: The identifier of the provider.
:param vector_db_name: The name of the vector database.
:param provider_vector_db_id: The identifier of the vector database in the provider.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database.
:param vector_db_id: The identifier of the vector database to unregister.
"""
...

View file

@ -11,6 +11,7 @@
import uuid
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from fastapi import Body
from pydantic import BaseModel, Field
from llama_stack.apis.inference import InterleavedContent
@ -466,6 +467,40 @@ class VectorStoreFilesListInBatchResponse(BaseModel):
has_more: bool = False
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAICreateVectorStoreRequestWithExtraBody(BaseModel, extra="allow"):
"""Request to create a vector store with extra_body support.
:param name: (Optional) A name for the vector store
:param file_ids: List of file IDs to include in the vector store
:param expires_after: (Optional) Expiration policy for the vector store
:param chunking_strategy: (Optional) Strategy for splitting files into chunks
:param metadata: Set of key-value pairs that can be attached to the vector store
"""
name: str | None = None
file_ids: list[str] | None = None
expires_after: dict[str, Any] | None = None
chunking_strategy: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None
# extra_body can be accessed via .model_extra
@json_schema_type
class OpenAICreateVectorStoreFileBatchRequestWithExtraBody(BaseModel, extra="allow"):
"""Request to create a vector store file batch with extra_body support.
:param file_ids: A list of File IDs that the vector store should use
:param attributes: (Optional) Key-value attributes to store with the files
:param chunking_strategy: (Optional) The chunking strategy used to chunk the file(s). Defaults to auto
"""
file_ids: list[str]
attributes: dict[str, Any] | None = None
chunking_strategy: VectorStoreChunkingStrategy | None = None
class VectorDBStore(Protocol):
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
@ -516,25 +551,11 @@ class VectorIO(Protocol):
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
async def openai_create_vector_store(
self,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
) -> VectorStoreObject:
"""Creates a vector store.
:param name: A name for the vector store.
:param file_ids: A list of File IDs that the vector store should use. Useful for tools like `file_search` that can access files.
:param expires_after: The expiration policy for a vector store.
:param chunking_strategy: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.
:param metadata: Set of 16 key-value pairs that can be attached to an object.
:param embedding_model: The embedding model to use for this vector store.
:param embedding_dimension: The dimension of the embedding vectors (default: 384).
:param provider_id: The ID of the provider to use for this vector store.
Generate an OpenAI-compatible vector store with the given parameters.
:returns: A VectorStoreObject representing the created vector store.
"""
...
@ -827,16 +848,12 @@ class VectorIO(Protocol):
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)],
) -> VectorStoreFileBatchObject:
"""Create a vector store file batch.
Generate an OpenAI-compatible vector store file batch for the given vector store.
:param vector_store_id: The ID of the vector store to create the file batch for.
:param file_ids: A list of File IDs that the vector store should use.
:param attributes: (Optional) Key-value attributes to store with the files.
:param chunking_strategy: (Optional) The chunking strategy used to chunk the file(s). Defaults to auto.
:returns: A VectorStoreFileBatchObject representing the created file batch.
"""
...

View file

@ -1,495 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import json
import os
import shutil
import sys
from dataclasses import dataclass
from datetime import UTC, datetime
from functools import partial
from pathlib import Path
import httpx
from pydantic import BaseModel, ConfigDict
from rich.console import Console
from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
TextColumn,
TimeRemainingColumn,
TransferSpeedColumn,
)
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
from llama_stack.models.llama.sku_types import Model
class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"download",
prog="llama download",
description="Download a model from llama.meta.com or Hugging Face Hub",
formatter_class=argparse.RawTextHelpFormatter,
)
setup_download_parser(self.parser)
def setup_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--source",
choices=["meta", "huggingface"],
default="meta",
)
parser.add_argument(
"--model-id",
required=False,
help="See `llama model list` or `llama model list --show-all` for the list of available models. Specify multiple model IDs with commas, e.g. --model-id Llama3.2-1B,Llama3.2-3B",
)
parser.add_argument(
"--hf-token",
type=str,
required=False,
default=None,
help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.",
)
parser.add_argument(
"--meta-url",
type=str,
required=False,
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
)
parser.add_argument(
"--max-parallel",
type=int,
required=False,
default=3,
help="Maximum number of concurrent downloads",
)
parser.add_argument(
"--ignore-patterns",
type=str,
required=False,
default="*.safetensors",
help="""For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
safetensors files to avoid downloading duplicate weights.
""",
)
parser.add_argument(
"--manifest-file",
type=str,
help="For source=meta, you can download models from a manifest file containing a file => URL mapping",
required=False,
)
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
@dataclass
class DownloadTask:
url: str
output_file: str
total_size: int = 0
downloaded_size: int = 0
task_id: int | None = None
retries: int = 0
max_retries: int = 3
class DownloadError(Exception):
pass
class CustomTransferSpeedColumn(TransferSpeedColumn):
def render(self, task):
if task.finished:
return "-"
return super().render(task)
class ParallelDownloader:
def __init__(
self,
max_concurrent_downloads: int = 3,
buffer_size: int = 1024 * 1024,
timeout: int = 30,
):
self.max_concurrent_downloads = max_concurrent_downloads
self.buffer_size = buffer_size
self.timeout = timeout
self.console = Console()
self.progress = Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(bar_width=40),
"[progress.percentage]{task.percentage:>3.1f}%",
DownloadColumn(),
CustomTransferSpeedColumn(),
TimeRemainingColumn(),
console=self.console,
expand=True,
)
self.client_options = {
"timeout": httpx.Timeout(timeout),
"follow_redirects": True,
}
async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
last_exception = None
for attempt in range(task.max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < task.max_retries - 1:
wait_time = min(30, 2**attempt) # Cap at 30 seconds
self.console.print(
f"[yellow]Attempt {attempt + 1}/{task.max_retries} failed, "
f"retrying in {wait_time} seconds: {str(e)}[/yellow]"
)
await asyncio.sleep(wait_time)
continue
raise last_exception
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
if task.total_size > 0:
self.progress.update(task.task_id, total=task.total_size)
return
async def _get_info():
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
response.raise_for_status()
return response
try:
response = await self.retry_with_exponential_backoff(task, _get_info)
task.url = str(response.url)
task.total_size = int(response.headers.get("Content-Length", 0))
if task.total_size == 0:
raise DownloadError(
f"Unable to determine file size for {task.output_file}. "
"The server might not support range requests."
)
# Update the progress bar's total size once we know it
if task.task_id is not None:
self.progress.update(task.task_id, total=task.total_size)
except httpx.HTTPError as e:
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
raise
def verify_file_integrity(self, task: DownloadTask) -> bool:
if not os.path.exists(task.output_file):
return False
return os.path.getsize(task.output_file) == task.total_size
async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
async def _download_chunk():
headers = {"Range": f"bytes={start}-{end}"}
async with client.stream("GET", task.url, headers=headers, **self.client_options) as response:
response.raise_for_status()
with open(task.output_file, "ab") as file:
file.seek(start)
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
task.downloaded_size += len(chunk)
self.progress.update(
task.task_id,
completed=task.downloaded_size,
)
try:
await self.retry_with_exponential_backoff(task, _download_chunk)
except Exception as e:
raise DownloadError(
f"Failed to download chunk {start}-{end} after {task.max_retries} attempts: {str(e)}"
) from e
async def prepare_download(self, task: DownloadTask) -> None:
output_dir = os.path.dirname(task.output_file)
os.makedirs(output_dir, exist_ok=True)
if os.path.exists(task.output_file):
task.downloaded_size = os.path.getsize(task.output_file)
async def download_file(self, task: DownloadTask) -> None:
try:
async with httpx.AsyncClient(**self.client_options) as client:
await self.get_file_info(client, task)
# Check if file is already downloaded
if os.path.exists(task.output_file):
if self.verify_file_integrity(task):
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
self.progress.update(task.task_id, completed=task.total_size)
return
await self.prepare_download(task)
try:
# Split the remaining download into chunks
chunk_size = 27_000_000_000 # Cloudfront max chunk size
chunks = []
current_pos = task.downloaded_size
while current_pos < task.total_size:
chunk_end = min(current_pos + chunk_size - 1, task.total_size - 1)
chunks.append((current_pos, chunk_end))
current_pos = chunk_end + 1
# Download chunks in sequence
for chunk_start, chunk_end in chunks:
await self.download_chunk(client, task, chunk_start, chunk_end)
except Exception as e:
raise DownloadError(f"Download failed: {str(e)}") from e
except Exception as e:
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
def has_disk_space(self, tasks: list[DownloadTask]) -> bool:
try:
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
free_space = shutil.disk_usage(dir_path).free
# Add 10% buffer for safety
required_space = int(total_remaining_size * 1.1)
if free_space < required_space:
self.console.print(
f"[red]Not enough disk space. Required: {required_space // (1024 * 1024)} MB, "
f"Available: {free_space // (1024 * 1024)} MB[/red]"
)
return False
return True
except Exception as e:
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
async def download_all(self, tasks: list[DownloadTask]) -> None:
if not tasks:
raise ValueError("No download tasks provided")
if not os.environ.get("LLAMA_DOWNLOAD_NO_SPACE_CHECK") and not self.has_disk_space(tasks):
raise DownloadError("Insufficient disk space for downloads")
failed_tasks = []
with self.progress:
for task in tasks:
desc = f"Downloading {Path(task.output_file).name}"
task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
async def download_with_semaphore(task: DownloadTask):
async with semaphore:
try:
await self.download_file(task)
except Exception as e:
failed_tasks.append((task, str(e)))
await asyncio.gather(*(download_with_semaphore(task) for task in tasks))
if failed_tasks:
self.console.print("\n[red]Some downloads failed:[/red]")
for task, error in failed_tasks:
self.console.print(f"[red]- {Path(task.output_file).name}: {error}[/red]")
raise DownloadError(f"{len(failed_tasks)} downloads failed")
def _hf_download(
model: "Model",
hf_token: str,
ignore_patterns: str,
parser: argparse.ArgumentParser,
):
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_stack.core.utils.model_utils import model_local_dir
repo_id = model.huggingface_repo
if repo_id is None:
raise ValueError(f"No repo id found for model {model.descriptor()}")
output_dir = model_local_dir(model.descriptor())
os.makedirs(output_dir, exist_ok=True)
try:
true_output_dir = snapshot_download(
repo_id,
local_dir=output_dir,
ignore_patterns=ignore_patterns,
token=hf_token,
library_name="llama-stack",
)
except GatedRepoError:
parser.error(
"It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository and have provided the proper Hugging Face API token "
"using the option `--hf-token` or by running `huggingface-cli login`."
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.")
except Exception as e:
parser.error(e)
print(f"\nSuccessfully downloaded model to {true_output_dir}")
def _meta_download(
model: "Model",
model_id: str,
meta_url: str,
info: "LlamaDownloadInfo",
max_concurrent_downloads: int,
):
from llama_stack.core.utils.model_utils import model_local_dir
output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True)
# Create download tasks for each file
tasks = []
for f in info.files:
output_file = str(output_dir / f)
url = meta_url.replace("*", f"{info.folder}/{f}")
total_size = info.pth_size if "consolidated" in f else 0
tasks.append(DownloadTask(url=url, output_file=output_file, total_size=total_size, max_retries=3))
# Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))
cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr)
cprint(
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
file=sys.stderr,
)
cprint(
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
color="yellow",
file=sys.stderr,
)
class ModelEntry(BaseModel):
model_id: str
files: dict[str, str]
model_config = ConfigDict(protected_namespaces=())
class Manifest(BaseModel):
models: list[ModelEntry]
expires_on: datetime
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
from llama_stack.core.utils.model_utils import model_local_dir
with open(manifest_file) as f:
d = json.load(f)
manifest = Manifest(**d)
if datetime.now(UTC) > manifest.expires_on.astimezone(UTC):
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console()
for entry in manifest.models:
console.print(f"[blue]Downloading model {entry.model_id}...[/blue]")
output_dir = Path(model_local_dir(entry.model_id))
os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()):
console.print(f"[yellow]Output directory {output_dir} is not empty.[/yellow]")
while True:
resp = input("Do you want to (C)ontinue download or (R)estart completely? (continue/restart): ")
if resp.lower() in ["restart", "r"]:
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
break
elif resp.lower() in ["continue", "c"]:
console.print("[blue]Continuing download...[/blue]")
break
else:
console.print("[red]Invalid response. Please try again.[/red]")
# Create download tasks for all files in the manifest
tasks = [
DownloadTask(url=url, output_file=str(output_dir / fname), max_retries=3)
for fname, url in entry.files.items()
]
# Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
"""Main download command handler"""
try:
if args.manifest_file:
_download_from_manifest(args.manifest_file, args.max_parallel)
return
if args.model_id is None:
parser.error("Please provide a model id")
return
# Handle comma-separated model IDs
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
from .model.safety_models import (
prompt_guard_download_info_map,
prompt_guard_model_sku_map,
)
prompt_guard_model_sku_map = prompt_guard_model_sku_map()
prompt_guard_download_info_map = prompt_guard_download_info_map()
for model_id in model_ids:
if model_id in prompt_guard_model_sku_map.keys():
model = prompt_guard_model_sku_map[model_id]
info = prompt_guard_download_info_map[model_id]
else:
model = resolve_model(model_id)
if model is None:
parser.error(f"Model {model_id} not found")
continue
info = llama_meta_net_info(model)
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url or input(
f"Please provide the signed URL for model {model_id} you received via email "
f"after visiting https://www.llama.com/llama-downloads/ "
f"(e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
if "llamameta.net" not in meta_url:
parser.error("Invalid Meta URL provided")
_meta_download(model, model_id, meta_url, info, args.max_parallel)
except Exception as e:
parser.error(f"Download failed: {str(e)}")

View file

@ -6,11 +6,8 @@
import argparse
from .download import Download
from .model import ModelParser
from .stack import StackParser
from .stack.utils import print_subcommand_description
from .verify_download import VerifyDownload
class LlamaCLIParser:
@ -30,10 +27,7 @@ class LlamaCLIParser:
subparsers = self.parser.add_subparsers(title="subcommands")
# Add sub-commands
ModelParser.create(subparsers)
StackParser.create(subparsers)
Download.create(subparsers)
VerifyDownload.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .model import ModelParser # noqa

View file

@ -1,70 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.sku_list import resolve_model
class ModelDescribe(Subcommand):
"""Show details about a model"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"describe",
prog="llama model describe",
description="Show details about a llama model",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_describe_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model-id",
type=str,
required=True,
help="See `llama model list` or `llama model list --show-all` for the list of available models",
)
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku_map
prompt_guard_model_map = prompt_guard_model_sku_map()
if args.model_id in prompt_guard_model_map.keys():
model = prompt_guard_model_map[args.model_id]
else:
model = resolve_model(args.model_id)
if model is None:
self.parser.error(
f"Model {args.model_id} not found; try 'llama model list' for a list of available models."
)
return
headers = [
"Model",
model.descriptor(),
]
rows = [
("Hugging Face ID", model.huggingface_repo or "<Not Available>"),
("Description", model.description),
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
("Weights format", model.quantization_format.value),
("Model params.json", json.dumps(model.arch_args, indent=4)),
]
print_table(
rows,
headers,
separate_rows=True,
)

View file

@ -1,24 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class ModelDownload(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"download",
prog="llama model download",
description="Download a model from llama.meta.com or Hugging Face Hub",
formatter_class=argparse.RawTextHelpFormatter,
)
from llama_stack.cli.download import setup_download_parser
setup_download_parser(self.parser)

View file

@ -1,119 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import time
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.models.llama.sku_list import all_registered_models
def _get_model_size(model_dir):
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
def _convert_to_model_descriptor(model):
for m in all_registered_models():
if model == m.descriptor().replace(":", "-"):
return str(m.descriptor())
return str(model)
def _run_model_list_downloaded_cmd() -> None:
headers = ["Model", "Size", "Modified Time"]
rows = []
for model in os.listdir(DEFAULT_CHECKPOINT_DIR):
abs_path = os.path.join(DEFAULT_CHECKPOINT_DIR, model)
space_usage = _get_model_size(abs_path)
model_size = f"{space_usage / (1024**3):.2f} GB"
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
rows.append(
[
_convert_to_model_descriptor(model),
model_size,
modified_time,
]
)
print_table(
rows,
headers,
separate_rows=True,
)
class ModelList(Subcommand):
"""List available llama models"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama model list",
description="Show available llama models",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_list_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--show-all",
action="store_true",
help="Show all models (not just defaults)",
)
self.parser.add_argument(
"--downloaded",
action="store_true",
help="List the downloaded models",
)
self.parser.add_argument(
"-s",
"--search",
type=str,
required=False,
help="Search for the input string as a substring in the model descriptor(ID)",
)
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_skus
if args.downloaded:
return _run_model_list_downloaded_cmd()
headers = [
"Model Descriptor(ID)",
"Hugging Face Repo",
"Context Length",
]
rows = []
for model in all_registered_models() + prompt_guard_model_skus():
if not args.show_all and not model.is_featured:
continue
descriptor = model.descriptor()
if not args.search or args.search.lower() in descriptor.lower():
rows.append(
[
descriptor,
model.huggingface_repo,
f"{model.max_seq_length // 1024}K",
]
)
if len(rows) == 0:
print(f"Did not find any model matching `{args.search}`.")
else:
print_table(
rows,
headers,
separate_rows=True,
)

View file

@ -1,43 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.model.describe import ModelDescribe
from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.remove import ModelRemove
from llama_stack.cli.model.verify_download import ModelVerifyDownload
from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand
class ModelParser(Subcommand):
"""Llama cli for model interface apis"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"model",
prog="llama model",
description="Work with llama models",
formatter_class=argparse.RawTextHelpFormatter,
)
self.parser.set_defaults(func=lambda args: self.parser.print_help())
subparsers = self.parser.add_subparsers(title="model_subcommands")
# Add sub-commands
ModelDownload.create(subparsers)
ModelList.create(subparsers)
ModelPromptFormat.create(subparsers)
ModelDescribe.create(subparsers)
ModelVerifyDownload.create(subparsers)
ModelRemove.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -1,133 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from io import StringIO
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
ROOT_DIR = Path(__file__).parent.parent.parent
class ModelPromptFormat(Subcommand):
"""Llama model cli for describe a model prompt format (message formats)"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"prompt-format",
prog="llama model prompt-format",
description="Show llama model message formats",
epilog=textwrap.dedent(
"""
Example:
llama model prompt-format <options>
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_template_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model-name",
type=str,
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
"(Run `llama model list` to see a list of valid model names)",
)
self.parser.add_argument(
"-l",
"--list",
action="store_true",
help="List all available models",
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
import importlib.resources
# Only Llama 3.1 and 3.2 are supported
supported_model_ids = [
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
]
model_list = [m.value for m in supported_model_ids]
if args.list:
headers = ["Model(s)"]
rows = []
for m in model_list:
rows.append(
[
m,
]
)
print_table(
rows,
headers,
separate_rows=True,
)
return
try:
model_id = CoreModelId(args.model_name)
except ValueError:
self.parser.error(
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
f"Run `llama model list` to see the valid model names."
)
if model_id not in supported_model_ids:
self.parser.error(
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
f"Run `llama model list` to see the valid model names."
)
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
llama_3_2_vision_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "vision_prompt_format.md"
if model_family(model_id) == ModelFamily.llama3_1:
with importlib.resources.as_file(llama_3_1_file) as f:
content = f.open("r").read()
elif model_family(model_id) == ModelFamily.llama3_2:
if is_multimodal(model_id):
with importlib.resources.as_file(llama_3_2_vision_file) as f:
content = f.open("r").read()
else:
with importlib.resources.as_file(llama_3_2_text_file) as f:
content = f.open("r").read()
render_markdown_to_pager(content)
def render_markdown_to_pager(markdown_content: str):
from rich.console import Console
from rich.markdown import Markdown
from rich.style import Style
from rich.text import Text
class LeftAlignedHeaderMarkdown(Markdown):
def parse_header(self, token):
level = token.type.count("h")
content = Text(token.content)
header_style = Style(color="bright_blue", bold=True)
header = Text(f"{'#' * level} ", style=header_style) + content
self.add_text(header)
# Render the Markdown
md = LeftAlignedHeaderMarkdown(markdown_content)
# Capture the rendered output
output = StringIO()
console = Console(file=output, force_terminal=True, width=100) # Set a fixed width
console.print(md)
rendered_content = output.getvalue()
print(rendered_content)

View file

@ -1,68 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import shutil
from llama_stack.cli.subcommand import Subcommand
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.models.llama.sku_list import resolve_model
class ModelRemove(Subcommand):
"""Remove the downloaded llama model"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"remove",
prog="llama model remove",
description="Remove the downloaded llama model",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_remove_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model",
required=True,
help="Specify the llama downloaded model name, see `llama model list --downloaded`",
)
self.parser.add_argument(
"-f",
"--force",
action="store_true",
help="Used to forcefully remove the llama model from the storage without further confirmation",
)
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku_map
prompt_guard_model_map = prompt_guard_model_sku_map()
if args.model in prompt_guard_model_map.keys():
model = prompt_guard_model_map[args.model]
else:
model = resolve_model(args.model)
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model.replace(":", "-"))
if model is None or not os.path.isdir(model_path):
print(f"'{args.model}' is not a valid llama model or does not exist.")
return
if args.force:
shutil.rmtree(model_path)
print(f"{args.model} removed.")
else:
if input(f"Are you sure you want to remove {args.model}? (y/n): ").strip().lower() == "y":
shutil.rmtree(model_path)
print(f"{args.model} removed.")
else:
print("Removal aborted.")

View file

@ -1,64 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
class PromptGuardModel(BaseModel):
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
model_id: str
huggingface_repo: str
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
is_featured: bool = False
max_seq_length: int = 512
is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
arch_args: dict[str, Any] = Field(default_factory=dict)
def descriptor(self) -> str:
return self.model_id
model_config = ConfigDict(protected_namespaces=())
def prompt_guard_model_skus():
return [
PromptGuardModel(model_id="Prompt-Guard-86M", huggingface_repo="meta-llama/Prompt-Guard-86M"),
PromptGuardModel(
model_id="Llama-Prompt-Guard-2-86M",
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-86M",
),
PromptGuardModel(
model_id="Llama-Prompt-Guard-2-22M",
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-22M",
),
]
def prompt_guard_model_sku_map() -> dict[str, Any]:
return {model.model_id: model for model in prompt_guard_model_skus()}
def prompt_guard_download_info_map() -> dict[str, LlamaDownloadInfo]:
return {
model.model_id: LlamaDownloadInfo(
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
files=[
"model.safetensors",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer_config.json",
],
pth_size=1,
)
for model in prompt_guard_model_skus()
}

View file

@ -1,24 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class ModelVerifyDownload(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"verify-download",
prog="llama model verify-download",
description="Verify the downloaded checkpoints' checksums for models downloaded from Meta",
formatter_class=argparse.RawTextHelpFormatter,
)
from llama_stack.cli.verify_download import setup_verify_download_parser
setup_verify_download_parser(self.parser)

View file

@ -439,12 +439,24 @@ def _run_stack_build_command_from_build_config(
cprint("Build Successful!", color="green", file=sys.stderr)
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
cprint(
"You can run the new Llama Stack distro via: "
+ colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"),
color="green",
file=sys.stderr,
)
if build_config.image_type == LlamaStackImageType.VENV:
cprint(
"You can run the new Llama Stack distro (after activating "
+ colored(image_name, "cyan")
+ ") via: "
+ colored(f"llama stack run {run_config_file}", "blue"),
color="green",
file=sys.stderr,
)
elif build_config.image_type == LlamaStackImageType.CONTAINER:
cprint(
"You can run the container with: "
+ colored(
f"docker run -p 8321:8321 -v ~/.llama:/root/.llama localhost/{image_name} --port 8321", "blue"
),
color="green",
file=sys.stderr,
)
return distro_path
else:
return _generate_run_config(build_config, build_dir, image_name)

View file

@ -6,11 +6,18 @@
import argparse
import os
import ssl
import subprocess
from pathlib import Path
import uvicorn
import yaml
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.subcommand import Subcommand
from llama_stack.core.datatypes import LoggingConfig, StackRunConfig
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -48,18 +55,12 @@ class StackRun(Subcommand):
"--image-name",
type=str,
default=None,
help="Name of the image to run. Defaults to the current environment",
)
self.parser.add_argument(
"--env",
action="append",
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
metavar="KEY=VALUE",
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
)
self.parser.add_argument(
"--image-type",
type=str,
help="Image Type used during the build. This can be only venv.",
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value],
)
self.parser.add_argument(
@ -68,48 +69,22 @@ class StackRun(Subcommand):
help="Start the UI server",
)
def _resolve_config_and_distro(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
"""Resolve config file path and distribution name from args.config"""
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
if not args.config:
return None, None
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
distro_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a distribution
config_file = Path(REPO_ROOT) / "llama_stack" / "distributions" / args.config / "run.yaml"
if config_file.exists():
distro_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
return config_file, distro_name
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
from llama_stack.core.configure import parse_and_maybe_upgrade_config
from llama_stack.core.utils.exec import formulate_run_args, run_command
if args.image_type or args.image_name:
self.parser.error(
"The --image-type and --image-name flags are no longer supported.\n\n"
"Please activate your virtual environment manually before running `llama stack run`.\n\n"
"For example:\n"
" source /path/to/venv/bin/activate\n"
" llama stack run <config>\n"
)
if args.enable_ui:
self._start_ui_development_server(args.port)
image_type, image_name = args.image_type, args.image_name
if args.config:
try:
@ -121,10 +96,6 @@ class StackRun(Subcommand):
else:
config_file = None
# Check if config is required based on image type
if image_type == ImageType.VENV.value and not config_file:
self.parser.error("Config file is required for venv environment")
if config_file:
logger.info(f"Using run configuration: {config_file}")
@ -139,50 +110,67 @@ class StackRun(Subcommand):
os.makedirs(str(config.external_providers_dir), exist_ok=True)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
self._uvicorn_run(config_file, args)
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
if not config_file:
self.parser.error("Config file is required")
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
port = args.port or config.server.port
host = config.server.host or ["::", "0.0.0.0"]
# Set the config file in environment so create_app can find it
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
uvicorn_config = {
"factory": True,
"host": host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
}
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
uvicorn_config["ssl_keyfile"] = config.server.tls_keyfile
uvicorn_config["ssl_certfile"] = config.server.tls_certfile
if config.server.tls_cafile:
uvicorn_config["ssl_ca_certs"] = config.server.tls_cafile
uvicorn_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
config = None
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
# If neither image type nor image name is provided, assume the server should be run directly
# using the current environment packages.
if not image_type and not image_name:
logger.info("No image type or image name provided. Assuming environment packages.")
from llama_stack.core.server.server import main as server_main
logger.info(f"Listening on {host}:{port}")
# Build the server args from the current args passed to the CLI
server_args = argparse.Namespace()
for arg in vars(args):
# If this is a function, avoid passing it
# "args" contains:
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
if callable(getattr(args, arg)):
continue
if arg == "config":
server_args.config = str(config_file)
else:
setattr(server_args, arg, getattr(args, arg))
# Run the server
server_main(server_args)
else:
run_args = formulate_run_args(image_type, image_name)
run_args.extend([str(args.port)])
if config_file:
run_args.extend(["--config", str(config_file)])
if args.env:
for env_var in args.env:
if "=" not in env_var:
self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
return
key, value = env_var.split("=", 1) # split on first = only
if not key:
self.parser.error(f"Environment variable '{env_var}' has empty key")
return
run_args.extend(["--env", f"{key}={value}"])
run_command(run_args)
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
# stack trace when using Ctrl+C or kill -2 (SIGINT).
# SIGTERM (kill -15) works fine without this because Python doesn't
# have a default handler for it.
#
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
uvicorn.run("llama_stack.core.server.server:create_app", **uvicorn_config)
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
def _start_ui_development_server(self, stack_server_port: int):
logger.info("Attempting to start UI development server...")

View file

@ -1,141 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import hashlib
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from llama_stack.cli.subcommand import Subcommand
@dataclass
class VerificationResult:
filename: str
expected_hash: str
actual_hash: str | None
exists: bool
matches: bool
class VerifyDownload(Subcommand):
"""Llama cli for verifying downloaded model files"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"verify-download",
prog="llama verify-download",
description="Verify integrity of downloaded model files",
formatter_class=argparse.RawTextHelpFormatter,
)
setup_verify_download_parser(self.parser)
def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--model-id",
required=True,
help="Model ID to verify (only for models downloaded from Meta)",
)
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
def calculate_sha256(filepath: Path, chunk_size: int = 8192) -> str:
sha256_hash = hashlib.sha256()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
def load_checksums(checklist_path: Path) -> dict[str, str]:
checksums = {}
with open(checklist_path) as f:
for line in f:
if line.strip():
sha256sum, filepath = line.strip().split(" ", 1)
# Remove leading './' if present
filepath = filepath.lstrip("./")
checksums[filepath] = sha256sum
return checksums
def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -> list[VerificationResult]:
results = []
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
for filepath, expected_hash in checksums.items():
full_path = model_dir / filepath
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
exists = full_path.exists()
actual_hash = None
matches = False
if exists:
actual_hash = calculate_sha256(full_path)
matches = actual_hash == expected_hash
results.append(
VerificationResult(
filename=filepath,
expected_hash=expected_hash,
actual_hash=actual_hash,
exists=exists,
matches=matches,
)
)
progress.remove_task(task_id)
return results
def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_stack.core.utils.model_utils import model_local_dir
console = Console()
model_dir = Path(model_local_dir(args.model_id))
checklist_path = model_dir / "checklist.chk"
if not model_dir.exists():
parser.error(f"Model directory not found: {model_dir}")
if not checklist_path.exists():
parser.error(f"Checklist file not found: {checklist_path}")
checksums = load_checksums(checklist_path)
results = verify_files(model_dir, checksums, console)
# Print results
console.print("\nVerification Results:")
all_good = True
for result in results:
if not result.exists:
console.print(f"[red]❌ {result.filename}: File not found[/red]")
all_good = False
elif not result.matches:
console.print(
f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
f" Expected: {result.expected_hash}\n"
f" Got: {result.actual_hash}"
)
all_good = False
else:
console.print(f"[green]✓ {result.filename}: Verified[/green]")
if all_good:
console.print("\n[green]All files verified successfully![/green]")

View file

@ -324,14 +324,14 @@ fi
RUN pip uninstall -y uv
EOF
# If a run config is provided, we use the --config flag
# If a run config is provided, we use the llama stack CLI
if [[ -n "$run_config" ]]; then
add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$RUN_CONFIG_PATH"]
ENTRYPOINT ["llama", "stack", "run", "$RUN_CONFIG_PATH"]
EOF
elif [[ "$distro_or_config" != *.yaml ]]; then
add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$distro_or_config"]
ENTRYPOINT ["llama", "stack", "run", "$distro_or_config"]
EOF
fi

View file

@ -32,7 +32,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import (
sqlstore_impl,
)
logger = get_logger(name=__name__, category="openai::conversations")
logger = get_logger(name=__name__, category="openai_conversations")
class ConversationServiceConfig(BaseModel):
@ -196,12 +196,15 @@ class ConversationServiceImpl(Conversations):
await self._get_validated_conversation(conversation_id)
created_items = []
created_at = int(time.time())
base_time = int(time.time())
for item in items:
for i, item in enumerate(items):
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
# make each timestamp unique to maintain order
created_at = base_time + i
item_record = {
"id": item_id,
"conversation_id": conversation_id,

View file

@ -47,10 +47,6 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
routing_table_api=Api.shields,
router_api=Api.safety,
),
AutoRoutedApiInfo(
routing_table_api=Api.vector_dbs,
router_api=Api.vector_io,
),
AutoRoutedApiInfo(
routing_table_api=Api.datasets,
router_api=Api.datasetio,
@ -243,6 +239,7 @@ def get_external_providers_from_module(
spec = module.get_provider_spec()
else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
# in the case we are building we CANNOT import this module of course because it has not been installed.
spec = ProviderSpec(
api=Api(provider_api),
provider_type=provider.provider_type,
@ -251,9 +248,20 @@ def get_external_providers_from_module(
config_class="",
)
provider_type = provider.provider_type
# in the case we are building we CANNOT import this module of course because it has not been installed.
# return a partially filled out spec that the build script will populate.
registry[Api(provider_api)][provider_type] = spec
if isinstance(spec, list):
# optionally allow people to pass inline and remote provider specs as a returned list.
# with the old method, users could pass in directories of specs using overlapping code
# we want to ensure we preserve that flexibility in this method.
logger.info(
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
)
for provider_spec in spec:
if provider_spec.provider_type != provider.provider_type:
continue
logger.info(f"Adding {provider.provider_type} to registry")
registry[Api(provider_api)][provider.provider_type] = provider_spec
else:
registry[Api(provider_api)][provider_type] = spec
except ModuleNotFoundError as exc:
raise ValueError(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable
IdFactory = Callable[[], str]
IdOverride = Callable[[str, IdFactory], str]
_id_override: IdOverride | None = None
def generate_object_id(kind: str, factory: IdFactory) -> str:
"""Generate an identifier for the given kind using the provided factory.
Allows tests to override ID generation deterministically by installing an
override callback via :func:`set_id_override`.
"""
override = _id_override
if override is not None:
return override(kind, factory)
return factory()
def set_id_override(override: IdOverride) -> IdOverride | None:
"""Install an override used to generate deterministic identifiers."""
global _id_override
previous = _id_override
_id_override = override
return previous
def reset_id_override(previous: IdOverride | None) -> None:
"""Restore the previous override returned by :func:`set_id_override`."""
global _id_override
_id_override = previous

View file

@ -54,6 +54,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
setup_logger,
start_trace,
)
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
logger = get_logger(name=__name__, category="core")
@ -383,7 +384,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
@ -446,7 +447,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
# Prepare body for the function call (handles both Pydantic and traditional params)
body = self._convert_body(func, body)
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
@ -493,21 +495,32 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return await response.parse()
def _convert_body(
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
) -> dict:
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
if not body:
return {}
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
exclude_params = exclude_params or set()
func, _, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func)
params_list = [p for p in sig.parameters.values() if p.name != "self"]
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
if len(params_list) == 1:
param = params_list[0]
param_type = param.annotation
if is_unwrapped_body_param(param_type):
base_type = get_args(param_type)[0]
return {param.name: base_type(**body)}
# Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
# Check if there's an unwrapped body parameter among multiple parameters
# (e.g., path param + body param like: vector_store_id: str, params: Annotated[Model, Body(...)])
unwrapped_body_param = None
for param in params_list:
if is_unwrapped_body_param(param.annotation):
unwrapped_body_param = param
break
# Convert parameters to Pydantic models where needed
converted_body = {}
for param_name, param in sig.parameters.items():
@ -517,5 +530,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
converted_body[param_name] = value
else:
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
elif unwrapped_body_param and param.name == unwrapped_body_param.name:
# This is the unwrapped body param - construct it from remaining body keys
base_type = get_args(param.annotation)[0]
# Extract only the keys that aren't already used by other params
remaining_keys = {k: v for k, v in body.items() if k not in converted_body}
converted_body[param.name] = base_type(**remaining_keys)
return converted_body

View file

@ -28,7 +28,6 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.core.client import get_client_impl
@ -55,7 +54,6 @@ from llama_stack.providers.datatypes import (
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolGroupsProtocolPrivate,
VectorDBsProtocolPrivate,
)
logger = get_logger(name=__name__, category="core")
@ -81,7 +79,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.inspect: Inspect,
Api.batches: Batches,
Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
@ -125,7 +122,6 @@ def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (
@ -150,6 +146,7 @@ async def resolve_impls(
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
internal_impls: dict[Api, Any] | None = None,
) -> dict[Api, Any]:
"""
Resolves provider implementations by:
@ -172,7 +169,7 @@ async def resolve_impls(
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy, internal_impls)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
@ -280,9 +277,10 @@ async def instantiate_providers(
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
internal_impls: dict[Api, Any] | None = None,
) -> dict[Api, Any]:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {}
impls: dict[Api, Any] = internal_impls.copy() if internal_impls else {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
# Skip providers that are not enabled

View file

@ -31,10 +31,8 @@ async def get_routing_table_impl(
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
api_to_tables = {
"vector_dbs": VectorDBsRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,

View file

@ -10,9 +10,10 @@ from collections.abc import AsyncGenerator, AsyncIterator
from datetime import UTC, datetime
from typing import Annotated, Any
from fastapi import Body
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -31,15 +32,17 @@ from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAICompletionWithInputMessages,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
Order,
StopReason,
ToolPromptFormat,
@ -181,61 +184,23 @@ class InferenceRouter(Inference):
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
)
model_obj = await self._get_model(model, ModelType.llm)
params = dict(
model=model_obj.identifier,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
suffix=suffix,
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
)
model_obj = await self._get_model(params.model, ModelType.llm)
# Update params with the resolved model identifier
params.model = model_obj.identifier
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
return await provider.openai_completion(**params)
if params.stream:
return await provider.openai_completion(params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
# response_stream = await provider.openai_completion(**params)
response = await provider.openai_completion(**params)
response = await provider.openai_completion(params)
if self.telemetry:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
@ -254,93 +219,49 @@ class InferenceRouter(Inference):
async def openai_chat_completion(
self,
model: str,
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
)
model_obj = await self._get_model(model, ModelType.llm)
model_obj = await self._get_model(params.model, ModelType.llm)
# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
if tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
if tools is None:
if params.tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
if params.tools is None:
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
if tools:
for tool in tools:
if params.tools:
for tool in params.tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
# Some providers make tool calls even when tool_choice is "none"
# so just clear them both out to avoid unexpected tool calls
if tool_choice == "none" and tools is not None:
tool_choice = None
tools = None
if params.tool_choice == "none" and params.tools is not None:
params.tool_choice = None
params.tools = None
# Update params with the resolved model identifier
params.model = model_obj.identifier
params = dict(
model=model_obj.identifier,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
response_stream = await provider.openai_chat_completion(**params)
if params.stream:
response_stream = await provider.openai_chat_completion(params)
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
# We need to add metrics to each chunk and store the final completion
return self.stream_tokens_and_compute_metrics_openai_chat(
response=response_stream,
model=model_obj,
messages=messages,
messages=params.messages,
)
response = await self._nonstream_openai_chat_completion(provider, params)
# Store the response with the ID that will be returned to the client
if self.store:
asyncio.create_task(self.store.store_chat_completion(response, messages))
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
if self.telemetry:
metrics = self._construct_metrics(
@ -359,26 +280,18 @@ class InferenceRouter(Inference):
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)],
) -> OpenAIEmbeddingsResponse:
logger.debug(
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
)
model_obj = await self._get_model(model, ModelType.embedding)
params = dict(
model=model_obj.identifier,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}",
)
model_obj = await self._get_model(params.model, ModelType.embedding)
# Update model to use resolved identifier
params.model = model_obj.identifier
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_embeddings(**params)
return await provider.openai_embeddings(params)
async def list_chat_completions(
self,
@ -396,8 +309,10 @@ class InferenceRouter(Inference):
return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
async def _nonstream_openai_chat_completion(
self, provider: Inference, params: OpenAIChatCompletionRequestWithExtraBody
) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(params)
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
@ -611,7 +526,7 @@ class InferenceRouter(Inference):
completion_text += "".join(choice_data["content_parts"])
# Add metrics to the chunk
if self.telemetry and chunk.usage:
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,

View file

@ -6,12 +6,16 @@
import asyncio
import uuid
from typing import Any
from typing import Annotated, Any
from fastapi import Body
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
OpenAICreateVectorStoreRequestWithExtraBody,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
@ -51,30 +55,18 @@ class VectorIORouter(VectorIO):
logger.debug("VectorIORouter.shutdown")
pass
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
"""Get the first available embedding model identifier."""
try:
# Get all models from the routing table
all_models = await self.routing_table.get_all_with_type("model")
async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int:
"""Get the embedding dimension for a specific embedding model."""
all_models = await self.routing_table.get_all_with_type("model")
# Filter for embedding models
embedding_models = [
model
for model in all_models
if hasattr(model, "model_type") and model.model_type == ModelType.embedding
]
if embedding_models:
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
for model in all_models:
if model.identifier == embedding_model_id and model.model_type == ModelType.embedding:
dimension = model.metadata.get("embedding_dimension")
if dimension is None:
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
return embedding_models[0].identifier, dimension
else:
logger.warning("No embedding models found in the routing table")
return None
except Exception as e:
logger.error(f"Error getting embedding models: {e}")
return None
raise ValueError(f"Embedding model '{embedding_model_id}' has no embedding_dimension in metadata")
return int(dimension)
raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model")
async def register_vector_db(
self,
@ -120,24 +112,35 @@ class VectorIORouter(VectorIO):
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = None,
provider_id: str | None = None,
params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
# Extract llama-stack-specific parameters from extra_body
extra = params.model_extra or {}
embedding_model = extra.get("embedding_model")
embedding_dimension = extra.get("embedding_dimension")
provider_id = extra.get("provider_id")
# If no embedding model is provided, use the first available one
logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}")
# Require explicit embedding model specification
if embedding_model is None:
embedding_model_info = await self._get_first_embedding_model()
if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system")
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
raise ValueError("embedding_model is required in extra_body when creating a vector store")
if embedding_dimension is None:
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
# Auto-select provider if not specified
if provider_id is None:
num_providers = len(self.routing_table.impls_by_provider_id)
if num_providers == 0:
raise ValueError("No vector_io providers available")
if num_providers > 1:
available_providers = list(self.routing_table.impls_by_provider_id.keys())
raise ValueError(
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
f"Available providers: {available_providers}"
)
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db(
@ -146,20 +149,19 @@ class VectorIORouter(VectorIO):
embedding_dimension=embedding_dimension,
provider_id=provider_id,
provider_vector_db_id=vector_db_id,
vector_db_name=name,
vector_db_name=params.name,
)
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
return await provider.openai_create_vector_store(
name=name,
file_ids=file_ids,
expires_after=expires_after,
chunking_strategy=chunking_strategy,
metadata=metadata,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
provider_id=registered_vector_db.provider_id,
provider_vector_db_id=registered_vector_db.provider_resource_id,
)
# Update model_extra with registered values so provider uses the already-registered vector_db
if params.model_extra is None:
params.model_extra = {}
params.model_extra["provider_vector_db_id"] = registered_vector_db.provider_resource_id
params.model_extra["provider_id"] = registered_vector_db.provider_id
params.model_extra["embedding_model"] = embedding_model
params.model_extra["embedding_dimension"] = embedding_dimension
return await provider.openai_create_vector_store(params)
async def openai_list_vector_stores(
self,
@ -219,7 +221,8 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
@ -229,7 +232,8 @@ class VectorIORouter(VectorIO):
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
return await self.routing_table.openai_update_vector_store(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
@ -241,7 +245,8 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
return await self.routing_table.openai_delete_vector_store(vector_store_id)
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store(vector_store_id)
async def openai_search_vector_store(
self,
@ -254,7 +259,8 @@ class VectorIORouter(VectorIO):
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
return await self.routing_table.openai_search_vector_store(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
@ -272,7 +278,8 @@ class VectorIORouter(VectorIO):
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
return await self.routing_table.openai_attach_file_to_vector_store(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
@ -289,7 +296,8 @@ class VectorIORouter(VectorIO):
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
return await self.routing_table.openai_list_files_in_vector_store(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
@ -304,7 +312,8 @@ class VectorIORouter(VectorIO):
file_id: str,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_retrieve_vector_store_file(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -315,7 +324,8 @@ class VectorIORouter(VectorIO):
file_id: str,
) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
return await self.routing_table.openai_retrieve_vector_store_file_contents(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -327,7 +337,8 @@ class VectorIORouter(VectorIO):
attributes: dict[str, Any],
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_update_vector_store_file(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
@ -339,7 +350,8 @@ class VectorIORouter(VectorIO):
file_id: str,
) -> VectorStoreFileDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_delete_vector_store_file(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -370,17 +382,13 @@ class VectorIORouter(VectorIO):
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)],
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(file_ids)} files")
return await self.routing_table.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id,
file_ids=file_ids,
attributes=attributes,
chunking_strategy=chunking_strategy,
logger.debug(
f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(params.file_ids)} files"
)
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch(vector_store_id, params)
async def openai_retrieve_vector_store_file_batch(
self,
@ -388,7 +396,8 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}")
return await self.routing_table.openai_retrieve_vector_store_file_batch(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)
@ -404,7 +413,8 @@ class VectorIORouter(VectorIO):
order: str | None = "desc",
) -> VectorStoreFilesListInBatchResponse:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}")
return await self.routing_table.openai_list_files_in_vector_store_file_batch(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
after=after,
@ -420,7 +430,8 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}")
return await self.routing_table.openai_cancel_vector_store_file_batch(
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)

View file

@ -9,7 +9,6 @@ from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import Model
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import Action
from llama_stack.core.datatypes import (
@ -17,6 +16,7 @@ from llama_stack.core.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
ScoringFnWithOwner,
)
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.core.store import DistributionRegistry
@ -114,7 +114,7 @@ class CommonRoutingTableImpl(RoutingTable):
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
await add_objects(scoring_functions, pid, ScoringFnWithOwner)
elif api == Api.eval:
p.benchmark_store = self
elif api == Api.tool_runtime:
@ -134,15 +134,12 @@ class CommonRoutingTableImpl(RoutingTable):
from .scoring_functions import ScoringFunctionsRoutingTable
from .shields import ShieldsRoutingTable
from .toolgroups import ToolGroupsRoutingTable
from .vector_dbs import VectorDBsRoutingTable
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, VectorDBsRoutingTable):
return ("VectorIO", "vector_db")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):

View file

@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
try:
models = await provider.list_models()
except Exception as e:
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
logger.debug(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
@ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
return self.impls_by_provider_id[model.provider_id]
async def has_model(self, model_id: str) -> bool:
"""
Check if a model exists in the routing table.
:param model_id: The model identifier to check
:return: True if the model exists, False otherwise
"""
try:
await lookup_model(self, model_id)
return True
except ModelNotFoundError:
return False
async def register_model(
self,
model_id: str,

View file

@ -1,247 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.apis.vector_io.vector_io import (
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import (
VectorDBWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise VectorStoreNotFoundError(vector_db_id)
return vector_db
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB:
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await lookup_model(self, embedding_model)
if model is None:
raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding:
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
provider = self.impls_by_provider_id[provider_id]
logger.warning(
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
)
vector_store = await provider.openai_create_vector_store(
name=vector_db_name or vector_db_id,
embedding_model=embedding_model,
embedding_dimension=model.metadata["embedding_dimension"],
provider_id=provider_id,
provider_vector_db_id=provider_vector_db_id,
)
vector_store_id = vector_store.id
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
logger.warning(
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
)
vector_db_data = {
"identifier": vector_store_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": actual_provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_store.name,
}
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
async def unregister_vector_db(self, vector_db_id: str) -> None:
existing_vector_db = await self.get_vector_db(vector_db_id)
await self.unregister_object(existing_vector_db)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id)
return result
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
after=after,
before=before,
filter=filter,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)

View file

@ -27,6 +27,11 @@ class AuthenticationMiddleware:
3. Extracts user attributes from the provider's response
4. Makes these attributes available to the route handlers for access control
Unauthenticated Access:
Endpoints can opt out of authentication by setting require_authentication=False
in their @webmethod decorator. This is typically used for operational endpoints
like /health and /version to support monitoring, load balancers, and observability tools.
The middleware supports multiple authentication providers through the AuthProvider interface:
- Kubernetes: Validates tokens against the Kubernetes API server
- Custom: Validates tokens against a custom endpoint
@ -88,7 +93,26 @@ class AuthenticationMiddleware:
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# First, handle authentication
# Find the route and check if authentication is required
path = scope.get("path", "")
method = scope.get("method", hdrs.METH_GET)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
webmethod = None
try:
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass here to run auth anyways
pass
# If webmethod explicitly sets require_authentication=False, allow without auth
if webmethod and webmethod.require_authentication is False:
logger.debug(f"Allowing unauthenticated access to endpoint: {path}")
return await self.app(scope, receive, send)
# Handle authentication
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
@ -127,19 +151,7 @@ class AuthenticationMiddleware:
)
# Scope-based API access control
path = scope.get("path", "")
method = scope.get("method", hdrs.METH_GET)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
try:
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
return await self.app(scope, receive, send)
if webmethod.required_scope:
if webmethod and webmethod.required_scope:
user = user_from_scope(scope)
if not _has_required_scope(webmethod.required_scope, user):
return await self._send_auth_error(

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import concurrent.futures
import functools
@ -12,7 +11,6 @@ import inspect
import json
import logging # allow-direct-logging
import os
import ssl
import sys
import traceback
import warnings
@ -35,7 +33,6 @@ from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
from llama_stack.core.access_control.access_control import AccessDeniedError
from llama_stack.core.datatypes import (
AuthenticationRequiredError,
@ -55,7 +52,6 @@ from llama_stack.core.stack import (
Stack,
cast_image_name_to_string,
replace_env_vars,
validate_env_pair,
)
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
@ -142,6 +138,13 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}")
elif isinstance(exc, AuthenticationRequiredError):
return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}")
elif hasattr(exc, "status_code") and isinstance(getattr(exc, "status_code", None), int):
# Handle provider SDK exceptions (e.g., OpenAI's APIStatusError and subclasses)
# These include AuthenticationError (401), PermissionDeniedError (403), etc.
# This preserves the actual HTTP status code from the provider
status_code = exc.status_code
detail = str(exc)
return HTTPException(status_code=status_code, detail=detail)
else:
return HTTPException(
status_code=httpx.codes.INTERNAL_SERVER_ERROR,
@ -181,7 +184,17 @@ async def lifespan(app: StackApp):
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
# If there's a stream parameter at top level, use it
if "stream" in kwargs:
return kwargs["stream"]
# If there's a stream parameter inside a "params" parameter, e.g. openai_chat_completion() use it
if "params" in kwargs:
params = kwargs["params"]
if hasattr(params, "stream"):
return params.stream
return False
async def maybe_await(value):
@ -236,15 +249,31 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
await log_request_pre_validation(request)
test_context_token = None
test_context_var = None
reset_test_context_fn = None
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user):
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
from llama_stack.core.testing_context import (
TEST_CONTEXT,
reset_test_context,
sync_test_context_from_provider_data,
)
test_context_token = sync_test_context_from_provider_data()
test_context_var = TEST_CONTEXT
reset_test_context_fn = reset_test_context
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
)
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
if test_context_var is not None:
context_vars.append(test_context_var)
gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars)
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
@ -262,6 +291,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e
finally:
if test_context_token is not None and reset_test_context_fn is not None:
reset_test_context_fn(test_context_token)
sig = inspect.signature(func)
@ -333,23 +365,18 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send)
def create_app(
config_file: str | None = None,
env_vars: list[str] | None = None,
) -> StackApp:
def create_app() -> StackApp:
"""Create and configure the FastAPI application.
Args:
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
env_vars: List of environment variables in KEY=value format.
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
This factory function reads configuration from environment variables:
- LLAMA_STACK_CONFIG: Path to config file (required)
Returns:
Configured StackApp instance.
"""
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
config_file = os.getenv("LLAMA_STACK_CONFIG")
if config_file is None:
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
raise ValueError("LLAMA_STACK_CONFIG environment variable is required")
config_file = resolve_config_or_distro(config_file, Mode.RUN)
@ -361,16 +388,6 @@ def create_app(
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="core::server", config=logger_config)
if env_vars:
for env_pair in env_vars:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config))
@ -494,101 +511,6 @@ def create_app(
return app
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
add_config_distro_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
config_or_distro = get_config_from_args(args)
try:
app = create_app(
config_file=config_or_distro,
env_vars=args.env,
)
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
sys.exit(1)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
import uvicorn
# Configure SSL if certificates are provided
port = args.port or config.server.port
ssl_config = None
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
ssl_config = {
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
if config.server.tls_cafile:
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = config.server.host or ["::", "0.0.0.0"]
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,
"host": listen_host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
}
if ssl_config:
uvicorn_config.update(ssl_config)
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
# stack trace when using Ctrl+C or kill -2 (SIGINT).
# SIGTERM (kill -15) works fine without this because Python doesn't
# have a default handler for it.
#
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:")
@ -615,7 +537,3 @@ def remove_disabled_providers(obj):
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
else:
return obj
if __name__ == "__main__":
main()

View file

@ -33,7 +33,6 @@ from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig
@ -53,7 +52,6 @@ logger = get_logger(name=__name__, category="core")
class LlamaStack(
Providers,
VectorDBs,
Inference,
Agents,
Safety,
@ -83,7 +81,6 @@ class LlamaStack(
RESOURCES = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
@ -274,22 +271,6 @@ def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
def validate_env_pair(env_pair: str) -> tuple[str, str]:
"""Validate and split an environment variable key-value pair."""
try:
key, value = env_pair.split("=", 1)
key = key.strip()
if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key):
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
return key, value
except ValueError as e:
raise ValueError(
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
) from e
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
@ -332,22 +313,27 @@ class Stack:
# asked for in the run config.
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
from llama_stack.testing.api_recorder import setup_api_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
TEST_RECORDING_CONTEXT = setup_api_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(self.run_config.persistence, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
impls = await resolve_impls(
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, self.run_config)
internal_impls = {}
add_internal_implementations(internal_impls, self.run_config)
impls = await resolve_impls(
self.run_config,
self.provider_registry or get_provider_registry(self.run_config),
dist_registry,
policy,
internal_impls,
)
if Api.prompts in impls:
await impls[Api.prompts].initialize()
@ -397,7 +383,7 @@ class Stack:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
logger.error(f"Error during API recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:

View file

@ -25,7 +25,7 @@ error_handler() {
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>]"
exit 1
fi
@ -43,7 +43,6 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
# Initialize variables
yaml_config=""
env_vars=""
other_args=""
# Process remaining arguments
@ -58,15 +57,6 @@ while [[ $# -gt 0 ]]; do
exit 1
fi
;;
--env)
if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
@ -116,10 +106,9 @@ if [[ "$env_type" == "venv" ]]; then
yaml_config_arg=""
fi
$PYTHON_BINARY -m llama_stack.core.server.server \
llama stack run \
$yaml_config_arg \
--port "$port" \
$env_vars \
$other_args
elif [[ "$env_type" == "container" ]]; then
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"

View file

@ -95,9 +95,11 @@ class DiskDistributionRegistry(DistributionRegistry):
async def register(self, obj: RoutableObjectWithProvider) -> bool:
existing_obj = await self.get(obj.type, obj.identifier)
# dont register if the object's providerid already exists
if existing_obj and existing_obj.provider_id == obj.provider_id:
return False
if existing_obj and existing_obj != obj:
raise ValueError(
f"Object of type '{obj.type}' and identifier '{obj.identifier}' already exists. "
"Unregister it first if you want to replace it."
)
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),

View file

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from contextvars import ContextVar
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
TEST_CONTEXT: ContextVar[str | None] = ContextVar("llama_stack_test_context", default=None)
def get_test_context() -> str | None:
return TEST_CONTEXT.get()
def set_test_context(value: str | None):
return TEST_CONTEXT.set(value)
def reset_test_context(token) -> None:
TEST_CONTEXT.reset(token)
def sync_test_context_from_provider_data():
"""Sync test context from provider data when running in server test mode."""
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
return None
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
if stack_config_type != "server":
return None
try:
provider_data = PROVIDER_DATA_VAR.get()
except LookupError:
provider_data = None
if provider_data and "__test_id" in provider_data:
return TEST_CONTEXT.set(provider_data["__test_id"])
return None

View file

@ -11,19 +11,17 @@ from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.core.ui.page.distribution.models import models
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
from llama_stack.core.ui.page.distribution.shields import shields
from llama_stack.core.ui.page.distribution.vector_dbs import vector_dbs
def resources_page():
options = [
"Models",
"Vector Databases",
"Shields",
"Scoring Functions",
"Datasets",
"Benchmarks",
]
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
icons = ["magic", "shield", "file-bar-graph", "database", "list-task"]
selected_resource = option_menu(
None,
options,
@ -37,8 +35,6 @@ def resources_page():
)
if selected_resource == "Benchmarks":
benchmarks()
elif selected_resource == "Vector Databases":
vector_dbs()
elif selected_resource == "Datasets":
datasets()
elif selected_resource == "Models":

View file

@ -1,20 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def vector_dbs():
st.header("Vector Databases")
vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
if len(vector_dbs_info) > 0:
selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
st.json(vector_dbs_info[selected_vector_db])
else:
st.info("No vector databases found")

View file

@ -1,301 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
import streamlit as st
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack.apis.common.content_types import ToolCallDelta
from llama_stack.core.ui.modules.api import llama_stack_api
from llama_stack.core.ui.modules.utils import data_url_from_file
def rag_chat_page():
st.title("🦙 RAG")
def reset_agent_and_chat():
st.session_state.clear()
st.cache_resource.clear()
def should_disable_input():
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
def log_message(message):
with st.chat_message(message["role"]):
if "tool_output" in message and message["tool_output"]:
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
st.write(message["tool_output"])
st.markdown(message["content"])
with st.sidebar:
# File/Directory Upload Section
st.subheader("Upload Documents", divider=True)
uploaded_files = st.file_uploader(
"Upload file(s) or directory",
accept_multiple_files=True,
type=["txt", "pdf", "doc", "docx"], # Add more file types as needed
)
# Process uploaded files
if uploaded_files:
st.success(f"Successfully uploaded {len(uploaded_files)} files")
# Add memory bank name input field
vector_db_name = st.text_input(
"Document Collection Name",
value="rag_vector_db",
help="Enter a unique identifier for this document collection",
)
if st.button("Create Document Collection"):
documents = [
RAGDocument(
document_id=uploaded_file.name,
content=data_url_from_file(uploaded_file),
)
for i, uploaded_file in enumerate(uploaded_files)
]
providers = llama_stack_api.client.providers.list()
vector_io_provider = None
for x in providers:
if x.api == "vector_io":
vector_io_provider = x.provider_id
llama_stack_api.client.vector_dbs.register(
vector_db_id=vector_db_name, # Use the user-provided name
embedding_dimension=384,
embedding_model="all-MiniLM-L6-v2",
provider_id=vector_io_provider,
)
# insert documents using the custom vector db name
llama_stack_api.client.tool_runtime.rag_tool.insert(
vector_db_id=vector_db_name, # Use the user-provided name
documents=documents,
chunk_size_in_tokens=512,
)
st.success("Vector database created successfully!")
st.subheader("RAG Parameters", divider=True)
rag_mode = st.radio(
"RAG mode",
["Direct", "Agent-based"],
captions=[
"RAG is performed by directly retrieving the information and augmenting the user query",
"RAG is performed by an agent activating a dedicated knowledge search tool.",
],
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# select memory banks
vector_dbs = llama_stack_api.client.vector_dbs.list()
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
selected_vector_dbs = st.multiselect(
label="Select Document Collections to use in RAG queries",
options=vector_dbs,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
st.subheader("Inference Parameters", divider=True)
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
label="Choose a model",
options=available_models,
index=0,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful assistant. ",
help="Initial instructions given to the AI to set its behavior and context",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True):
reset_agent_and_chat()
st.rerun()
# Chat Interface
if "messages" not in st.session_state:
st.session_state.messages = []
if "displayed_messages" not in st.session_state:
st.session_state.displayed_messages = []
# Display chat history
for message in st.session_state.displayed_messages:
log_message(message)
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
@st.cache_resource
def create_agent():
return Agent(
llama_stack_api.client,
model=selected_model,
instructions=system_prompt,
sampling_params={
"strategy": strategy,
},
tools=[
dict(
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": list(selected_vector_dbs),
},
)
],
)
if rag_mode == "Agent-based":
agent = create_agent()
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
session_id = st.session_state["agent_session_id"]
def agent_process_prompt(prompt):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Send the prompt to the agent
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
session_id=session_id,
)
# Display assistant response
with st.chat_message("assistant"):
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
for log in AgentEventLogger().log(response):
log.print()
if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip()
retrieval_message_placeholder.write(retrieval_response)
else:
full_response += log.content
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
st.session_state.displayed_messages.append(
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
)
def direct_process_prompt(prompt):
# Add the system prompt in the beginning of the conversation
if len(st.session_state.messages) == 0:
st.session_state.messages.append({"role": "system", "content": system_prompt})
# Query the vector DB
rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
content=prompt, vector_db_ids=list(selected_vector_dbs)
)
prompt_context = rag_response.content
with st.chat_message("assistant"):
with st.expander(label="Retrieval Output", expanded=False):
st.write(prompt_context)
retrieval_message_placeholder = st.empty()
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
# Construct the extended prompt
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
# Run inference directly
st.session_state.messages.append({"role": "user", "content": extended_prompt})
response = llama_stack_api.client.inference.chat_completion(
messages=st.session_state.messages,
model_id=selected_model,
sampling_params={
"strategy": strategy,
},
stream=True,
)
# Display assistant response
for chunk in response:
response_delta = chunk.event.delta
if isinstance(response_delta, ToolCallDelta):
retrieval_response += response_delta.tool_call.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response)
else:
full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
st.session_state.messages.append(response_dict)
st.session_state.displayed_messages.append(response_dict)
# Chat input
if prompt := st.chat_input("Ask a question about your documents"):
# Add user message to chat history
st.session_state.displayed_messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# store the prompt to process it after page refresh
st.session_state.prompt = prompt
# force page refresh to disable the settings widgets
st.rerun()
if "prompt" in st.session_state and st.session_state.prompt is not None:
if rag_mode == "Agent-based":
agent_process_prompt(st.session_state.prompt)
else: # rag_mode == "Direct"
direct_process_prompt(st.session_state.prompt)
st.session_state.prompt = None
rag_chat_page()

View file

@ -117,11 +117,11 @@ docker run -it \
# NOTE: mount the llama-stack directory if testing local changes else not needed
-v $HOME/git/llama-stack:/app/llama-stack-source \
# localhost/distribution-dell:dev if building / testing locally
-e INFERENCE_MODEL=$INFERENCE_MODEL \
-e DEH_URL=$DEH_URL \
-e CHROMA_URL=$CHROMA_URL \
llamastack/distribution-{{ name }}\
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env DEH_URL=$DEH_URL \
--env CHROMA_URL=$CHROMA_URL
--port $LLAMA_STACK_PORT
```
@ -142,14 +142,14 @@ docker run \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $HOME/.llama:/root/.llama \
-v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \
-e INFERENCE_MODEL=$INFERENCE_MODEL \
-e DEH_URL=$DEH_URL \
-e SAFETY_MODEL=$SAFETY_MODEL \
-e DEH_SAFETY_URL=$DEH_SAFETY_URL \
-e CHROMA_URL=$CHROMA_URL \
llamastack/distribution-{{ name }} \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env DEH_URL=$DEH_URL \
--env SAFETY_MODEL=$SAFETY_MODEL \
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
--env CHROMA_URL=$CHROMA_URL
--port $LLAMA_STACK_PORT
```
### Via Conda
@ -158,21 +158,21 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a
```bash
llama stack build --distro {{ name }} --image-type conda
llama stack run {{ name }}
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env DEH_URL=$DEH_URL \
--env CHROMA_URL=$CHROMA_URL
INFERENCE_MODEL=$INFERENCE_MODEL \
DEH_URL=$DEH_URL \
CHROMA_URL=$CHROMA_URL \
llama stack run {{ name }} \
--port $LLAMA_STACK_PORT
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
INFERENCE_MODEL=$INFERENCE_MODEL \
DEH_URL=$DEH_URL \
SAFETY_MODEL=$SAFETY_MODEL \
DEH_SAFETY_URL=$DEH_SAFETY_URL \
CHROMA_URL=$CHROMA_URL \
llama stack run ./run-with-safety.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env DEH_URL=$DEH_URL \
--env SAFETY_MODEL=$SAFETY_MODEL \
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
--env CHROMA_URL=$CHROMA_URL
--port $LLAMA_STACK_PORT
```

View file

@ -101,6 +101,9 @@ metadata_store:
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db
conversations_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/conversations.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -29,31 +29,7 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ llama model list --downloaded
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
Please check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models using the Hugging Face CLI.
```
## Running the Distribution
@ -72,9 +48,9 @@ docker run \
--gpu all \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
--port $LLAMA_STACK_PORT
```
If you are using Llama Stack Safety / Shield APIs, use:
@ -86,10 +62,10 @@ docker run \
--gpu all \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
-e SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
--port $LLAMA_STACK_PORT
```
### Via venv
@ -98,16 +74,16 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
```bash
llama stack build --distro {{ name }} --image-type venv
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
llama stack run distributions/{{ name }}/run.yaml \
--port 8321 \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
--port 8321
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
llama stack run distributions/{{ name }}/run-with-safety.yaml \
--port 8321 \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
--port 8321
```

View file

@ -114,6 +114,9 @@ metadata_store:
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db
conversations_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/conversations.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -118,10 +118,10 @@ docker run \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
llamastack/distribution-{{ name }} \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
--port $LLAMA_STACK_PORT
```
### Via venv
@ -131,10 +131,10 @@ If you've set up your local development environment, you can also build the imag
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
llama stack build --distro nvidia --image-type venv
NVIDIA_API_KEY=$NVIDIA_API_KEY \
INFERENCE_MODEL=$INFERENCE_MODEL \
llama stack run ./run.yaml \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
--env INFERENCE_MODEL=$INFERENCE_MODEL
--port 8321
```
## Example Notebooks

View file

@ -103,6 +103,9 @@ metadata_store:
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
conversations_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/conversations.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -181,6 +181,7 @@ class RunConfigSettings(BaseModel):
default_benchmarks: list[BenchmarkInput] | None = None
metadata_store: dict | None = None
inference_store: dict | None = None
conversations_store: dict | None = None
def run_config(
self,
@ -240,6 +241,11 @@ class RunConfigSettings(BaseModel):
__distro_dir__=f"~/.llama/distributions/{name}",
db_name="inference_store.db",
),
"conversations_store": self.conversations_store
or SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=f"~/.llama/distributions/{name}",
db_name="conversations.db",
),
"models": [m.model_dump(exclude_none=True) for m in (self.default_models or [])],
"shields": [s.model_dump(exclude_none=True) for s in (self.default_shields or [])],
"vector_dbs": [],

View file

@ -3,3 +3,5 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .watsonx import get_distribution_template # noqa: F401

View file

@ -3,44 +3,33 @@ distribution_spec:
description: Use watsonx for running LLM inference
providers:
inference:
- provider_id: watsonx
provider_type: remote::watsonx
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
- provider_type: remote::watsonx
- provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
- provider_type: inline::faiss
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
- provider_type: inline::llama-guard
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_type: inline::meta-reference
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_type: inline::meta-reference
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_type: inline::meta-reference
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
- provider_id: localfs
provider_type: inline::localfs
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
files:
- provider_type: inline::localfs
image_type: venv
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]
- aiosqlite
- aiosqlite

View file

@ -4,17 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
config=WatsonXConfig.sample_run_config(),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
available_models = {
"watsonx": MODEL_ENTRIES,
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
),
]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name=name,
distro_type="remote_hosted",
description="Use watsonx for running LLM inference",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
template_path=None,
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"inference": [inference_provider],
"files": [files_provider],
},
default_models=default_models + [embedding_model],
default_models=[],
default_tool_groups=default_tool_groups,
),
},

View file

@ -30,13 +30,19 @@ CATEGORIES = [
"tools",
"client",
"telemetry",
"openai",
"openai_responses",
"openai_conversations",
"testing",
"providers",
"models",
"files",
"vector_io",
"tool_runtime",
"cli",
"post_training",
"scoring",
"tests",
]
UNCATEGORIZED = "uncategorized"
@ -128,7 +134,10 @@ def strip_rich_markup(text):
class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs):
kwargs["console"] = Console()
# Set a reasonable default width for console output, especially when redirected to files
console_width = int(os.environ.get("LLAMA_STACK_LOG_WIDTH", "120"))
# Don't force terminal codes to avoid ANSI escape codes in log files
kwargs["console"] = Console(width=console_width)
super().__init__(*args, **kwargs)
def emit(self, record):
@ -261,11 +270,12 @@ def get_logger(
if root_category in _category_levels:
log_level = _category_levels[root_category]
else:
log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
if category != UNCATEGORIZED:
logging.warning(
f"Unknown logging category: {category}. Falling back to default 'root' level: {log_level}"
raise ValueError(
f"Unknown logging category: {category}. To resolve, choose a valid category from the CATEGORIES list "
f"or add it to the CATEGORIES list. Available categories: {CATEGORIES}"
)
log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
logger.setLevel(log_level)
return logging.LoggerAdapter(logger, {"category": category})

View file

@ -11,19 +11,13 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json
import textwrap
from pathlib import Path
from pydantic import BaseModel, Field
from llama_stack.models.llama.datatypes import (
RawContent,
RawMediaItem,
RawMessage,
RawTextItem,
StopReason,
ToolCall,
ToolPromptFormat,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
@ -175,25 +169,6 @@ def llama3_1_builtin_code_interpreter_dialog(tool_prompt_format=ToolPromptFormat
return messages
def llama3_1_builtin_tool_call_with_image_dialog(
tool_prompt_format=ToolPromptFormat.json,
):
this_dir = Path(__file__).parent
with open(this_dir / "llama3/dog.jpg", "rb") as f:
img = f.read()
interface = LLama31Interface(tool_prompt_format)
messages = interface.system_messages(**system_message_builtin_tools_only())
messages += interface.user_message(content=[RawMediaItem(data=img), RawTextItem(text="What is this dog breed?")])
messages += interface.assistant_response_messages(
"Based on the description of the dog in the image, it appears to be a small breed dog, possibly a terrier mix",
StopReason.end_of_turn,
)
messages += interface.user_message("Search the web for some food recommendations for the indentified breed")
return messages
def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
interface = LLama31Interface(tool_prompt_format)
@ -202,35 +177,6 @@ def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
return messages
def llama3_1_e2e_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
tool_response = json.dumps(["great song1", "awesome song2", "cool song3"])
interface = LLama31Interface(tool_prompt_format)
messages = interface.system_messages(**system_message_custom_tools_only())
messages += interface.user_message(content="Use tools to get latest trending songs")
messages.append(
RawMessage(
role="assistant",
content="",
stop_reason=StopReason.end_of_message,
tool_calls=[
ToolCall(
call_id="call_id",
tool_name="trending_songs",
arguments={"n": "10", "genre": "latest"},
)
],
),
)
messages.append(
RawMessage(
role="assistant",
content=tool_response,
)
)
return messages
def llama3_2_user_assistant_conversation():
return UseCase(
title="User and assistant conversation",

View file

@ -9,7 +9,7 @@ from pathlib import Path
from llama_stack.log import get_logger
logger = get_logger(__name__, "tokenizer_utils")
logger = get_logger(__name__, "models")
def load_bpe_file(model_path: Path) -> dict[bytes, int]:

View file

@ -21,7 +21,9 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
deps[Api.safety],
deps[Api.tool_runtime],
deps[Api.tool_groups],
deps[Api.conversations],
policy,
Api.telemetry in deps,
)
await impl.initialize()
return impl

View file

@ -7,8 +7,6 @@
import copy
import json
import re
import secrets
import string
import uuid
import warnings
from collections.abc import AsyncGenerator
@ -51,6 +49,7 @@ from llama_stack.apis.inference import (
Inference,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
@ -84,11 +83,6 @@ from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
from .safety import SafetyException, ShieldRunnerMixin
def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
@ -110,6 +104,7 @@ class ChatAgent(ShieldRunnerMixin):
persistence_store: KVStore,
created_at: str,
policy: list[AccessRule],
telemetry_enabled: bool = False,
):
self.agent_id = agent_id
self.agent_config = agent_config
@ -120,6 +115,7 @@ class ChatAgent(ShieldRunnerMixin):
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.created_at = created_at
self.telemetry_enabled = telemetry_enabled
ShieldRunnerMixin.__init__(
self,
@ -188,28 +184,30 @@ class ChatAgent(ShieldRunnerMixin):
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
turn_id = str(uuid.uuid4())
span = tracing.get_current_span()
if span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
if self.telemetry_enabled:
span = tracing.get_current_span()
if span is not None:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
await self._initialize_tools(request.toolgroups)
async for chunk in self._run_turn(request, turn_id):
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
span = tracing.get_current_span()
if span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", request.turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
if self.telemetry_enabled:
span = tracing.get_current_span()
if span is not None:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", request.turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
await self._initialize_tools()
async for chunk in self._run_turn(request):
@ -395,9 +393,12 @@ class ChatAgent(ShieldRunnerMixin):
touchpoint: str,
) -> AsyncGenerator:
async with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if self.telemetry_enabled and span is not None:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
if len(shields) == 0:
span.set_attribute("output", "no shields")
return
step_id = str(uuid.uuid4())
@ -430,7 +431,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
span.set_attribute("output", e.violation.model_dump_json())
if self.telemetry_enabled and span is not None:
span.set_attribute("output", e.violation.model_dump_json())
yield CompletionMessage(
content=str(e),
@ -453,7 +455,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
span.set_attribute("output", "no violations")
if self.telemetry_enabled and span is not None:
span.set_attribute("output", "no violations")
async def _run(
self,
@ -518,8 +521,9 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason: StopReason | None = None
async with tracing.span("inference") as span:
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
if self.telemetry_enabled and span is not None:
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
def _serialize_nested(value):
"""Recursively serialize nested Pydantic models to dicts."""
@ -579,7 +583,7 @@ class ChatAgent(ShieldRunnerMixin):
max_tokens = getattr(sampling_params, "max_tokens", None)
# Use OpenAI chat completion
openai_stream = await self.inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.agent_config.model,
messages=openai_messages,
tools=openai_tools if openai_tools else None,
@ -590,6 +594,7 @@ class ChatAgent(ShieldRunnerMixin):
max_tokens=max_tokens,
stream=True,
)
openai_stream = await self.inference_api.openai_chat_completion(params)
# Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream(
@ -637,18 +642,19 @@ class ChatAgent(ShieldRunnerMixin):
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
span.set_attribute(
"input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
)
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
}
)
span.set_attribute("output", output_attr)
if self.telemetry_enabled and span is not None:
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
span.set_attribute(
"input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
)
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
}
)
span.set_attribute("output", output_attr)
n_iter += 1
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
@ -756,7 +762,9 @@ class ChatAgent(ShieldRunnerMixin):
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
}
if self.telemetry_enabled
else {},
) as span:
tool_execution_start_time = datetime.now(UTC).isoformat()
tool_result = await self.execute_tool_call_maybe(
@ -771,7 +779,8 @@ class ChatAgent(ShieldRunnerMixin):
call_id=tool_call.call_id,
content=tool_result.content,
)
span.set_attribute("output", result_message.model_dump_json())
if self.telemetry_enabled and span is not None:
span.set_attribute("output", result_message.model_dump_json())
# Store tool execution step
tool_execution_step = ToolExecutionStep(

View file

@ -30,6 +30,7 @@ from llama_stack.apis.agents import (
)
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import (
Inference,
ToolConfig,
@ -63,7 +64,9 @@ class MetaReferenceAgentsImpl(Agents):
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
conversations_api: Conversations,
policy: list[AccessRule],
telemetry_enabled: bool = False,
):
self.config = config
self.inference_api = inference_api
@ -71,6 +74,8 @@ class MetaReferenceAgentsImpl(Agents):
self.safety_api = safety_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.conversations_api = conversations_api
self.telemetry_enabled = telemetry_enabled
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None
@ -86,6 +91,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api=self.tool_runtime_api,
responses_store=self.responses_store,
vector_io_api=self.vector_io_api,
conversations_api=self.conversations_api,
)
async def create_agent(
@ -135,6 +141,7 @@ class MetaReferenceAgentsImpl(Agents):
),
created_at=agent_info.created_at,
policy=self.policy,
telemetry_enabled=self.telemetry_enabled,
)
async def create_agent_session(
@ -322,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
@ -336,6 +344,7 @@ class MetaReferenceAgentsImpl(Agents):
model,
instructions,
previous_response_id,
conversation,
store,
stream,
temperature,

View file

@ -24,6 +24,11 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.common.errors import (
InvalidConversationIdError,
)
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.conversations.conversations import ConversationItem
from llama_stack.apis.inference import (
Inference,
OpenAIMessageParam,
@ -39,7 +44,7 @@ from llama_stack.providers.utils.responses.responses_store import (
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
from .types import ChatCompletionContext
from .types import ChatCompletionContext, ToolContext
from .utils import (
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
@ -61,12 +66,14 @@ class OpenAIResponsesImpl:
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
conversations_api: Conversations,
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
self.conversations_api = conversations_api
self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api,
tool_runtime_api=tool_runtime_api,
@ -91,13 +98,15 @@ class OpenAIResponsesImpl:
async def _process_input_with_previous_response(
self,
input: str | list[OpenAIResponseInput],
tools: list[OpenAIResponseInputTool] | None,
previous_response_id: str | None,
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
"""Process input with optional previous response context.
Returns:
tuple: (all_input for storage, messages for chat completion)
tuple: (all_input for storage, messages for chat completion, tool context)
"""
tool_context = ToolContext(tools)
if previous_response_id:
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
await self.responses_store.get_response_object(previous_response_id)
@ -108,16 +117,18 @@ class OpenAIResponsesImpl:
# Use stored messages directly and convert only new input
message_adapter = TypeAdapter(list[OpenAIMessageParam])
messages = message_adapter.validate_python(previous_response.messages)
new_messages = await convert_response_input_to_chat_messages(input)
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
messages.extend(new_messages)
else:
# Backward compatibility: reconstruct from inputs
messages = await convert_response_input_to_chat_messages(all_input)
tool_context.recover_tools_from_previous_response(previous_response)
else:
all_input = input
messages = await convert_response_input_to_chat_messages(input)
return all_input, messages
return all_input, messages, tool_context
async def _prepend_instructions(self, messages, instructions):
if instructions:
@ -201,6 +212,7 @@ class OpenAIResponsesImpl:
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
@ -217,11 +229,27 @@ class OpenAIResponsesImpl:
if shields is not None:
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
if conversation is not None and previous_response_id is not None:
raise ValueError(
"Mutually exclusive parameters: 'previous_response_id' and 'conversation'. Ensure you are only providing one of these parameters."
)
original_input = input # needed for syncing to Conversations
if conversation is not None:
if not conversation.startswith("conv_"):
raise InvalidConversationIdError(conversation)
# Check conversation exists (raises ConversationNotFoundError if not)
_ = await self.conversations_api.get_conversation(conversation)
input = await self._load_conversation_context(conversation, input)
stream_gen = self._create_streaming_response(
input=input,
original_input=original_input,
model=model,
instructions=instructions,
previous_response_id=previous_response_id,
conversation=conversation,
store=store,
temperature=temperature,
text=text,
@ -232,24 +260,42 @@ class OpenAIResponsesImpl:
if stream:
return stream_gen
else:
response = None
async for stream_chunk in stream_gen:
if stream_chunk.type == "response.completed":
if response is not None:
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
response = stream_chunk.response
# don't leave the generator half complete!
final_response = None
final_event_type = None
failed_response = None
if response is None:
raise ValueError("The response stream never completed")
return response
async for stream_chunk in stream_gen:
if stream_chunk.type in {"response.completed", "response.incomplete"}:
if final_response is not None:
raise ValueError(
"The response stream produced multiple terminal responses! "
f"Earlier response from {final_event_type}"
)
final_response = stream_chunk.response
final_event_type = stream_chunk.type
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
if failed_response is not None:
error_message = (
failed_response.error.message
if failed_response and failed_response.error
else "Response stream failed without error details"
)
raise RuntimeError(f"OpenAI response failed: {error_message}")
if final_response is None:
raise ValueError("The response stream never reached a terminal state")
return final_response
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
original_input: str | list[OpenAIResponseInput] | None = None,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
@ -257,7 +303,9 @@ class OpenAIResponsesImpl:
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
all_input, messages, tool_context = await self._process_input_with_previous_response(
input, tools, previous_response_id
)
await self._prepend_instructions(messages, instructions)
# Structured outputs
@ -269,11 +317,12 @@ class OpenAIResponsesImpl:
response_tools=tools,
temperature=temperature,
response_format=response_format,
inputs=input,
tool_context=tool_context,
inputs=all_input,
)
# Create orchestrator and delegate streaming logic
response_id = f"resp-{uuid.uuid4()}"
response_id = f"resp_{uuid.uuid4()}"
created_at = int(time.time())
orchestrator = StreamingResponseOrchestrator(
@ -288,18 +337,110 @@ class OpenAIResponsesImpl:
# Stream the response
final_response = None
failed_response = None
async for stream_chunk in orchestrator.create_response():
if stream_chunk.type == "response.completed":
if stream_chunk.type in {"response.completed", "response.incomplete"}:
final_response = stream_chunk.response
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
yield stream_chunk
# Store the response if requested
if store and final_response:
await self._store_response(
response=final_response,
input=all_input,
messages=orchestrator.final_messages,
)
# Store and sync immediately after yielding terminal events
# This ensures the storage/syncing happens even if the consumer breaks early
if (
stream_chunk.type in {"response.completed", "response.incomplete"}
and store
and final_response
and failed_response is None
):
await self._store_response(
response=final_response,
input=all_input,
messages=orchestrator.final_messages,
)
if stream_chunk.type in {"response.completed", "response.incomplete"} and conversation and final_response:
# for Conversations, we need to use the original_input if it's available, otherwise use input
sync_input = original_input if original_input is not None else input
await self._sync_response_to_conversation(conversation, sync_input, final_response)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
return await self.responses_store.delete_response_object(response_id)
async def _load_conversation_context(
self, conversation_id: str, content: str | list[OpenAIResponseInput]
) -> list[OpenAIResponseInput]:
"""Load conversation history and merge with provided content."""
conversation_items = await self.conversations_api.list(conversation_id, order="asc")
context_messages = []
for item in conversation_items.data:
if isinstance(item, OpenAIResponseMessage):
if item.role == "user":
context_messages.append(
OpenAIResponseMessage(
role="user", content=item.content, id=item.id if hasattr(item, "id") else None
)
)
elif item.role == "assistant":
context_messages.append(
OpenAIResponseMessage(
role="assistant", content=item.content, id=item.id if hasattr(item, "id") else None
)
)
# add new content to context
if isinstance(content, str):
context_messages.append(OpenAIResponseMessage(role="user", content=content))
elif isinstance(content, list):
context_messages.extend(content)
return context_messages
async def _sync_response_to_conversation(
self, conversation_id: str, content: str | list[OpenAIResponseInput], response: OpenAIResponseObject
) -> None:
"""Sync content and response messages to the conversation."""
conversation_items = []
# add user content message(s)
if isinstance(content, str):
conversation_items.append(
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": content}]}
)
elif isinstance(content, list):
for item in content:
if not isinstance(item, OpenAIResponseMessage):
raise NotImplementedError(f"Unsupported input item type: {type(item)}")
if item.role == "user":
if isinstance(item.content, str):
conversation_items.append(
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": item.content}],
}
)
elif isinstance(item.content, list):
conversation_items.append({"type": "message", "role": "user", "content": item.content})
else:
raise NotImplementedError(f"Unsupported user message content type: {type(item.content)}")
elif item.role == "assistant":
if isinstance(item.content, list):
conversation_items.append({"type": "message", "role": "assistant", "content": item.content})
else:
raise NotImplementedError(f"Unsupported assistant message content type: {type(item.content)}")
else:
raise NotImplementedError(f"Unsupported message role: {item.role}")
# add assistant response message
for output_item in response.output:
if isinstance(output_item, OpenAIResponseMessage) and output_item.role == "assistant":
if hasattr(output_item, "content") and isinstance(output_item.content, list):
conversation_items.append({"type": "message", "role": "assistant", "content": output_item.content})
if conversation_items:
adapter = TypeAdapter(list[ConversationItem])
validated_items = adapter.validate_python(conversation_items)
await self.conversations_api.add_items(conversation_id, validated_items)

View file

@ -13,6 +13,9 @@ from llama_stack.apis.agents.openai_responses import (
ApprovalFilter,
MCPListToolsTool,
OpenAIResponseContentPartOutputText,
OpenAIResponseContentPartReasoningText,
OpenAIResponseContentPartRefusal,
OpenAIResponseError,
OpenAIResponseInputTool,
OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest,
@ -22,8 +25,11 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseContentPartAdded,
OpenAIResponseObjectStreamResponseContentPartDone,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseFailed,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
OpenAIResponseObjectStreamResponseIncomplete,
OpenAIResponseObjectStreamResponseInProgress,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
@ -31,21 +37,31 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseOutputItemAdded,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseObjectStreamResponseReasoningTextDelta,
OpenAIResponseObjectStreamResponseReasoningTextDone,
OpenAIResponseObjectStreamResponseRefusalDelta,
OpenAIResponseObjectStreamResponseRefusalDone,
OpenAIResponseOutput,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText,
OpenAIResponseUsage,
OpenAIResponseUsageInputTokensDetails,
OpenAIResponseUsageOutputTokensDetails,
WebSearchToolTypes,
)
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
@ -94,113 +110,174 @@ class StreamingResponseOrchestrator:
self.tool_executor = tool_executor
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations
self.citation_files: dict[str, str] = {}
# Track accumulated usage across all inference calls
self.accumulated_usage: OpenAIResponseUsage | None = None
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages
output_messages: list[OpenAIResponseOutput] = []
# Create initial response and emit response.created immediately
initial_response = OpenAIResponseObject(
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
cloned: list[OpenAIResponseOutput] = []
for item in outputs:
if hasattr(item, "model_copy"):
cloned.append(item.model_copy(deep=True))
else:
cloned.append(item)
return cloned
def _snapshot_response(
self,
status: str,
outputs: list[OpenAIResponseOutput],
*,
error: OpenAIResponseError | None = None,
) -> OpenAIResponseObject:
return OpenAIResponseObject(
created_at=self.created_at,
id=self.response_id,
model=self.ctx.model,
object="response",
status="in_progress",
output=output_messages.copy(),
status=status,
output=self._clone_outputs(outputs),
text=self.text,
tools=self.ctx.available_tools(),
error=error,
usage=self.accumulated_usage,
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
output_messages: list[OpenAIResponseOutput] = []
# Process all tools (including MCP tools) and emit streaming events
if self.ctx.response_tools:
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
yield stream_event
# Emit response.created followed by response.in_progress to align with OpenAI streaming
yield OpenAIResponseObjectStreamResponseCreated(
response=self._snapshot_response("in_progress", output_messages)
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseInProgress(
response=self._snapshot_response("in_progress", output_messages),
sequence_number=self.sequence_number,
)
async for stream_event in self._process_tools(output_messages):
yield stream_event
n_iter = 0
messages = self.ctx.messages.copy()
final_status = "completed"
last_completion_result: ChatCompletionResult | None = None
while True:
# Text is the default response format for chat completion so don't need to pass it
# (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
completion_result = await self.inference_api.openai_chat_completion(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
)
try:
while True:
# Text is the default response format for chat completion so don't need to pass it
# (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
stream_options={
"include_usage": True,
},
)
completion_result = await self.inference_api.openai_chat_completion(params)
# Process streaming chunks and build complete response
completion_result_data = None
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
if isinstance(stream_event_or_result, ChatCompletionResult):
completion_result_data = stream_event_or_result
else:
yield stream_event_or_result
if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data")
current_response = self._build_chat_completion(completion_result_data)
# Process streaming chunks and build complete response
completion_result_data = None
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
if isinstance(stream_event_or_result, ChatCompletionResult):
completion_result_data = stream_event_or_result
else:
yield stream_event_or_result
if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data")
last_completion_result = completion_result_data
current_response = self._build_chat_completion(completion_result_data)
function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
current_response, messages
)
(
function_tool_calls,
non_function_tool_calls,
approvals,
next_turn_messages,
) = self._separate_tool_calls(current_response, messages)
# add any approval requests required
for tool_call in approvals:
async for evt in self._add_mcp_approval_request(
tool_call.function.name, tool_call.function.arguments, output_messages
# add any approval requests required
for tool_call in approvals:
async for evt in self._add_mcp_approval_request(
tool_call.function.name, tool_call.function.arguments, output_messages
):
yield evt
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
output_messages.append(
await convert_chat_choice_to_response_message(
choice,
self.citation_files,
message_id=completion_result_data.message_item_id,
)
)
# Execute tool calls and coordinate results
async for stream_event in self._coordinate_tool_execution(
function_tool_calls,
non_function_tool_calls,
completion_result_data,
output_messages,
next_turn_messages,
):
yield evt
yield stream_event
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
output_messages.append(await convert_chat_choice_to_response_message(choice))
messages = next_turn_messages
# Execute tool calls and coordinate results
async for stream_event in self._coordinate_tool_execution(
function_tool_calls,
non_function_tool_calls,
completion_result_data,
output_messages,
next_turn_messages,
):
yield stream_event
if not function_tool_calls and not non_function_tool_calls:
break
if not function_tool_calls and not non_function_tool_calls:
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= self.max_infer_iters:
logger.info(
f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}"
)
final_status = "incomplete"
break
n_iter += 1
if n_iter >= self.max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
break
if last_completion_result and last_completion_result.finish_reason == "length":
final_status = "incomplete"
messages = next_turn_messages
except Exception as exc: # noqa: BLE001
self.final_messages = messages.copy()
self.sequence_number += 1
error = OpenAIResponseError(code="internal_error", message=str(exc))
failure_response = self._snapshot_response("failed", output_messages, error=error)
yield OpenAIResponseObjectStreamResponseFailed(
response=failure_response,
sequence_number=self.sequence_number,
)
return
self.final_messages = messages.copy() + [current_response.choices[0].message]
self.final_messages = messages.copy()
# Create final response
final_response = OpenAIResponseObject(
created_at=self.created_at,
id=self.response_id,
model=self.ctx.model,
object="response",
status="completed",
text=self.text,
output=output_messages,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
if final_status == "incomplete":
self.sequence_number += 1
final_response = self._snapshot_response("incomplete", output_messages)
yield OpenAIResponseObjectStreamResponseIncomplete(
response=final_response,
sequence_number=self.sequence_number,
)
else:
final_response = self._snapshot_response("completed", output_messages)
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
"""Separate tool calls into function and non-function categories."""
@ -211,6 +288,8 @@ class StreamingResponseOrchestrator:
for choice in current_response.choices:
next_turn_messages.append(choice.message)
logger.debug(f"Choice message content: {choice.message.content}")
logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}")
if choice.message.tool_calls and self.ctx.response_tools:
for tool_call in choice.message.tool_calls:
@ -227,14 +306,183 @@ class StreamingResponseOrchestrator:
non_function_tool_calls.append(tool_call)
else:
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
next_turn_messages.pop()
else:
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
approvals.append(tool_call)
next_turn_messages.pop()
else:
non_function_tool_calls.append(tool_call)
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
def _accumulate_chunk_usage(self, chunk: OpenAIChatCompletionChunk) -> None:
"""Accumulate usage from a streaming chunk into the response usage format."""
if not chunk.usage:
return
if self.accumulated_usage is None:
# Convert from chat completion format to response format
self.accumulated_usage = OpenAIResponseUsage(
input_tokens=chunk.usage.prompt_tokens,
output_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
input_tokens_details=(
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
if chunk.usage.prompt_tokens_details
else None
),
output_tokens_details=(
OpenAIResponseUsageOutputTokensDetails(
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
)
if chunk.usage.completion_tokens_details
else None
),
)
else:
# Accumulate across multiple inference calls
self.accumulated_usage = OpenAIResponseUsage(
input_tokens=self.accumulated_usage.input_tokens + chunk.usage.prompt_tokens,
output_tokens=self.accumulated_usage.output_tokens + chunk.usage.completion_tokens,
total_tokens=self.accumulated_usage.total_tokens + chunk.usage.total_tokens,
# Use latest non-null details
input_tokens_details=(
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
if chunk.usage.prompt_tokens_details
else self.accumulated_usage.input_tokens_details
),
output_tokens_details=(
OpenAIResponseUsageOutputTokensDetails(
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
)
if chunk.usage.completion_tokens_details
else self.accumulated_usage.output_tokens_details
),
)
async def _handle_reasoning_content_chunk(
self,
reasoning_content: str,
reasoning_part_emitted: bool,
reasoning_content_index: int,
message_item_id: str,
message_output_index: int,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Emit content_part.added event for first reasoning chunk
if not reasoning_part_emitted:
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
content_index=reasoning_content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartReasoningText(
text="", # Will be filled incrementally via reasoning deltas
),
sequence_number=self.sequence_number,
)
# Emit reasoning_text.delta event
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseReasoningTextDelta(
content_index=reasoning_content_index,
delta=reasoning_content,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
async def _handle_refusal_content_chunk(
self,
refusal_content: str,
refusal_part_emitted: bool,
refusal_content_index: int,
message_item_id: str,
message_output_index: int,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Emit content_part.added event for first refusal chunk
if not refusal_part_emitted:
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
content_index=refusal_content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartRefusal(
refusal="", # Will be filled incrementally via refusal deltas
),
sequence_number=self.sequence_number,
)
# Emit refusal.delta event
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseRefusalDelta(
content_index=refusal_content_index,
delta=refusal_content,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
async def _emit_reasoning_done_events(
self,
reasoning_text_accumulated: list[str],
reasoning_content_index: int,
message_item_id: str,
message_output_index: int,
) -> AsyncIterator[OpenAIResponseObjectStream]:
final_reasoning_text = "".join(reasoning_text_accumulated)
# Emit reasoning_text.done event
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseReasoningTextDone(
content_index=reasoning_content_index,
text=final_reasoning_text,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
# Emit content_part.done for reasoning
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
content_index=reasoning_content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartReasoningText(
text=final_reasoning_text,
),
sequence_number=self.sequence_number,
)
async def _emit_refusal_done_events(
self,
refusal_text_accumulated: list[str],
refusal_content_index: int,
message_item_id: str,
message_output_index: int,
) -> AsyncIterator[OpenAIResponseObjectStream]:
final_refusal_text = "".join(refusal_text_accumulated)
# Emit refusal.done event
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseRefusalDone(
content_index=refusal_content_index,
refusal=final_refusal_text,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
# Emit content_part.done for refusal
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
content_index=refusal_content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartRefusal(
refusal=final_refusal_text,
),
sequence_number=self.sequence_number,
)
async def _process_streaming_chunks(
self, completion_result, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
@ -253,11 +501,23 @@ class StreamingResponseOrchestrator:
tool_call_item_ids: dict[int, str] = {}
# Track content parts for streaming events
content_part_emitted = False
reasoning_part_emitted = False
refusal_part_emitted = False
content_index = 0
reasoning_content_index = 1 # reasoning is a separate content part
refusal_content_index = 2 # refusal is a separate content part
message_output_index = len(output_messages)
reasoning_text_accumulated = []
refusal_text_accumulated = []
async for chunk in completion_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
# Accumulate usage from chunks (typically in final chunk with stream_options)
self._accumulate_chunk_usage(chunk)
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
@ -266,8 +526,10 @@ class StreamingResponseOrchestrator:
content_part_emitted = True
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
content_index=content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartOutputText(
text="", # Will be filled incrementally via text deltas
),
@ -275,10 +537,10 @@ class StreamingResponseOrchestrator:
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
content_index=content_index,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
@ -287,6 +549,32 @@ class StreamingResponseOrchestrator:
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Handle reasoning content if present (non-standard field for o1/o3 models)
if hasattr(chunk_choice.delta, "reasoning_content") and chunk_choice.delta.reasoning_content:
async for event in self._handle_reasoning_content_chunk(
reasoning_content=chunk_choice.delta.reasoning_content,
reasoning_part_emitted=reasoning_part_emitted,
reasoning_content_index=reasoning_content_index,
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
reasoning_part_emitted = True
reasoning_text_accumulated.append(chunk_choice.delta.reasoning_content)
# Handle refusal content if present
if chunk_choice.delta.refusal:
async for event in self._handle_refusal_content_chunk(
refusal_content=chunk_choice.delta.refusal,
refusal_part_emitted=refusal_part_emitted,
refusal_content_index=refusal_content_index,
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
refusal_part_emitted = True
refusal_text_accumulated.append(chunk_choice.delta.refusal)
# Aggregate tool call arguments across chunks
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
@ -378,14 +666,36 @@ class StreamingResponseOrchestrator:
final_text = "".join(chat_response_content)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
content_index=content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartOutputText(
text=final_text,
),
sequence_number=self.sequence_number,
)
# Emit reasoning done events if reasoning content was streamed
if reasoning_part_emitted:
async for event in self._emit_reasoning_done_events(
reasoning_text_accumulated=reasoning_text_accumulated,
reasoning_content_index=reasoning_content_index,
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
# Emit refusal done events if refusal content was streamed
if refusal_part_emitted:
async for event in self._emit_refusal_done_events(
refusal_text_accumulated=refusal_text_accumulated,
refusal_content_index=refusal_content_index,
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
# Clear content when there are tool calls (OpenAI spec behavior)
if chat_response_tool_calls:
chat_response_content = []
@ -470,6 +780,8 @@ class StreamingResponseOrchestrator:
tool_call_log = result.final_output_message
tool_response_message = result.final_input_message
self.sequence_number = result.sequence_number
if result.citation_files:
self.citation_files.update(result.citation_files)
if tool_call_log:
output_messages.append(tool_call_log)
@ -518,7 +830,7 @@ class StreamingResponseOrchestrator:
sequence_number=self.sequence_number,
)
async def _process_tools(
async def _process_new_tools(
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Process all tools and emit appropriate streaming events."""
@ -573,7 +885,6 @@ class StreamingResponseOrchestrator:
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
sequence_number=self.sequence_number,
)
try:
# Parse allowed/never allowed tools
always_allowed = None
@ -586,14 +897,22 @@ class StreamingResponseOrchestrator:
never_allowed = mcp_tool.allowed_tools.never
# Call list_mcp_tools
tool_defs = await list_mcp_tools(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
)
tool_defs = None
list_id = f"mcp_list_{uuid.uuid4()}"
attributes = {
"server_label": mcp_tool.server_label,
"server_url": mcp_tool.server_url,
"mcp_list_tools_id": list_id,
}
async with tracing.span("list_mcp_tools", attributes):
tool_defs = await list_mcp_tools(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
)
# Create the MCP list tools message
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
id=list_id,
server_label=mcp_tool.server_label,
tools=[],
)
@ -627,39 +946,26 @@ class StreamingResponseOrchestrator:
},
)
)
# Add the MCP list message to output
output_messages.append(mcp_list_message)
# Emit output_item.added for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
# Emit mcp_list_tools.completed
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
sequence_number=self.sequence_number,
)
# Emit output_item.done for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
yield stream_event
except Exception as e:
# TODO: Emit mcp_list_tools.failed event if needed
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
raise
async def _process_tools(
self, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Handle all mcp tool lists from previous response that are still valid:
for tool in self.ctx.tool_context.previous_tool_listings:
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
yield evt
# Process all remaining tools (including MCP tools) and emit streaming events
if self.ctx.tool_context.tools_to_process:
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
yield stream_event
def _approval_required(self, tool_name: str) -> bool:
if tool_name not in self.mcp_tool_to_server:
return False
@ -694,7 +1000,6 @@ class StreamingResponseOrchestrator:
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
@ -702,3 +1007,60 @@ class StreamingResponseOrchestrator:
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
async def _add_mcp_list_tools(
self, mcp_list_message: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Add the MCP list message to output
output_messages.append(mcp_list_message)
# Emit output_item.added for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
# Emit mcp_list_tools.completed
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
sequence_number=self.sequence_number,
)
# Emit output_item.done for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
async def _reuse_mcp_list_tools(
self, original: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
for t in original.tools:
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
# convert from input_schema to map of ToolParamDefinitions...
tool_def = ToolDefinition(
tool_name=t.name,
description=t.description,
input_schema=t.input_schema,
)
# ...then can convert that to openai completions tool
openai_tool = convert_tooldef_to_openai_tool(tool_def)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
server_label=original.server_label,
tools=original.tools,
)
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
yield stream_event

View file

@ -11,6 +11,9 @@ from collections.abc import AsyncIterator
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseObjectStreamResponseFileSearchCallCompleted,
OpenAIResponseObjectStreamResponseFileSearchCallInProgress,
OpenAIResponseObjectStreamResponseFileSearchCallSearching,
OpenAIResponseObjectStreamResponseMcpCallCompleted,
OpenAIResponseObjectStreamResponseMcpCallFailed,
OpenAIResponseObjectStreamResponseMcpCallInProgress,
@ -35,6 +38,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
from .types import ChatCompletionContext, ToolExecutionResult
@ -94,7 +98,10 @@ class ToolExecutor:
# Yield the final result
yield ToolExecutionResult(
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
sequence_number=sequence_number,
final_output_message=output_message,
final_input_message=input_message,
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
)
async def _execute_knowledge_search_via_vector_store(
@ -129,8 +136,6 @@ class ToolExecutor:
for results in all_results:
search_results.extend(results)
# Convert search results to tool result format matching memory.py
# Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
@ -138,27 +143,58 @@ class ToolExecutor:
)
)
unique_files = set()
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
# Get file_id from attributes if result_item.file_id is empty
file_id = result_item.file_id or (
result_item.attributes.get("document_id") if result_item.attributes else None
)
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
unique_files.add(file_id)
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
citation_instruction = ""
if unique_files:
citation_instruction = (
" Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). "
"Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
)
content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
)
)
# handling missing attributes for old versions
citation_files = {}
for result in search_results:
file_id = result.file_id
if not file_id and result.attributes:
file_id = result.attributes.get("document_id")
filename = result.filename
if not filename and result.attributes:
filename = result.attributes.get("filename")
if not filename:
filename = "unknown"
citation_files[file_id] = filename
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
"citation_files": citation_files,
},
)
@ -188,7 +224,13 @@ class ToolExecutor:
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
elif function_name == "knowledge_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
@ -203,6 +245,16 @@ class ToolExecutor:
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
# For file search, emit searching event
if function_name == "knowledge_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
async def _execute_tool(
self,
function_name: str,
@ -219,12 +271,18 @@ class ToolExecutor:
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
mcp_tool = mcp_tool_to_server[function_name]
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
)
attributes = {
"server_label": mcp_tool.server_label,
"server_url": mcp_tool.server_url,
"tool_name": function_name,
}
async with tracing.span("invoke_mcp_tool", attributes):
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
)
elif function_name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
@ -234,15 +292,20 @@ class ToolExecutor:
# Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options
query = tool_kwargs.get("query", "")
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
async with tracing.span("knowledge_search", {}):
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function_name,
kwargs=tool_kwargs,
)
attributes = {
"tool_name": function_name,
}
async with tracing.span("invoke_tool", attributes):
result = await self.tool_runtime_api.invoke_tool(
tool_name=function_name,
kwargs=tool_kwargs,
)
except Exception as e:
error_exc = e
@ -278,7 +341,13 @@ class ToolExecutor:
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
elif function_name == "knowledge_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)

View file

@ -12,10 +12,18 @@ from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseOutput,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseTool,
OpenAIResponseToolMCP,
)
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
@ -27,6 +35,7 @@ class ToolExecutionResult(BaseModel):
sequence_number: int
final_output_message: OpenAIResponseOutput | None = None
final_input_message: OpenAIMessageParam | None = None
citation_files: dict[str, str] | None = None
@dataclass
@ -54,6 +63,86 @@ class ChatCompletionResult:
return bool(self.tool_calls)
class ToolContext(BaseModel):
"""Holds information about tools from this and (if relevant)
previous response in order to facilitate reuse of previous
listings where appropriate."""
# tools argument passed into current request:
current_tools: list[OpenAIResponseInputTool]
# reconstructed map of tool -> mcp server from previous response:
previous_tools: dict[str, OpenAIResponseInputToolMCP]
# reusable mcp-list-tools objects from previous response:
previous_tool_listings: list[OpenAIResponseOutputMessageMCPListTools]
# tool arguments from current request that still need to be processed:
tools_to_process: list[OpenAIResponseInputTool]
def __init__(
self,
current_tools: list[OpenAIResponseInputTool] | None,
):
super().__init__(
current_tools=current_tools or [],
previous_tools={},
previous_tool_listings=[],
tools_to_process=current_tools or [],
)
def recover_tools_from_previous_response(
self,
previous_response: OpenAIResponseObject,
):
"""Determine which mcp_list_tools objects from previous response we can reuse."""
if self.current_tools and previous_response.tools:
previous_tools_by_label: dict[str, OpenAIResponseToolMCP] = {}
for tool in previous_response.tools:
if isinstance(tool, OpenAIResponseToolMCP):
previous_tools_by_label[tool.server_label] = tool
# collect tool definitions which are the same in current and previous requests:
tools_to_process = []
matched: dict[str, OpenAIResponseInputToolMCP] = {}
for tool in self.current_tools:
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
previous_tool = previous_tools_by_label[tool.server_label]
if previous_tool.allowed_tools == tool.allowed_tools:
matched[tool.server_label] = tool
else:
tools_to_process.append(tool)
else:
tools_to_process.append(tool)
# tools that are not the same or were not previously defined need to be processed:
self.tools_to_process = tools_to_process
# for all matched definitions, get the mcp_list_tools objects from the previous output:
self.previous_tool_listings = [
obj for obj in previous_response.output if obj.type == "mcp_list_tools" and obj.server_label in matched
]
# reconstruct the tool to server mappings that can be reused:
for listing in self.previous_tool_listings:
definition = matched[listing.server_label]
for tool in listing.tools:
self.previous_tools[tool.name] = definition
def available_tools(self) -> list[OpenAIResponseTool]:
if not self.current_tools:
return []
def convert_tool(tool: OpenAIResponseInputTool) -> OpenAIResponseTool:
if isinstance(tool, OpenAIResponseInputToolWebSearch):
return tool
if isinstance(tool, OpenAIResponseInputToolFileSearch):
return tool
if isinstance(tool, OpenAIResponseInputToolFunction):
return tool
if isinstance(tool, OpenAIResponseInputToolMCP):
return OpenAIResponseToolMCP(
server_label=tool.server_label,
allowed_tools=tool.allowed_tools,
)
return [convert_tool(tool) for tool in self.current_tools]
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
@ -61,6 +150,7 @@ class ChatCompletionContext(BaseModel):
chat_tools: list[ChatCompletionToolParam] | None = None
temperature: float | None
response_format: OpenAIResponseFormatParam
tool_context: ToolContext | None
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
@ -71,6 +161,7 @@ class ChatCompletionContext(BaseModel):
response_tools: list[OpenAIResponseInputTool] | None,
temperature: float | None,
response_format: OpenAIResponseFormatParam,
tool_context: ToolContext,
inputs: list[OpenAIResponseInput] | str,
):
super().__init__(
@ -79,6 +170,7 @@ class ChatCompletionContext(BaseModel):
response_tools=response_tools,
temperature=temperature,
response_format=response_format,
tool_context=tool_context,
)
if not isinstance(inputs, str):
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
@ -95,3 +187,8 @@ class ChatCompletionContext(BaseModel):
if request.name == tool_name and request.arguments == arguments:
return request
return None
def available_tools(self) -> list[OpenAIResponseTool]:
if not self.tool_context:
return []
return self.tool_context.available_tools()

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import uuid
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
@ -45,7 +47,12 @@ from llama_stack.apis.inference import (
)
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
async def convert_chat_choice_to_response_message(
choice: OpenAIChoice,
citation_files: dict[str, str] | None = None,
*,
message_id: str | None = None,
) -> OpenAIResponseMessage:
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
output_content = ""
if isinstance(choice.message.content, str):
@ -57,9 +64,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
id=message_id or f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
status="completed",
role="assistant",
)
@ -97,9 +106,13 @@ async def convert_response_content_to_chat_content(
async def convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
previous_messages: list[OpenAIMessageParam] | None = None,
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
:param input: The input to convert
:param previous_messages: Optional previous messages to check for function_call references
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
@ -163,16 +176,53 @@ async def convert_response_input_to_chat_messages(
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
# Skip user messages that duplicate the last user message in previous_messages
# This handles cases where input includes context for function_call_outputs
if previous_messages and input_item.role == "user":
last_user_msg = None
for msg in reversed(previous_messages):
if isinstance(msg, OpenAIUserMessageParam):
last_user_msg = msg
break
if last_user_msg:
last_user_content = getattr(last_user_msg, "content", None)
if last_user_content == content:
continue # Skip duplicate user message
messages.append(message_type(content=content))
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
# Check if unpaired function_call_outputs reference function_calls from previous messages
if previous_messages:
previous_call_ids = _extract_tool_call_ids(previous_messages)
for call_id in list(tool_call_results.keys()):
if call_id in previous_call_ids:
# Valid: this output references a call from previous messages
# Add the tool message
messages.append(tool_call_results[call_id])
del tool_call_results[call_id]
# If still have unpaired outputs, error
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]:
"""Extract all tool_call IDs from messages."""
call_ids = set()
for msg in messages:
if isinstance(msg, OpenAIAssistantMessageParam):
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
for tool_call in tool_calls:
# tool_call is a Pydantic model, use attribute access
call_ids.add(tool_call.id)
return call_ids
async def convert_response_text_to_chat_response_format(
text: OpenAIResponseText,
) -> OpenAIResponseFormatParam:
@ -200,6 +250,53 @@ async def get_message_type_by_role(role: str):
return role_to_type.get(role)
def _extract_citations_from_text(
text: str, citation_files: dict[str, str]
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
"""Extract citation markers from text and create annotations
Args:
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
citation_files: Dictionary mapping file_id to filename
Returns:
Tuple of (annotations_list, clean_text_without_markers)
"""
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
annotations = []
parts = []
total_len = 0
last_end = 0
for m in file_id_regex.finditer(text):
# segment before the marker
prefix = text[last_end : m.start()]
# drop one space if it exists (since marker is at sentence end)
if prefix.endswith(" "):
prefix = prefix[:-1]
parts.append(prefix)
total_len += len(prefix)
fid = m.group(1)
if fid in citation_files:
annotations.append(
OpenAIResponseAnnotationFileCitation(
file_id=fid,
filename=citation_files[fid],
index=total_len, # index points to punctuation
)
)
last_end = m.end()
parts.append(text[last_end:])
cleaned_text = "".join(parts)
return annotations, cleaned_text
def is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],

View file

@ -22,7 +22,10 @@ from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIDeveloperMessageParam,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
@ -178,9 +181,9 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
@ -425,18 +428,23 @@ class ReferenceBatchesImpl(Batches):
valid = False
if batch.endpoint == "/v1/chat/completions":
required_params = [
required_params: list[tuple[str, Any, str]] = [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]
else: # /v1/completions
elif batch.endpoint == "/v1/completions":
required_params = [
("model", str, "a string"),
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
]
else: # /v1/embeddings
required_params = [
("model", str, "a string"),
("input", (str, list), "a string or array of strings"),
]
for param, expected_type, type_string in required_params:
if param not in body:
@ -601,7 +609,8 @@ class ReferenceBatchesImpl(Batches):
# TODO(SECURITY): review body for security issues
if request.url == "/v1/chat/completions":
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
chat_params = OpenAIChatCompletionRequestWithExtraBody(**request.body)
chat_response = await self.inference_api.openai_chat_completion(chat_params)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
@ -614,8 +623,9 @@ class ReferenceBatchesImpl(Batches):
"body": chat_response.model_dump_json(),
},
}
else: # /v1/completions
completion_response = await self.inference_api.openai_completion(**request.body)
elif request.url == "/v1/completions":
completion_params = OpenAICompletionRequestWithExtraBody(**request.body)
completion_response = await self.inference_api.openai_completion(completion_params)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), (
@ -630,6 +640,22 @@ class ReferenceBatchesImpl(Batches):
"body": completion_response.model_dump_json(),
},
}
else: # /v1/embeddings
embeddings_response = await self.inference_api.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(**request.body)
)
assert hasattr(embeddings_response, "model_dump_json"), (
"Embeddings response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": embeddings_response.model_dump_json(),
},
}
except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return {

View file

@ -12,7 +12,14 @@ from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage
from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAISystemMessageParam,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
@ -168,11 +175,12 @@ class MetaReferenceEvalImpl(
sampling_params["stop"] = candidate.sampling_params.stop
input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.openai_completion(
params = OpenAICompletionRequestWithExtraBody(
model=candidate.model,
prompt=input_content,
**sampling_params,
)
response = await self.inference_api.openai_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
@ -187,11 +195,12 @@ class MetaReferenceEvalImpl(
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages
response = await self.inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=candidate.model,
messages=messages,
**sampling_params,
)
response = await self.inference_api.openai_chat_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
else:
raise ValueError("Invalid input row")

View file

@ -22,6 +22,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
@ -65,7 +66,7 @@ class LocalfsFilesImpl(Files):
def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API."""
return f"file-{uuid.uuid4().hex}"
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
def _get_file_path(self, file_id: str) -> Path:
"""Get the filesystem path for a file ID."""
@ -95,7 +96,9 @@ class LocalfsFilesImpl(Files):
raise RuntimeError("Files provider not initialized")
if expires_after is not None:
raise NotImplementedError("File expiration is not supported by this provider")
logger.warning(
f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}"
)
file_id = self._generate_file_id()
file_path = self._get_file_path(file_id)

View file

@ -18,7 +18,7 @@ def model_checkpoint_dir(model_id) -> str:
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
f"If you try to use the native llama model, please download the model using `llama-model download --source meta --model-id {model_id}` (see https://github.com/meta-llama/llama-models). "
f"Otherwise, please save your model checkpoint under {model_local_dir(model_id)}"
)
return str(checkpoint_dir)

View file

@ -6,16 +6,16 @@
import asyncio
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import (
InferenceProvider,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAICompletion,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
@ -65,7 +65,10 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group:
self.generator.stop()
async def openai_completion(self, *args, **kwargs):
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
async def should_refresh_models(self) -> bool:
@ -150,28 +153,6 @@ class MetaReferenceInferenceImpl(
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")

View file

@ -5,17 +5,16 @@
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import (
InferenceProvider,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
@ -73,56 +72,12 @@ class SentenceTransformersInferenceImpl(
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")

View file

@ -104,9 +104,10 @@ class LoraFinetuningSingleDevice:
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
hf_repo = model.huggingface_repo or f"meta-llama/{model.descriptor()}"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
f"Please download the model using `huggingface-cli download {hf_repo} --local-dir ~/.llama/{model.descriptor()}`"
)
return str(checkpoint_dir)

View file

@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from codeshield.cs import CodeShieldScanResult
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -53,7 +53,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)

View file

@ -10,7 +10,12 @@ from string import Template
from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import Inference, Message, UserMessage
from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -159,7 +164,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
@ -169,8 +174,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
if len(messages) > 0 and messages[0].role != "user":
messages[0] = OpenAIUserMessageParam(content=messages[0].content)
# Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
@ -202,7 +207,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
messages = [input]
# convert to user messages format with role
messages = [UserMessage(content=m) for m in messages]
messages = [OpenAIUserMessageParam(content=m) for m in messages]
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
@ -271,7 +276,7 @@ class LlamaGuardShield:
return final_categories
def validate_messages(self, messages: list[Message]) -> None:
def validate_messages(self, messages: list[OpenAIMessageParam]) -> list[OpenAIMessageParam]:
if len(messages) == 0:
raise ValueError("Messages must not be empty")
if messages[0].role != Role.user.value:
@ -282,7 +287,7 @@ class LlamaGuardShield:
return messages
async def run(self, messages: list[Message]) -> RunShieldResponse:
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
messages = self.validate_messages(messages)
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
@ -290,20 +295,21 @@ class LlamaGuardShield:
else:
shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content
content = content.strip()
return self.get_shield_response(content)
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages))
def build_text_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
return OpenAIUserMessageParam(content=self.build_prompt(messages))
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage:
def build_vision_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
conversation = []
most_recent_img = None
@ -326,7 +332,7 @@ class LlamaGuardShield:
else:
raise ValueError(f"Unknown content type: {c}")
conversation.append(UserMessage(content=content))
conversation.append(OpenAIUserMessageParam(content=content))
else:
raise ValueError(f"Unknown content type: {m.content}")
@ -335,9 +341,9 @@ class LlamaGuardShield:
prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1]))
return UserMessage(content=prompt)
return OpenAIUserMessageParam(content=prompt)
def build_prompt(self, messages: list[Message]) -> str:
def build_prompt(self, messages: list[OpenAIMessageParam]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
@ -370,18 +376,20 @@ class LlamaGuardShield:
raise ValueError(f"Unexpected response: {response}")
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
async def run_moderation(self, messages: list[OpenAIMessageParam]) -> ModerationObject:
if not messages:
return self.create_moderation_object(self.model)
# TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content
content = content.strip()
return self.get_moderation_object(content)

View file

@ -9,7 +9,7 @@ from typing import Any
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -22,9 +22,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from .config import PromptGuardConfig, PromptGuardType
@ -56,7 +54,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any],
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
@ -93,7 +91,7 @@ class PromptGuardShield:
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
async def run(self, messages: list[Message]) -> RunShieldResponse:
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
message = messages[-1]
text = interleaved_content_as_str(message.content)

View file

@ -6,7 +6,7 @@
import re
from typing import Any
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import Inference, OpenAIChatCompletionRequestWithExtraBody
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
@ -55,7 +55,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
generated_answer=generated_answer,
)
judge_response = await self.inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=fn_def.params.judge_model,
messages=[
{
@ -64,6 +64,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
}
],
)
judge_response = await self.inference_api.openai_chat_completion(params)
content = judge_response.choices[0].message.content
rating_regexes = fn_def.params.judge_score_regexes

View file

@ -8,7 +8,7 @@
from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
@ -65,11 +65,12 @@ async def llm_rag_query_generator(
model = config.model
message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=model,
messages=[message],
stream=False,
)
response = await inference_api.openai_chat_completion(params)
query = response.choices[0].message.content

View file

@ -8,8 +8,6 @@ import asyncio
import base64
import io
import mimetypes
import secrets
import string
from typing import Any
import httpx
@ -52,10 +50,6 @@ from .context_retriever import generate_rag_query
log = get_logger(name=__name__, category="tool_runtime")
def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
@ -331,5 +325,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
return ToolInvocationResult(
content=result.content or [],
metadata=result.metadata,
metadata={
**(result.metadata or {}),
"citation_files": getattr(result, "citation_files", None),
},
)

View file

@ -200,12 +200,10 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.persistence)
@ -227,8 +225,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
# Cleanup if needed
pass
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def health(self) -> HealthResponse:
"""

View file

@ -410,12 +410,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
"""
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.persistence)
@ -436,8 +434,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]

View file

@ -32,9 +32,12 @@ def available_providers() -> list[ProviderSpec]:
Api.inference,
Api.safety,
Api.vector_io,
Api.vector_dbs,
Api.tool_runtime,
Api.tool_groups,
Api.conversations,
],
optional_api_dependencies=[
Api.telemetry,
],
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
),

View file

@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="cerebras",
provider_type="remote::cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
@ -169,7 +167,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="openai",
provider_type="remote::openai",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="anthropic",
provider_type="remote::anthropic",
pip_packages=["litellm"],
pip_packages=["anthropic"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
@ -189,9 +187,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="gemini",
provider_type="remote::gemini",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
@ -202,7 +198,6 @@ def available_providers() -> list[ProviderSpec]:
adapter_type="vertexai",
provider_type="remote::vertexai",
pip_packages=[
"litellm",
"google-cloud-aiplatform",
],
module="llama_stack.providers.remote.inference.vertexai",
@ -233,9 +228,7 @@ Available Models:
api=Api.inference,
adapter_type="groq",
provider_type="remote::groq",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
@ -245,7 +238,7 @@ Available Models:
api=Api.inference,
adapter_type="llama-openai-compat",
provider_type="remote::llama-openai-compat",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
@ -255,9 +248,7 @@ Available Models:
api=Api.inference,
adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
@ -277,7 +268,7 @@ Available Models:
api=Api.inference,
adapter_type="watsonx",
provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"],
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
@ -287,7 +278,7 @@ Available Models:
api=Api.inference,
provider_type="remote::azure",
adapter_type="azure",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.azure",
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",

View file

@ -11,6 +11,7 @@ from llama_stack.providers.datatypes import (
ProviderSpec,
RemoteProviderSpec,
)
from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS
def available_providers() -> list[ProviderSpec]:
@ -18,9 +19,8 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=[
"chardet",
"pypdf",
pip_packages=DEFAULT_VECTOR_IO_DEPS
+ [
"tqdm",
"numpy",
"scikit-learn",

View file

@ -12,13 +12,16 @@ from llama_stack.providers.datatypes import (
RemoteProviderSpec,
)
# Common dependencies for all vector IO providers that support document processing
DEFAULT_VECTOR_IO_DEPS = ["chardet", "pypdf"]
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::meta-reference",
pip_packages=["faiss-cpu"],
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.",
@ -29,7 +32,7 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::faiss",
pip_packages=["faiss-cpu"],
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
api_dependencies=[Api.inference],
@ -82,7 +85,7 @@ more details about Faiss in general.
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite-vec",
pip_packages=["sqlite-vec"],
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
api_dependencies=[Api.inference],
@ -289,7 +292,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite_vec",
pip_packages=["sqlite-vec"],
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
@ -303,7 +306,7 @@ Please refer to the sqlite-vec provider documentation.
api=Api.vector_io,
adapter_type="chromadb",
provider_type="remote::chromadb",
pip_packages=["chromadb-client"],
pip_packages=["chromadb-client"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
@ -345,7 +348,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::chromadb",
pip_packages=["chromadb"],
pip_packages=["chromadb"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.chroma",
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
@ -389,7 +392,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
api=Api.vector_io,
adapter_type="pgvector",
provider_type="remote::pgvector",
pip_packages=["psycopg2-binary"],
pip_packages=["psycopg2-binary"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
api_dependencies=[Api.inference],
@ -500,7 +503,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
api=Api.vector_io,
adapter_type="weaviate",
provider_type="remote::weaviate",
pip_packages=["weaviate-client>=4.16.5"],
pip_packages=["weaviate-client>=4.16.5"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
@ -541,7 +544,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::qdrant",
pip_packages=["qdrant-client"],
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.qdrant",
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
@ -594,7 +597,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
api=Api.vector_io,
adapter_type="qdrant",
provider_type="remote::qdrant",
pip_packages=["qdrant-client"],
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
@ -607,7 +610,7 @@ Please refer to the inline provider documentation.
api=Api.vector_io,
adapter_type="milvus",
provider_type="remote::milvus",
pip_packages=["pymilvus>=2.4.10"],
pip_packages=["pymilvus>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
@ -813,7 +816,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::milvus",
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
pip_packages=["pymilvus[milvus-lite]>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],

View file

@ -23,6 +23,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
@ -198,7 +199,7 @@ class S3FilesImpl(Files):
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
filename = getattr(file, "filename", None) or "uploaded_file"

View file

@ -10,6 +10,6 @@ from .config import AnthropicConfig
async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter
impl = AnthropicInferenceAdapter(config)
impl = AnthropicInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,13 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from collections.abc import Iterable
from anthropic import AsyncAnthropic
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AnthropicConfig
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class AnthropicInferenceAdapter(OpenAIMixin):
config: AnthropicConfig
provider_data_api_key_field: str = "anthropic_api_key"
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
# TODO: add support for voyageai, which is where these models are hosted
# embedding_model_metadata = {
@ -23,22 +29,8 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# }
def __init__(self, config: AnthropicConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",
)
self.config = config
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self):
return "https://api.anthropic.com/v1"
async def list_provider_model_ids(self) -> Iterable[str]:
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]

View file

@ -21,11 +21,6 @@ class AnthropicProviderDataValidator(BaseModel):
@json_schema_type
class AnthropicConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Anthropic models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {

View file

@ -10,6 +10,6 @@ from .config import AzureConfig
async def get_adapter_impl(config: AzureConfig, _deps):
from .azure import AzureInferenceAdapter
impl = AzureInferenceAdapter(config)
impl = AzureInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,31 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from urllib.parse import urljoin
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: AzureConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="azure",
api_key_from_config=config.api_key.get_secret_value(),
provider_data_api_key_field="azure_api_key",
openai_compat_api_base=str(config.api_base),
)
self.config = config
class AzureInferenceAdapter(OpenAIMixin):
config: AzureConfig
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
provider_data_api_key_field: str = "azure_api_key"
def get_base_url(self) -> str:
"""
@ -37,26 +23,3 @@ class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Returns the Azure API base URL from the configuration.
"""
return urljoin(str(self.config.api_base), "/openai/v1")
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add Azure specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "azure_api_key", None):
params["api_key"] = provider_data.azure_api_key
if getattr(provider_data, "azure_api_base", None):
params["api_base"] = provider_data.azure_api_base
if getattr(provider_data, "azure_api_version", None):
params["api_version"] = provider_data.azure_api_version
if getattr(provider_data, "azure_api_type", None):
params["api_type"] = provider_data.azure_api_type
else:
params["api_key"] = self.config.api_key.get_secret_value()
params["api_base"] = str(self.config.api_base)
params["api_version"] = self.config.api_version
params["api_type"] = self.config.api_type
return params

View file

@ -32,9 +32,6 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type
class AzureConfig(RemoteInferenceProviderConfig):
api_key: SecretStr = Field(
description="Azure API key for Azure",
)
api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
)

View file

@ -6,21 +6,21 @@
import json
from collections.abc import AsyncIterator
from typing import Any
from botocore.client import BaseClient
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
@ -125,66 +125,18 @@ class BedrockInferenceAdapter(
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")

View file

@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
impl = CerebrasInferenceAdapter(config)
impl = CerebrasInferenceAdapter(config=config)
await impl.initialize()

View file

@ -6,77 +6,23 @@
from urllib.parse import urljoin
from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Inference,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
TopKSamplingStrategy,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import CerebrasImplConfig
class CerebrasInferenceAdapter(
OpenAIMixin,
Inference,
):
def __init__(self, config: CerebrasImplConfig) -> None:
self.config = config
# TODO: make this use provider data, etc. like other providers
self._cerebras_client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
)
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request)
else:
raise ValueError(f"Unknown request type {type(request)}")
return {
"model": request.model,
"prompt": prompt,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -21,10 +21,6 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
api_key: SecretStr = Field(
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
description="Cerebras API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
impl = DatabricksInferenceAdapter(config=config)
await impl.initialize()
return impl

Some files were not shown because too many files have changed in this diff Show more