mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-24 12:38:04 +00:00
Merge branch 'main' into make-kvstore-optional
This commit is contained in:
commit
f62e6cb063
554 changed files with 63962 additions and 4870 deletions
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.distribution.library_client import ( # noqa: F401
|
||||
from llama_stack.core.library_client import ( # noqa: F401
|
||||
AsyncLlamaStackAsLibraryClient,
|
||||
LlamaStackAsLibraryClient,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -152,7 +152,17 @@ Step = Annotated[
|
|||
|
||||
@json_schema_type
|
||||
class Turn(BaseModel):
|
||||
"""A single turn in an interaction with an Agentic System."""
|
||||
"""A single turn in an interaction with an Agentic System.
|
||||
|
||||
:param turn_id: Unique identifier for the turn within a session
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param input_messages: List of messages that initiated this turn
|
||||
:param steps: Ordered list of processing steps executed during this turn
|
||||
:param output_message: The model's generated response containing content and metadata
|
||||
:param output_attachments: (Optional) Files or media attached to the agent's response
|
||||
:param started_at: Timestamp when the turn began
|
||||
:param completed_at: (Optional) Timestamp when the turn finished, if completed
|
||||
"""
|
||||
|
||||
turn_id: str
|
||||
session_id: str
|
||||
|
|
@ -167,7 +177,13 @@ class Turn(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Session(BaseModel):
|
||||
"""A single session of an interaction with an Agentic System."""
|
||||
"""A single session of an interaction with an Agentic System.
|
||||
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param session_name: Human-readable name for the session
|
||||
:param turns: List of all turns that have occurred in this session
|
||||
:param started_at: Timestamp when the session was created
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
session_name: str
|
||||
|
|
@ -232,6 +248,13 @@ class AgentConfig(AgentConfigCommon):
|
|||
|
||||
@json_schema_type
|
||||
class Agent(BaseModel):
|
||||
"""An agent instance with configuration and metadata.
|
||||
|
||||
:param agent_id: Unique identifier for the agent
|
||||
:param agent_config: Configuration settings for the agent
|
||||
:param created_at: Timestamp when the agent was created
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
agent_config: AgentConfig
|
||||
created_at: datetime
|
||||
|
|
@ -253,6 +276,14 @@ class AgentTurnResponseEventType(StrEnum):
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||
"""Payload for step start events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param step_type: Type of step being executed
|
||||
:param step_id: Unique identifier for the step within a turn
|
||||
:param metadata: (Optional) Additional metadata for the step
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
|
@ -261,6 +292,14 @@ class AgentTurnResponseStepStartPayload(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||
"""Payload for step completion events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param step_type: Type of step being executed
|
||||
:param step_id: Unique identifier for the step within a turn
|
||||
:param step_details: Complete details of the executed step
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
|
@ -269,6 +308,14 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||
"""Payload for step progress events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param step_type: Type of step being executed
|
||||
:param step_id: Unique identifier for the step within a turn
|
||||
:param delta: Incremental content changes during step execution
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
|
||||
|
|
@ -280,18 +327,36 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||
"""Payload for turn start events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param turn_id: Unique identifier for the turn within a session
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
|
||||
turn_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||
"""Payload for turn completion events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param turn: Complete turn data including all steps and results
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
|
||||
turn: Turn
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||
"""Payload for turn awaiting input events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param turn: Turn data when waiting for external tool responses
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
|
||||
turn: Turn
|
||||
|
||||
|
|
@ -310,21 +375,47 @@ register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPaylo
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseEvent(BaseModel):
|
||||
"""An event in an agent turn response stream.
|
||||
|
||||
:param payload: Event-specific payload containing event data
|
||||
"""
|
||||
|
||||
payload: AgentTurnResponseEventPayload
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentCreateResponse(BaseModel):
|
||||
"""Response returned when creating a new agent.
|
||||
|
||||
:param agent_id: Unique identifier for the created agent
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentSessionCreateResponse(BaseModel):
|
||||
"""Response returned when creating a new agent session.
|
||||
|
||||
:param session_id: Unique identifier for the created session
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||
"""Request to create a new turn for an agent.
|
||||
|
||||
:param agent_id: Unique identifier for the agent
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param messages: List of messages to start the turn with
|
||||
:param documents: (Optional) List of documents to provide to the agent
|
||||
:param toolgroups: (Optional) List of tool groups to make available for this turn
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
:param tool_config: (Optional) Tool configuration to override agent defaults
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
session_id: str
|
||||
|
||||
|
|
@ -342,6 +433,15 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResumeRequest(BaseModel):
|
||||
"""Request to resume an agent turn with tool responses.
|
||||
|
||||
:param agent_id: Unique identifier for the agent
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param turn_id: Unique identifier for the turn within a session
|
||||
:param tool_responses: List of tool responses to submit to continue the turn
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
|
|
@ -351,13 +451,21 @@ class AgentTurnResumeRequest(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStreamChunk(BaseModel):
|
||||
"""streamed agent turn completion response."""
|
||||
"""Streamed agent turn completion response.
|
||||
|
||||
:param event: Individual event in the agent turn response stream
|
||||
"""
|
||||
|
||||
event: AgentTurnResponseEvent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentStepResponse(BaseModel):
|
||||
"""Response containing details of a specific agent step.
|
||||
|
||||
:param step: The complete step data and execution details
|
||||
"""
|
||||
|
||||
step: Step
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,18 +18,37 @@ from llama_stack.schema_utils import json_schema_type, register_schema
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseError(BaseModel):
|
||||
"""Error details for failed OpenAI response requests.
|
||||
|
||||
:param code: Error code identifying the type of failure
|
||||
:param message: Human-readable error message describing the failure
|
||||
"""
|
||||
|
||||
code: str
|
||||
message: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentText(BaseModel):
|
||||
"""Text content for input messages in OpenAI response format.
|
||||
|
||||
:param text: The text content of the input message
|
||||
:param type: Content type identifier, always "input_text"
|
||||
"""
|
||||
|
||||
text: str
|
||||
type: Literal["input_text"] = "input_text"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||
"""Image content for input messages in OpenAI response format.
|
||||
|
||||
:param detail: Level of detail for image processing, can be "low", "high", or "auto"
|
||||
:param type: Content type identifier, always "input_image"
|
||||
:param image_url: (Optional) URL of the image content
|
||||
"""
|
||||
|
||||
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||
type: Literal["input_image"] = "input_image"
|
||||
# TODO: handle file_id
|
||||
|
|
@ -46,6 +65,14 @@ register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMess
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseAnnotationFileCitation(BaseModel):
|
||||
"""File citation annotation for referencing specific files in response content.
|
||||
|
||||
:param type: Annotation type identifier, always "file_citation"
|
||||
:param file_id: Unique identifier of the referenced file
|
||||
:param filename: Name of the referenced file
|
||||
:param index: Position index of the citation within the content
|
||||
"""
|
||||
|
||||
type: Literal["file_citation"] = "file_citation"
|
||||
file_id: str
|
||||
filename: str
|
||||
|
|
@ -54,6 +81,15 @@ class OpenAIResponseAnnotationFileCitation(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseAnnotationCitation(BaseModel):
|
||||
"""URL citation annotation for referencing external web resources.
|
||||
|
||||
:param type: Annotation type identifier, always "url_citation"
|
||||
:param end_index: End position of the citation span in the content
|
||||
:param start_index: Start position of the citation span in the content
|
||||
:param title: Title of the referenced web resource
|
||||
:param url: URL of the referenced web resource
|
||||
"""
|
||||
|
||||
type: Literal["url_citation"] = "url_citation"
|
||||
end_index: int
|
||||
start_index: int
|
||||
|
|
@ -122,6 +158,13 @@ class OpenAIResponseMessage(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||
"""Web search tool call output message for OpenAI responses.
|
||||
|
||||
:param id: Unique identifier for this tool call
|
||||
:param status: Current status of the web search operation
|
||||
:param type: Tool call type identifier, always "web_search_call"
|
||||
"""
|
||||
|
||||
id: str
|
||||
status: str
|
||||
type: Literal["web_search_call"] = "web_search_call"
|
||||
|
|
@ -129,6 +172,15 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
||||
"""File search tool call output message for OpenAI responses.
|
||||
|
||||
:param id: Unique identifier for this tool call
|
||||
:param queries: List of search queries executed
|
||||
:param status: Current status of the file search operation
|
||||
:param type: Tool call type identifier, always "file_search_call"
|
||||
:param results: (Optional) Search results returned by the file search operation
|
||||
"""
|
||||
|
||||
id: str
|
||||
queries: list[str]
|
||||
status: str
|
||||
|
|
@ -138,6 +190,16 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||
"""Function tool call output message for OpenAI responses.
|
||||
|
||||
:param call_id: Unique identifier for the function call
|
||||
:param name: Name of the function being called
|
||||
:param arguments: JSON string containing the function arguments
|
||||
:param type: Tool call type identifier, always "function_call"
|
||||
:param id: (Optional) Additional identifier for the tool call
|
||||
:param status: (Optional) Current status of the function call execution
|
||||
"""
|
||||
|
||||
call_id: str
|
||||
name: str
|
||||
arguments: str
|
||||
|
|
@ -148,6 +210,17 @@ class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageMCPCall(BaseModel):
|
||||
"""Model Context Protocol (MCP) call output message for OpenAI responses.
|
||||
|
||||
:param id: Unique identifier for this MCP call
|
||||
:param type: Tool call type identifier, always "mcp_call"
|
||||
:param arguments: JSON string containing the MCP call arguments
|
||||
:param name: Name of the MCP method being called
|
||||
:param server_label: Label identifying the MCP server handling the call
|
||||
:param error: (Optional) Error message if the MCP call failed
|
||||
:param output: (Optional) Output result from the successful MCP call
|
||||
"""
|
||||
|
||||
id: str
|
||||
type: Literal["mcp_call"] = "mcp_call"
|
||||
arguments: str
|
||||
|
|
@ -158,6 +231,13 @@ class OpenAIResponseOutputMessageMCPCall(BaseModel):
|
|||
|
||||
|
||||
class MCPListToolsTool(BaseModel):
|
||||
"""Tool definition returned by MCP list tools operation.
|
||||
|
||||
:param input_schema: JSON schema defining the tool's input parameters
|
||||
:param name: Name of the tool
|
||||
:param description: (Optional) Description of what the tool does
|
||||
"""
|
||||
|
||||
input_schema: dict[str, Any]
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
|
@ -165,6 +245,14 @@ class MCPListToolsTool(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
||||
"""MCP list tools output message containing available tools from an MCP server.
|
||||
|
||||
:param id: Unique identifier for this MCP list tools operation
|
||||
:param type: Tool call type identifier, always "mcp_list_tools"
|
||||
:param server_label: Label identifying the MCP server providing the tools
|
||||
:param tools: List of available tools provided by the MCP server
|
||||
"""
|
||||
|
||||
id: str
|
||||
type: Literal["mcp_list_tools"] = "mcp_list_tools"
|
||||
server_label: str
|
||||
|
|
@ -206,11 +294,34 @@ class OpenAIResponseTextFormat(TypedDict, total=False):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseText(BaseModel):
|
||||
"""Text response configuration for OpenAI responses.
|
||||
|
||||
:param format: (Optional) Text format configuration specifying output format requirements
|
||||
"""
|
||||
|
||||
format: OpenAIResponseTextFormat | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObject(BaseModel):
|
||||
"""Complete OpenAI response object containing generation results and metadata.
|
||||
|
||||
:param created_at: Unix timestamp when the response was created
|
||||
:param error: (Optional) Error details if the response generation failed
|
||||
:param id: Unique identifier for this response
|
||||
:param model: Model identifier used for generation
|
||||
:param object: Object type identifier, always "response"
|
||||
:param output: List of generated output items (messages, tool calls, etc.)
|
||||
:param parallel_tool_calls: Whether tool calls can be executed in parallel
|
||||
:param previous_response_id: (Optional) ID of the previous response in a conversation
|
||||
:param status: Current status of the response generation
|
||||
:param temperature: (Optional) Sampling temperature used for generation
|
||||
:param text: Text formatting configuration for the response
|
||||
:param top_p: (Optional) Nucleus sampling parameter used for generation
|
||||
:param truncation: (Optional) Truncation strategy applied to the response
|
||||
:param user: (Optional) User identifier associated with the request
|
||||
"""
|
||||
|
||||
created_at: int
|
||||
error: OpenAIResponseError | None = None
|
||||
id: str
|
||||
|
|
@ -231,6 +342,13 @@ class OpenAIResponseObject(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIDeleteResponseObject(BaseModel):
|
||||
"""Response object confirming deletion of an OpenAI response.
|
||||
|
||||
:param id: Unique identifier of the deleted response
|
||||
:param object: Object type identifier, always "response"
|
||||
:param deleted: Deletion confirmation flag, always True
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: Literal["response"] = "response"
|
||||
deleted: bool = True
|
||||
|
|
@ -238,18 +356,39 @@ class OpenAIDeleteResponseObject(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
||||
"""Streaming event indicating a new response has been created.
|
||||
|
||||
:param response: The newly created response object
|
||||
:param type: Event type identifier, always "response.created"
|
||||
"""
|
||||
|
||||
response: OpenAIResponseObject
|
||||
type: Literal["response.created"] = "response.created"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||
"""Streaming event indicating a response has been completed.
|
||||
|
||||
:param response: The completed response object
|
||||
:param type: Event type identifier, always "response.completed"
|
||||
"""
|
||||
|
||||
response: OpenAIResponseObject
|
||||
type: Literal["response.completed"] = "response.completed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseOutputItemAdded(BaseModel):
|
||||
"""Streaming event for when a new output item is added to the response.
|
||||
|
||||
:param response_id: Unique identifier of the response containing this output
|
||||
:param item: The output item that was added (message, tool call, etc.)
|
||||
:param output_index: Index position of this item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.output_item.added"
|
||||
"""
|
||||
|
||||
response_id: str
|
||||
item: OpenAIResponseOutput
|
||||
output_index: int
|
||||
|
|
@ -259,6 +398,15 @@ class OpenAIResponseObjectStreamResponseOutputItemAdded(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseOutputItemDone(BaseModel):
|
||||
"""Streaming event for when an output item is completed.
|
||||
|
||||
:param response_id: Unique identifier of the response containing this output
|
||||
:param item: The completed output item (message, tool call, etc.)
|
||||
:param output_index: Index position of this item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.output_item.done"
|
||||
"""
|
||||
|
||||
response_id: str
|
||||
item: OpenAIResponseOutput
|
||||
output_index: int
|
||||
|
|
@ -268,6 +416,16 @@ class OpenAIResponseObjectStreamResponseOutputItemDone(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
||||
"""Streaming event for incremental text content updates.
|
||||
|
||||
:param content_index: Index position within the text content
|
||||
:param delta: Incremental text content being added
|
||||
:param item_id: Unique identifier of the output item being updated
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.output_text.delta"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
delta: str
|
||||
item_id: str
|
||||
|
|
@ -278,6 +436,16 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseOutputTextDone(BaseModel):
|
||||
"""Streaming event for when text output is completed.
|
||||
|
||||
:param content_index: Index position within the text content
|
||||
:param text: Final complete text content of the output item
|
||||
:param item_id: Unique identifier of the completed output item
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.output_text.done"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
text: str # final text of the output item
|
||||
item_id: str
|
||||
|
|
@ -288,6 +456,15 @@ class OpenAIResponseObjectStreamResponseOutputTextDone(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(BaseModel):
|
||||
"""Streaming event for incremental function call argument updates.
|
||||
|
||||
:param delta: Incremental function call arguments being added
|
||||
:param item_id: Unique identifier of the function call being updated
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.function_call_arguments.delta"
|
||||
"""
|
||||
|
||||
delta: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
|
|
@ -297,6 +474,15 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(BaseModel):
|
||||
"""Streaming event for when function call arguments are completed.
|
||||
|
||||
:param arguments: Final complete arguments JSON string for the function call
|
||||
:param item_id: Unique identifier of the completed function call
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.function_call_arguments.done"
|
||||
"""
|
||||
|
||||
arguments: str # final arguments of the function call
|
||||
item_id: str
|
||||
output_index: int
|
||||
|
|
@ -306,6 +492,14 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseWebSearchCallInProgress(BaseModel):
|
||||
"""Streaming event for web search calls in progress.
|
||||
|
||||
:param item_id: Unique identifier of the web search call
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.web_search_call.in_progress"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
|
|
@ -322,6 +516,14 @@ class OpenAIResponseObjectStreamResponseWebSearchCallSearching(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseWebSearchCallCompleted(BaseModel):
|
||||
"""Streaming event for completed web search calls.
|
||||
|
||||
:param item_id: Unique identifier of the completed web search call
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.web_search_call.completed"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
|
|
@ -366,6 +568,14 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpCallInProgress(BaseModel):
|
||||
"""Streaming event for MCP calls in progress.
|
||||
|
||||
:param item_id: Unique identifier of the MCP call
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.mcp_call.in_progress"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
|
|
@ -374,12 +584,24 @@ class OpenAIResponseObjectStreamResponseMcpCallInProgress(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpCallFailed(BaseModel):
|
||||
"""Streaming event for failed MCP calls.
|
||||
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.mcp_call.failed"
|
||||
"""
|
||||
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_call.failed"] = "response.mcp_call.failed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
|
||||
"""Streaming event for completed MCP calls.
|
||||
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.mcp_call.completed"
|
||||
"""
|
||||
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed"
|
||||
|
||||
|
|
@ -442,6 +664,12 @@ WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_20
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||
"""Web search tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Web search tool type variant to use
|
||||
:param search_context_size: (Optional) Size of search context, must be "low", "medium", or "high"
|
||||
"""
|
||||
|
||||
# Must match values of WebSearchToolTypes above
|
||||
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
|
||||
"web_search"
|
||||
|
|
@ -453,6 +681,15 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFunction(BaseModel):
|
||||
"""Function tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "function"
|
||||
:param name: Name of the function that can be called
|
||||
:param description: (Optional) Description of what the function does
|
||||
:param parameters: (Optional) JSON schema defining the function's parameters
|
||||
:param strict: (Optional) Whether to enforce strict parameter validation
|
||||
"""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
|
@ -462,6 +699,15 @@ class OpenAIResponseInputToolFunction(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||
"""File search tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "file_search"
|
||||
:param vector_store_ids: List of vector store identifiers to search within
|
||||
:param filters: (Optional) Additional filters to apply to the search
|
||||
:param max_num_results: (Optional) Maximum number of search results to return (1-50)
|
||||
:param ranking_options: (Optional) Options for ranking and scoring search results
|
||||
"""
|
||||
|
||||
type: Literal["file_search"] = "file_search"
|
||||
vector_store_ids: list[str]
|
||||
filters: dict[str, Any] | None = None
|
||||
|
|
@ -470,16 +716,37 @@ class OpenAIResponseInputToolFileSearch(BaseModel):
|
|||
|
||||
|
||||
class ApprovalFilter(BaseModel):
|
||||
"""Filter configuration for MCP tool approval requirements.
|
||||
|
||||
:param always: (Optional) List of tool names that always require approval
|
||||
:param never: (Optional) List of tool names that never require approval
|
||||
"""
|
||||
|
||||
always: list[str] | None = None
|
||||
never: list[str] | None = None
|
||||
|
||||
|
||||
class AllowedToolsFilter(BaseModel):
|
||||
"""Filter configuration for restricting which MCP tools can be used.
|
||||
|
||||
:param tool_names: (Optional) List of specific tool names that are allowed
|
||||
"""
|
||||
|
||||
tool_names: list[str] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolMCP(BaseModel):
|
||||
"""Model Context Protocol (MCP) tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "mcp"
|
||||
:param server_label: Label to identify this MCP server
|
||||
:param server_url: URL endpoint of the MCP server
|
||||
:param headers: (Optional) HTTP headers to include when connecting to the server
|
||||
:param require_approval: Approval requirement for tool calls ("always", "never", or filter)
|
||||
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
|
||||
"""
|
||||
|
||||
type: Literal["mcp"] = "mcp"
|
||||
server_label: str
|
||||
server_url: str
|
||||
|
|
@ -500,17 +767,37 @@ register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
|||
|
||||
|
||||
class ListOpenAIResponseInputItem(BaseModel):
|
||||
"""List container for OpenAI response input items.
|
||||
|
||||
:param data: List of input items
|
||||
:param object: Object type identifier, always "list"
|
||||
"""
|
||||
|
||||
data: list[OpenAIResponseInput]
|
||||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
||||
"""OpenAI response object extended with input context information.
|
||||
|
||||
:param input: List of input items that led to this response
|
||||
"""
|
||||
|
||||
input: list[OpenAIResponseInput]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIResponseObject(BaseModel):
|
||||
"""Paginated list of OpenAI response objects with navigation metadata.
|
||||
|
||||
:param data: List of response objects with their input context
|
||||
:param has_more: Whether there are more results available beyond this page
|
||||
:param first_id: Identifier of the first item in this page
|
||||
:param last_id: Identifier of the last item in this page
|
||||
:param object: Object type identifier, always "list"
|
||||
"""
|
||||
|
||||
data: list[OpenAIResponseObjectWithInput]
|
||||
has_more: bool
|
||||
first_id: str
|
||||
|
|
|
|||
|
|
@ -22,6 +22,14 @@ class CommonBenchmarkFields(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Benchmark(CommonBenchmarkFields, Resource):
|
||||
"""A benchmark resource for evaluating model performance.
|
||||
|
||||
:param dataset_id: Identifier of the dataset to use for the benchmark evaluation
|
||||
:param scoring_functions: List of scoring function identifiers to apply during evaluation
|
||||
:param metadata: Metadata for this evaluation task
|
||||
:param type: The resource type, always benchmark
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.benchmark] = ResourceType.benchmark
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -15,6 +15,11 @@ from llama_stack.schema_utils import json_schema_type, register_schema
|
|||
|
||||
@json_schema_type
|
||||
class URL(BaseModel):
|
||||
"""A URL reference to external content.
|
||||
|
||||
:param uri: The URL string pointing to the resource
|
||||
"""
|
||||
|
||||
uri: str
|
||||
|
||||
|
||||
|
|
@ -76,17 +81,36 @@ register_schema(InterleavedContent, name="InterleavedContent")
|
|||
|
||||
@json_schema_type
|
||||
class TextDelta(BaseModel):
|
||||
"""A text content delta for streaming responses.
|
||||
|
||||
:param type: Discriminator type of the delta. Always "text"
|
||||
:param text: The incremental text content
|
||||
"""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ImageDelta(BaseModel):
|
||||
"""An image content delta for streaming responses.
|
||||
|
||||
:param type: Discriminator type of the delta. Always "image"
|
||||
:param image: The incremental image data as bytes
|
||||
"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
image: bytes
|
||||
|
||||
|
||||
class ToolCallParseStatus(Enum):
|
||||
"""Status of tool call parsing during streaming.
|
||||
:cvar started: Tool call parsing has begun
|
||||
:cvar in_progress: Tool call parsing is ongoing
|
||||
:cvar failed: Tool call parsing failed
|
||||
:cvar succeeded: Tool call parsing completed successfully
|
||||
"""
|
||||
|
||||
started = "started"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
|
|
@ -95,6 +119,13 @@ class ToolCallParseStatus(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class ToolCallDelta(BaseModel):
|
||||
"""A tool call content delta for streaming responses.
|
||||
|
||||
:param type: Discriminator type of the delta. Always "tool_call"
|
||||
:param tool_call: Either an in-progress tool call string or the final parsed tool call
|
||||
:param parse_status: Current parsing status of the tool call
|
||||
"""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
|
||||
# you either send an in-progress tool call so the client can stream a long
|
||||
|
|
|
|||
|
|
@ -4,6 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Custom Llama Stack Exception classes should follow the following schema
|
||||
# 1. All classes should inherit from an existing Built-In Exception class: https://docs.python.org/3/library/exceptions.html
|
||||
# 2. All classes should have a custom error message with the goal of informing the Llama Stack user specifically
|
||||
# 3. All classes should propogate the inherited __init__ function otherwise via 'super().__init__(message)'
|
||||
|
||||
|
||||
class UnsupportedModelError(ValueError):
|
||||
"""raised when model is not present in the list of supported models"""
|
||||
|
|
@ -11,3 +16,45 @@ class UnsupportedModelError(ValueError):
|
|||
def __init__(self, model_name: str, supported_models_list: list[str]):
|
||||
message = f"'{model_name}' model is not supported. Supported models are: {', '.join(supported_models_list)}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelNotFoundError(ValueError):
|
||||
"""raised when Llama Stack cannot find a referenced model"""
|
||||
|
||||
def __init__(self, model_name: str) -> None:
|
||||
message = f"Model '{model_name}' not found. Use client.models.list() to list available models."
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class VectorStoreNotFoundError(ValueError):
|
||||
"""raised when Llama Stack cannot find a referenced vector store"""
|
||||
|
||||
def __init__(self, vector_store_name: str) -> None:
|
||||
message = f"Vector store '{vector_store_name}' not found. Use client.vector_dbs.list() to list available vector stores."
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class DatasetNotFoundError(ValueError):
|
||||
"""raised when Llama Stack cannot find a referenced dataset"""
|
||||
|
||||
def __init__(self, dataset_name: str) -> None:
|
||||
message = f"Dataset '{dataset_name}' not found. Use client.datasets.list() to list available datasets."
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ToolGroupNotFoundError(ValueError):
|
||||
"""raised when Llama Stack cannot find a referenced tool group"""
|
||||
|
||||
def __init__(self, toolgroup_name: str) -> None:
|
||||
message = (
|
||||
f"Tool group '{toolgroup_name}' not found. Use client.toolgroups.list() to list available tool groups."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class SessionNotFoundError(ValueError):
|
||||
"""raised when Llama Stack cannot find a referenced session or access is denied"""
|
||||
|
||||
def __init__(self, session_name: str) -> None:
|
||||
message = f"Session '{session_name}' not found or access denied."
|
||||
super().__init__(message)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,14 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
|
||||
class JobStatus(Enum):
|
||||
"""Status of a job execution.
|
||||
:cvar completed: Job has finished successfully
|
||||
:cvar in_progress: Job is currently running
|
||||
:cvar failed: Job has failed during execution
|
||||
:cvar scheduled: Job is scheduled but not yet started
|
||||
:cvar cancelled: Job was cancelled before completion
|
||||
"""
|
||||
|
||||
completed = "completed"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
|
|
@ -20,5 +28,11 @@ class JobStatus(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class Job(BaseModel):
|
||||
"""A job execution instance with status tracking.
|
||||
|
||||
:param job_id: Unique identifier for the job
|
||||
:param status: Current execution status of the job
|
||||
"""
|
||||
|
||||
job_id: str
|
||||
status: JobStatus
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
|
||||
class Order(Enum):
|
||||
"""Sort order for paginated responses.
|
||||
:cvar asc: Ascending order
|
||||
:cvar desc: Descending order
|
||||
"""
|
||||
|
||||
asc = "asc"
|
||||
desc = "desc"
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,14 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
@json_schema_type
|
||||
class PostTrainingMetric(BaseModel):
|
||||
"""Training metrics captured during post-training jobs.
|
||||
|
||||
:param epoch: Training epoch number
|
||||
:param train_loss: Loss value on the training dataset
|
||||
:param validation_loss: Loss value on the validation dataset
|
||||
:param perplexity: Perplexity metric indicating model confidence
|
||||
"""
|
||||
|
||||
epoch: int
|
||||
train_loss: float
|
||||
validation_loss: float
|
||||
|
|
@ -21,7 +29,15 @@ class PostTrainingMetric(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Checkpoint(BaseModel):
|
||||
"""Checkpoint created during training runs"""
|
||||
"""Checkpoint created during training runs.
|
||||
|
||||
:param identifier: Unique identifier for the checkpoint
|
||||
:param created_at: Timestamp when the checkpoint was created
|
||||
:param epoch: Training epoch when the checkpoint was saved
|
||||
:param post_training_job_id: Identifier of the training job that created this checkpoint
|
||||
:param path: File system path where the checkpoint is stored
|
||||
:param training_metrics: (Optional) Training metrics associated with this checkpoint
|
||||
"""
|
||||
|
||||
identifier: str
|
||||
created_at: datetime
|
||||
|
|
|
|||
|
|
@ -13,59 +13,114 @@ from llama_stack.schema_utils import json_schema_type, register_schema
|
|||
|
||||
@json_schema_type
|
||||
class StringType(BaseModel):
|
||||
"""Parameter type for string values.
|
||||
|
||||
:param type: Discriminator type. Always "string"
|
||||
"""
|
||||
|
||||
type: Literal["string"] = "string"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NumberType(BaseModel):
|
||||
"""Parameter type for numeric values.
|
||||
|
||||
:param type: Discriminator type. Always "number"
|
||||
"""
|
||||
|
||||
type: Literal["number"] = "number"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BooleanType(BaseModel):
|
||||
"""Parameter type for boolean values.
|
||||
|
||||
:param type: Discriminator type. Always "boolean"
|
||||
"""
|
||||
|
||||
type: Literal["boolean"] = "boolean"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ArrayType(BaseModel):
|
||||
"""Parameter type for array values.
|
||||
|
||||
:param type: Discriminator type. Always "array"
|
||||
"""
|
||||
|
||||
type: Literal["array"] = "array"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ObjectType(BaseModel):
|
||||
"""Parameter type for object values.
|
||||
|
||||
:param type: Discriminator type. Always "object"
|
||||
"""
|
||||
|
||||
type: Literal["object"] = "object"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class JsonType(BaseModel):
|
||||
"""Parameter type for JSON values.
|
||||
|
||||
:param type: Discriminator type. Always "json"
|
||||
"""
|
||||
|
||||
type: Literal["json"] = "json"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UnionType(BaseModel):
|
||||
"""Parameter type for union values.
|
||||
|
||||
:param type: Discriminator type. Always "union"
|
||||
"""
|
||||
|
||||
type: Literal["union"] = "union"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionInputType(BaseModel):
|
||||
"""Parameter type for chat completion input.
|
||||
|
||||
:param type: Discriminator type. Always "chat_completion_input"
|
||||
"""
|
||||
|
||||
# expects List[Message] for messages
|
||||
type: Literal["chat_completion_input"] = "chat_completion_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionInputType(BaseModel):
|
||||
"""Parameter type for completion input.
|
||||
|
||||
:param type: Discriminator type. Always "completion_input"
|
||||
"""
|
||||
|
||||
# expects InterleavedTextMedia for content
|
||||
type: Literal["completion_input"] = "completion_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnInputType(BaseModel):
|
||||
"""Parameter type for agent turn input.
|
||||
|
||||
:param type: Discriminator type. Always "agent_turn_input"
|
||||
"""
|
||||
|
||||
# expects List[Message] for messages (may also include attachments?)
|
||||
type: Literal["agent_turn_input"] = "agent_turn_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DialogType(BaseModel):
|
||||
"""Parameter type for dialog data with semantic output labels.
|
||||
|
||||
:param type: Discriminator type. Always "dialog"
|
||||
"""
|
||||
|
||||
# expects List[Message] for messages
|
||||
# this type semantically contains the output label whereas ChatCompletionInputType does not
|
||||
type: Literal["dialog"] = "dialog"
|
||||
|
|
|
|||
|
|
@ -94,6 +94,10 @@ register_schema(DataSource, name="DataSource")
|
|||
class CommonDatasetFields(BaseModel):
|
||||
"""
|
||||
Common fields for a dataset.
|
||||
|
||||
:param purpose: Purpose of the dataset indicating its intended use
|
||||
:param source: Data source configuration for the dataset
|
||||
:param metadata: Additional metadata for the dataset
|
||||
"""
|
||||
|
||||
purpose: DatasetPurpose
|
||||
|
|
@ -106,6 +110,11 @@ class CommonDatasetFields(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Dataset(CommonDatasetFields, Resource):
|
||||
"""Dataset resource for storing and accessing training or evaluation data.
|
||||
|
||||
:param type: Type of resource, always 'dataset' for datasets
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.dataset] = ResourceType.dataset
|
||||
|
||||
@property
|
||||
|
|
@ -118,10 +127,20 @@ class Dataset(CommonDatasetFields, Resource):
|
|||
|
||||
|
||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||
"""Input parameters for dataset operations.
|
||||
|
||||
:param dataset_id: Unique identifier for the dataset
|
||||
"""
|
||||
|
||||
dataset_id: str
|
||||
|
||||
|
||||
class ListDatasetsResponse(BaseModel):
|
||||
"""Response from listing datasets.
|
||||
|
||||
:param data: List of datasets
|
||||
"""
|
||||
|
||||
data: list[Dataset]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -81,6 +81,29 @@ class DynamicApiMeta(EnumMeta):
|
|||
|
||||
@json_schema_type
|
||||
class Api(Enum, metaclass=DynamicApiMeta):
|
||||
"""Enumeration of all available APIs in the Llama Stack system.
|
||||
:cvar providers: Provider management and configuration
|
||||
:cvar inference: Text generation, chat completions, and embeddings
|
||||
:cvar safety: Content moderation and safety shields
|
||||
:cvar agents: Agent orchestration and execution
|
||||
:cvar vector_io: Vector database operations and queries
|
||||
:cvar datasetio: Dataset input/output operations
|
||||
:cvar scoring: Model output evaluation and scoring
|
||||
:cvar eval: Model evaluation and benchmarking framework
|
||||
:cvar post_training: Fine-tuning and model training
|
||||
:cvar tool_runtime: Tool execution and management
|
||||
:cvar telemetry: Observability and system monitoring
|
||||
:cvar models: Model metadata and management
|
||||
:cvar shields: Safety shield implementations
|
||||
:cvar vector_dbs: Vector database management
|
||||
:cvar datasets: Dataset creation and management
|
||||
:cvar scoring_functions: Scoring function definitions
|
||||
:cvar benchmarks: Benchmark suite management
|
||||
:cvar tool_groups: Tool group organization
|
||||
:cvar files: File storage and management
|
||||
:cvar inspect: Built-in system inspection and introspection
|
||||
"""
|
||||
|
||||
providers = "providers"
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
|
|
|
|||
|
|
@ -54,6 +54,9 @@ class ListOpenAIFileResponse(BaseModel):
|
|||
Response for listing files in OpenAI Files API.
|
||||
|
||||
:param data: List of file objects
|
||||
:param has_more: Whether there are more files available beyond this page
|
||||
:param first_id: ID of the first file in the list for pagination
|
||||
:param last_id: ID of the last file in the list for pagination
|
||||
:param object: The object type, which is always "list"
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -41,11 +41,23 @@ from enum import StrEnum
|
|||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
"""Greedy sampling strategy that selects the highest probability token at each step.
|
||||
|
||||
:param type: Must be "greedy" to identify this sampling strategy
|
||||
"""
|
||||
|
||||
type: Literal["greedy"] = "greedy"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopPSamplingStrategy(BaseModel):
|
||||
"""Top-p (nucleus) sampling strategy that samples from the smallest set of tokens with cumulative probability >= p.
|
||||
|
||||
:param type: Must be "top_p" to identify this sampling strategy
|
||||
:param temperature: Controls randomness in sampling. Higher values increase randomness
|
||||
:param top_p: Cumulative probability threshold for nucleus sampling. Defaults to 0.95
|
||||
"""
|
||||
|
||||
type: Literal["top_p"] = "top_p"
|
||||
temperature: float | None = Field(..., gt=0.0)
|
||||
top_p: float | None = 0.95
|
||||
|
|
@ -53,6 +65,12 @@ class TopPSamplingStrategy(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class TopKSamplingStrategy(BaseModel):
|
||||
"""Top-k sampling strategy that restricts sampling to the k most likely tokens.
|
||||
|
||||
:param type: Must be "top_k" to identify this sampling strategy
|
||||
:param top_k: Number of top tokens to consider for sampling. Must be at least 1
|
||||
"""
|
||||
|
||||
type: Literal["top_k"] = "top_k"
|
||||
top_k: int = Field(..., ge=1)
|
||||
|
||||
|
|
@ -108,11 +126,21 @@ class QuantizationType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class Fp8QuantizationConfig(BaseModel):
|
||||
"""Configuration for 8-bit floating point quantization.
|
||||
|
||||
:param type: Must be "fp8_mixed" to identify this quantization type
|
||||
"""
|
||||
|
||||
type: Literal["fp8_mixed"] = "fp8_mixed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Bf16QuantizationConfig(BaseModel):
|
||||
"""Configuration for BFloat16 precision (typically no quantization).
|
||||
|
||||
:param type: Must be "bf16" to identify this quantization type
|
||||
"""
|
||||
|
||||
type: Literal["bf16"] = "bf16"
|
||||
|
||||
|
||||
|
|
@ -202,6 +230,14 @@ register_schema(Message, name="Message")
|
|||
|
||||
@json_schema_type
|
||||
class ToolResponse(BaseModel):
|
||||
"""Response from a tool invocation.
|
||||
|
||||
:param call_id: Unique identifier for the tool call this response is for
|
||||
:param tool_name: Name of the tool that was invoked
|
||||
:param content: The response content from the tool
|
||||
:param metadata: (Optional) Additional metadata about the tool response
|
||||
"""
|
||||
|
||||
call_id: str
|
||||
tool_name: BuiltinTool | str
|
||||
content: InterleavedContent
|
||||
|
|
@ -439,24 +475,55 @@ class EmbeddingsResponse(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||
"""Text content part for OpenAI-compatible chat completion messages.
|
||||
|
||||
:param type: Must be "text" to identify this as text content
|
||||
:param text: The text content of the message
|
||||
"""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIImageURL(BaseModel):
|
||||
"""Image URL specification for OpenAI-compatible chat completion messages.
|
||||
|
||||
:param url: URL of the image to include in the message
|
||||
:param detail: (Optional) Level of detail for image processing. Can be "low", "high", or "auto"
|
||||
"""
|
||||
|
||||
url: str
|
||||
detail: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
||||
"""Image content part for OpenAI-compatible chat completion messages.
|
||||
|
||||
:param type: Must be "image_url" to identify this as image content
|
||||
:param image_url: Image URL specification and processing details
|
||||
"""
|
||||
|
||||
type: Literal["image_url"] = "image_url"
|
||||
image_url: OpenAIImageURL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIFileFile(BaseModel):
|
||||
file_data: str | None = None
|
||||
file_id: str | None = None
|
||||
filename: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIFile(BaseModel):
|
||||
type: Literal["file"] = "file"
|
||||
file: OpenAIFileFile
|
||||
|
||||
|
||||
OpenAIChatCompletionContentPartParam = Annotated[
|
||||
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||
|
|
@ -464,6 +531,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion
|
|||
|
||||
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||
|
||||
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIUserMessageParam(BaseModel):
|
||||
|
|
@ -489,18 +558,32 @@ class OpenAISystemMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["system"] = "system"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
name: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||
"""Function call details for OpenAI-compatible tool calls.
|
||||
|
||||
:param name: (Optional) Name of the function to call
|
||||
:param arguments: (Optional) Arguments to pass to the function as a JSON string
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
arguments: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCall(BaseModel):
|
||||
"""Tool call specification for OpenAI-compatible chat completion responses.
|
||||
|
||||
:param index: (Optional) Index of the tool call in the list
|
||||
:param id: (Optional) Unique identifier for the tool call
|
||||
:param type: Must be "function" to identify this as a function call
|
||||
:param function: (Optional) Function call details
|
||||
"""
|
||||
|
||||
index: int | None = None
|
||||
id: str | None = None
|
||||
type: Literal["function"] = "function"
|
||||
|
|
@ -518,7 +601,7 @@ class OpenAIAssistantMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: OpenAIChatCompletionMessageContent | None = None
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||
|
||||
|
|
@ -534,7 +617,7 @@ class OpenAIToolMessageParam(BaseModel):
|
|||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -547,7 +630,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["developer"] = "developer"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
name: str | None = None
|
||||
|
||||
|
||||
|
|
@ -564,11 +647,24 @@ register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatText(BaseModel):
|
||||
"""Text response format for OpenAI-compatible chat completion requests.
|
||||
|
||||
:param type: Must be "text" to indicate plain text response format
|
||||
"""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIJSONSchema(TypedDict, total=False):
|
||||
"""JSON schema specification for OpenAI-compatible structured response format.
|
||||
|
||||
:param name: Name of the schema
|
||||
:param description: (Optional) Description of the schema
|
||||
:param strict: (Optional) Whether to enforce strict adherence to the schema
|
||||
:param schema: (Optional) The JSON schema definition
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str | None
|
||||
strict: bool | None
|
||||
|
|
@ -582,12 +678,23 @@ class OpenAIJSONSchema(TypedDict, total=False):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatJSONSchema(BaseModel):
|
||||
"""JSON schema response format for OpenAI-compatible chat completion requests.
|
||||
|
||||
:param type: Must be "json_schema" to indicate structured JSON response format
|
||||
:param json_schema: The JSON schema specification for the response
|
||||
"""
|
||||
|
||||
type: Literal["json_schema"] = "json_schema"
|
||||
json_schema: OpenAIJSONSchema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatJSONObject(BaseModel):
|
||||
"""JSON object response format for OpenAI-compatible chat completion requests.
|
||||
|
||||
:param type: Must be "json_object" to indicate generic JSON object response format
|
||||
"""
|
||||
|
||||
type: Literal["json_object"] = "json_object"
|
||||
|
||||
|
||||
|
|
@ -846,11 +953,21 @@ class EmbeddingTaskType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class BatchCompletionResponse(BaseModel):
|
||||
"""Response from a batch completion request.
|
||||
|
||||
:param batch: List of completion responses, one for each input in the batch
|
||||
"""
|
||||
|
||||
batch: list[CompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionResponse(BaseModel):
|
||||
"""Response from a batch chat completion request.
|
||||
|
||||
:param batch: List of chat completion responses, one for each conversation in the batch
|
||||
"""
|
||||
|
||||
batch: list[ChatCompletionResponse]
|
||||
|
||||
|
||||
|
|
@ -860,6 +977,15 @@ class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
|||
|
||||
@json_schema_type
|
||||
class ListOpenAIChatCompletionResponse(BaseModel):
|
||||
"""Response from listing OpenAI-compatible chat completions.
|
||||
|
||||
:param data: List of chat completion objects with their input messages
|
||||
:param has_more: Whether there are more completions available beyond this list
|
||||
:param first_id: ID of the first completion in this list
|
||||
:param last_id: ID of the last completion in this list
|
||||
:param object: Must be "list" to identify this as a list response
|
||||
"""
|
||||
|
||||
data: list[OpenAICompletionWithInputMessages]
|
||||
has_more: bool
|
||||
first_id: str
|
||||
|
|
|
|||
|
|
@ -14,6 +14,13 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
"""Information about an API route including its path, method, and implementing providers.
|
||||
|
||||
:param route: The API endpoint path
|
||||
:param method: HTTP method for the route
|
||||
:param provider_types: List of provider types that implement this route
|
||||
"""
|
||||
|
||||
route: str
|
||||
method: str
|
||||
provider_types: list[str]
|
||||
|
|
@ -21,15 +28,30 @@ class RouteInfo(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class HealthInfo(BaseModel):
|
||||
"""Health status information for the service.
|
||||
|
||||
:param status: Current health status of the service
|
||||
"""
|
||||
|
||||
status: HealthStatus
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VersionInfo(BaseModel):
|
||||
"""Version information for the service.
|
||||
|
||||
:param version: Version number of the service
|
||||
"""
|
||||
|
||||
version: str
|
||||
|
||||
|
||||
class ListRoutesResponse(BaseModel):
|
||||
"""Response containing a list of all available API routes.
|
||||
|
||||
:param data: List of available route information objects
|
||||
"""
|
||||
|
||||
data: list[RouteInfo]
|
||||
|
||||
|
||||
|
|
@ -37,17 +59,17 @@ class ListRoutesResponse(BaseModel):
|
|||
class Inspect(Protocol):
|
||||
@webmethod(route="/inspect/routes", method="GET")
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
"""List all routes.
|
||||
"""List all available API routes with their methods and implementing providers.
|
||||
|
||||
:returns: A ListRoutesResponse.
|
||||
:returns: Response containing information about all available routes.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/health", method="GET")
|
||||
async def health(self) -> HealthInfo:
|
||||
"""Get the health of the service.
|
||||
"""Get the current health status of the service.
|
||||
|
||||
:returns: A HealthInfo.
|
||||
:returns: Health information indicating if the service is operational.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
@ -55,6 +77,6 @@ class Inspect(Protocol):
|
|||
async def version(self) -> VersionInfo:
|
||||
"""Get the version of the service.
|
||||
|
||||
:returns: A VersionInfo.
|
||||
:returns: Version information containing the service version number.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -23,12 +23,27 @@ class CommonModelFields(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ModelType(StrEnum):
|
||||
"""Enumeration of supported model types in Llama Stack.
|
||||
:cvar llm: Large language model for text generation and completion
|
||||
:cvar embedding: Embedding model for converting text to vector representations
|
||||
"""
|
||||
|
||||
llm = "llm"
|
||||
embedding = "embedding"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Model(CommonModelFields, Resource):
|
||||
"""A model resource representing an AI model registered in Llama Stack.
|
||||
|
||||
:param type: The resource type, always 'model' for model resources
|
||||
:param model_type: The type of model (LLM or embedding model)
|
||||
:param metadata: Any additional metadata for this model
|
||||
:param identifier: Unique identifier for this resource in llama stack
|
||||
:param provider_resource_id: Unique identifier for this resource in the provider
|
||||
:param provider_id: ID of the provider that owns this resource
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.model] = ResourceType.model
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -18,6 +18,12 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
|
||||
@json_schema_type
|
||||
class OptimizerType(Enum):
|
||||
"""Available optimizer algorithms for training.
|
||||
:cvar adam: Adaptive Moment Estimation optimizer
|
||||
:cvar adamw: AdamW optimizer with weight decay
|
||||
:cvar sgd: Stochastic Gradient Descent optimizer
|
||||
"""
|
||||
|
||||
adam = "adam"
|
||||
adamw = "adamw"
|
||||
sgd = "sgd"
|
||||
|
|
@ -25,12 +31,28 @@ class OptimizerType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class DatasetFormat(Enum):
|
||||
"""Format of the training dataset.
|
||||
:cvar instruct: Instruction-following format with prompt and completion
|
||||
:cvar dialog: Multi-turn conversation format with messages
|
||||
"""
|
||||
|
||||
instruct = "instruct"
|
||||
dialog = "dialog"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DataConfig(BaseModel):
|
||||
"""Configuration for training data and data loading.
|
||||
|
||||
:param dataset_id: Unique identifier for the training dataset
|
||||
:param batch_size: Number of samples per training batch
|
||||
:param shuffle: Whether to shuffle the dataset during training
|
||||
:param data_format: Format of the dataset (instruct or dialog)
|
||||
:param validation_dataset_id: (Optional) Unique identifier for the validation dataset
|
||||
:param packed: (Optional) Whether to pack multiple samples into a single sequence for efficiency
|
||||
:param train_on_input: (Optional) Whether to compute loss on input tokens as well as output tokens
|
||||
"""
|
||||
|
||||
dataset_id: str
|
||||
batch_size: int
|
||||
shuffle: bool
|
||||
|
|
@ -42,6 +64,14 @@ class DataConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OptimizerConfig(BaseModel):
|
||||
"""Configuration parameters for the optimization algorithm.
|
||||
|
||||
:param optimizer_type: Type of optimizer to use (adam, adamw, or sgd)
|
||||
:param lr: Learning rate for the optimizer
|
||||
:param weight_decay: Weight decay coefficient for regularization
|
||||
:param num_warmup_steps: Number of steps for learning rate warmup
|
||||
"""
|
||||
|
||||
optimizer_type: OptimizerType
|
||||
lr: float
|
||||
weight_decay: float
|
||||
|
|
@ -50,6 +80,14 @@ class OptimizerConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class EfficiencyConfig(BaseModel):
|
||||
"""Configuration for memory and compute efficiency optimizations.
|
||||
|
||||
:param enable_activation_checkpointing: (Optional) Whether to use activation checkpointing to reduce memory usage
|
||||
:param enable_activation_offloading: (Optional) Whether to offload activations to CPU to save GPU memory
|
||||
:param memory_efficient_fsdp_wrap: (Optional) Whether to use memory-efficient FSDP wrapping
|
||||
:param fsdp_cpu_offload: (Optional) Whether to offload FSDP parameters to CPU
|
||||
"""
|
||||
|
||||
enable_activation_checkpointing: bool | None = False
|
||||
enable_activation_offloading: bool | None = False
|
||||
memory_efficient_fsdp_wrap: bool | None = False
|
||||
|
|
@ -58,6 +96,18 @@ class EfficiencyConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Comprehensive configuration for the training process.
|
||||
|
||||
:param n_epochs: Number of training epochs to run
|
||||
:param max_steps_per_epoch: Maximum number of steps to run per epoch
|
||||
:param gradient_accumulation_steps: Number of steps to accumulate gradients before updating
|
||||
:param max_validation_steps: (Optional) Maximum number of validation steps per epoch
|
||||
:param data_config: (Optional) Configuration for data loading and formatting
|
||||
:param optimizer_config: (Optional) Configuration for the optimization algorithm
|
||||
:param efficiency_config: (Optional) Configuration for memory and compute optimizations
|
||||
:param dtype: (Optional) Data type for model parameters (bf16, fp16, fp32)
|
||||
"""
|
||||
|
||||
n_epochs: int
|
||||
max_steps_per_epoch: int = 1
|
||||
gradient_accumulation_steps: int = 1
|
||||
|
|
@ -70,6 +120,18 @@ class TrainingConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class LoraFinetuningConfig(BaseModel):
|
||||
"""Configuration for Low-Rank Adaptation (LoRA) fine-tuning.
|
||||
|
||||
:param type: Algorithm type identifier, always "LoRA"
|
||||
:param lora_attn_modules: List of attention module names to apply LoRA to
|
||||
:param apply_lora_to_mlp: Whether to apply LoRA to MLP layers
|
||||
:param apply_lora_to_output: Whether to apply LoRA to output projection layers
|
||||
:param rank: Rank of the LoRA adaptation (lower rank = fewer parameters)
|
||||
:param alpha: LoRA scaling parameter that controls adaptation strength
|
||||
:param use_dora: (Optional) Whether to use DoRA (Weight-Decomposed Low-Rank Adaptation)
|
||||
:param quantize_base: (Optional) Whether to quantize the base model weights
|
||||
"""
|
||||
|
||||
type: Literal["LoRA"] = "LoRA"
|
||||
lora_attn_modules: list[str]
|
||||
apply_lora_to_mlp: bool
|
||||
|
|
@ -82,6 +144,13 @@ class LoraFinetuningConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class QATFinetuningConfig(BaseModel):
|
||||
"""Configuration for Quantization-Aware Training (QAT) fine-tuning.
|
||||
|
||||
:param type: Algorithm type identifier, always "QAT"
|
||||
:param quantizer_name: Name of the quantization algorithm to use
|
||||
:param group_size: Size of groups for grouped quantization
|
||||
"""
|
||||
|
||||
type: Literal["QAT"] = "QAT"
|
||||
quantizer_name: str
|
||||
group_size: int
|
||||
|
|
@ -93,7 +162,11 @@ register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
|||
|
||||
@json_schema_type
|
||||
class PostTrainingJobLogStream(BaseModel):
|
||||
"""Stream of logs from a finetuning job."""
|
||||
"""Stream of logs from a finetuning job.
|
||||
|
||||
:param job_uuid: Unique identifier for the training job
|
||||
:param log_lines: List of log message strings from the training process
|
||||
"""
|
||||
|
||||
job_uuid: str
|
||||
log_lines: list[str]
|
||||
|
|
@ -101,6 +174,10 @@ class PostTrainingJobLogStream(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RLHFAlgorithm(Enum):
|
||||
"""Available reinforcement learning from human feedback algorithms.
|
||||
:cvar dpo: Direct Preference Optimization algorithm
|
||||
"""
|
||||
|
||||
dpo = "dpo"
|
||||
|
||||
|
||||
|
|
@ -114,13 +191,31 @@ class DPOLossType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class DPOAlignmentConfig(BaseModel):
|
||||
"""Configuration for Direct Preference Optimization (DPO) alignment.
|
||||
|
||||
:param beta: Temperature parameter for the DPO loss
|
||||
:param loss_type: The type of loss function to use for DPO
|
||||
"""
|
||||
|
||||
beta: float
|
||||
loss_type: DPOLossType = DPOLossType.sigmoid
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingRLHFRequest(BaseModel):
|
||||
"""Request to finetune a model."""
|
||||
"""Request to finetune a model using reinforcement learning from human feedback.
|
||||
|
||||
:param job_uuid: Unique identifier for the training job
|
||||
:param finetuned_model: URL or path to the base model to fine-tune
|
||||
:param dataset_id: Unique identifier for the training dataset
|
||||
:param validation_dataset_id: Unique identifier for the validation dataset
|
||||
:param algorithm: RLHF algorithm to use for training
|
||||
:param algorithm_config: Configuration parameters for the RLHF algorithm
|
||||
:param optimizer_config: Configuration parameters for the optimization algorithm
|
||||
:param training_config: Configuration parameters for the training process
|
||||
:param hyperparam_search_config: Configuration for hyperparameter search
|
||||
:param logger_config: Configuration for training logging
|
||||
"""
|
||||
|
||||
job_uuid: str
|
||||
|
||||
|
|
@ -146,7 +241,16 @@ class PostTrainingJob(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class PostTrainingJobStatusResponse(BaseModel):
|
||||
"""Status of a finetuning job."""
|
||||
"""Status of a finetuning job.
|
||||
|
||||
:param job_uuid: Unique identifier for the training job
|
||||
:param status: Current status of the training job
|
||||
:param scheduled_at: (Optional) Timestamp when the job was scheduled
|
||||
:param started_at: (Optional) Timestamp when the job execution began
|
||||
:param completed_at: (Optional) Timestamp when the job finished, if completed
|
||||
:param resources_allocated: (Optional) Information about computational resources allocated to the job
|
||||
:param checkpoints: List of model checkpoints created during training
|
||||
"""
|
||||
|
||||
job_uuid: str
|
||||
status: JobStatus
|
||||
|
|
@ -166,7 +270,11 @@ class ListPostTrainingJobsResponse(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class PostTrainingJobArtifactsResponse(BaseModel):
|
||||
"""Artifacts of a finetuning job."""
|
||||
"""Artifacts of a finetuning job.
|
||||
|
||||
:param job_uuid: Unique identifier for the training job
|
||||
:param checkpoints: List of model checkpoints created during training
|
||||
"""
|
||||
|
||||
job_uuid: str
|
||||
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,15 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about a registered provider including its configuration and health status.
|
||||
|
||||
:param api: The API name this provider implements
|
||||
:param provider_id: Unique identifier for the provider
|
||||
:param provider_type: The type of provider implementation
|
||||
:param config: Configuration parameters for the provider
|
||||
:param health: Current health status of the provider
|
||||
"""
|
||||
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
|
@ -22,6 +31,11 @@ class ProviderInfo(BaseModel):
|
|||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
"""Response containing a list of all available providers.
|
||||
|
||||
:param data: List of provider information objects
|
||||
"""
|
||||
|
||||
data: list[ProviderInfo]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,13 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
@json_schema_type
|
||||
class ViolationLevel(Enum):
|
||||
"""Severity level of a safety violation.
|
||||
|
||||
:cvar INFO: Informational level violation that does not require action
|
||||
:cvar WARN: Warning level violation that suggests caution but allows continuation
|
||||
:cvar ERROR: Error level violation that requires blocking or intervention
|
||||
"""
|
||||
|
||||
INFO = "info"
|
||||
WARN = "warn"
|
||||
ERROR = "error"
|
||||
|
|
@ -24,6 +31,13 @@ class ViolationLevel(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class SafetyViolation(BaseModel):
|
||||
"""Details of a safety violation detected by content moderation.
|
||||
|
||||
:param violation_level: Severity level of the violation
|
||||
:param user_message: (Optional) Message to convey to the user about the violation
|
||||
:param metadata: Additional metadata including specific violation codes for debugging and telemetry
|
||||
"""
|
||||
|
||||
violation_level: ViolationLevel
|
||||
|
||||
# what message should you convey to the user
|
||||
|
|
@ -36,6 +50,11 @@ class SafetyViolation(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RunShieldResponse(BaseModel):
|
||||
"""Response from running a safety shield.
|
||||
|
||||
:param violation: (Optional) Safety violation detected by the shield, if any
|
||||
"""
|
||||
|
||||
violation: SafetyViolation | None = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,12 @@ class ScoringResult(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ScoreBatchResponse(BaseModel):
|
||||
"""Response from batch scoring operations on datasets.
|
||||
|
||||
:param dataset_id: (Optional) The identifier of the dataset that was scored
|
||||
:param results: A map of scoring function name to ScoringResult
|
||||
"""
|
||||
|
||||
dataset_id: str | None = None
|
||||
results: dict[str, ScoringResult]
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,12 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
# with standard metrics so they can be rolled up?
|
||||
@json_schema_type
|
||||
class ScoringFnParamsType(StrEnum):
|
||||
"""Types of scoring function parameter configurations.
|
||||
:cvar llm_as_judge: Use an LLM model to evaluate and score responses
|
||||
:cvar regex_parser: Use regex patterns to extract and score specific parts of responses
|
||||
:cvar basic: Basic scoring with simple aggregation functions
|
||||
"""
|
||||
|
||||
llm_as_judge = "llm_as_judge"
|
||||
regex_parser = "regex_parser"
|
||||
basic = "basic"
|
||||
|
|
@ -32,6 +38,14 @@ class ScoringFnParamsType(StrEnum):
|
|||
|
||||
@json_schema_type
|
||||
class AggregationFunctionType(StrEnum):
|
||||
"""Types of aggregation functions for scoring results.
|
||||
:cvar average: Calculate the arithmetic mean of scores
|
||||
:cvar weighted_average: Calculate a weighted average of scores
|
||||
:cvar median: Calculate the median value of scores
|
||||
:cvar categorical_count: Count occurrences of categorical values
|
||||
:cvar accuracy: Calculate accuracy as the proportion of correct answers
|
||||
"""
|
||||
|
||||
average = "average"
|
||||
weighted_average = "weighted_average"
|
||||
median = "median"
|
||||
|
|
@ -41,6 +55,14 @@ class AggregationFunctionType(StrEnum):
|
|||
|
||||
@json_schema_type
|
||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||
"""Parameters for LLM-as-judge scoring function configuration.
|
||||
:param type: The type of scoring function parameters, always llm_as_judge
|
||||
:param judge_model: Identifier of the LLM model to use as a judge for scoring
|
||||
:param prompt_template: (Optional) Custom prompt template for the judge model
|
||||
:param judge_score_regexes: Regexes to extract the answer from generated response
|
||||
:param aggregation_functions: Aggregation functions to apply to the scores of each row
|
||||
"""
|
||||
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
|
||||
judge_model: str
|
||||
prompt_template: str | None = None
|
||||
|
|
@ -56,6 +78,12 @@ class LLMAsJudgeScoringFnParams(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RegexParserScoringFnParams(BaseModel):
|
||||
"""Parameters for regex parser scoring function configuration.
|
||||
:param type: The type of scoring function parameters, always regex_parser
|
||||
:param parsing_regexes: Regex to extract the answer from generated response
|
||||
:param aggregation_functions: Aggregation functions to apply to the scores of each row
|
||||
"""
|
||||
|
||||
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
|
||||
parsing_regexes: list[str] = Field(
|
||||
description="Regex to extract the answer from generated response",
|
||||
|
|
@ -69,6 +97,11 @@ class RegexParserScoringFnParams(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class BasicScoringFnParams(BaseModel):
|
||||
"""Parameters for basic scoring function configuration.
|
||||
:param type: The type of scoring function parameters, always basic
|
||||
:param aggregation_functions: Aggregation functions to apply to the scores of each row
|
||||
"""
|
||||
|
||||
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
|
|
@ -100,6 +133,10 @@ class CommonScoringFnFields(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ScoringFn(CommonScoringFnFields, Resource):
|
||||
"""A scoring function resource for evaluating model outputs.
|
||||
:param type: The resource type, always scoring_function
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -19,7 +19,11 @@ class CommonShieldFields(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Shield(CommonShieldFields, Resource):
|
||||
"""A safety shield resource that can be used to check content"""
|
||||
"""A safety shield resource that can be used to check content.
|
||||
|
||||
:param params: (Optional) Configuration parameters for the shield
|
||||
:param type: The resource type, always shield
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.shield] = ResourceType.shield
|
||||
|
||||
|
|
@ -79,3 +83,11 @@ class Shields(Protocol):
|
|||
:returns: A Shield.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE")
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
"""Unregister a shield.
|
||||
|
||||
:param identifier: The identifier of the shield to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -14,7 +14,15 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
|
||||
class FilteringFunction(Enum):
|
||||
"""The type of filtering function."""
|
||||
"""The type of filtering function.
|
||||
|
||||
:cvar none: No filtering applied, accept all generated synthetic data
|
||||
:cvar random: Random sampling of generated data points
|
||||
:cvar top_k: Keep only the top-k highest scoring synthetic data samples
|
||||
:cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold
|
||||
:cvar top_k_top_p: Combined top-k and top-p filtering strategy
|
||||
:cvar sigmoid: Apply sigmoid function for probability-based filtering
|
||||
"""
|
||||
|
||||
none = "none"
|
||||
random = "random"
|
||||
|
|
@ -26,7 +34,12 @@ class FilteringFunction(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class SyntheticDataGenerationRequest(BaseModel):
|
||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function
|
||||
|
||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
||||
"""
|
||||
|
||||
dialogs: list[Message]
|
||||
filtering_function: FilteringFunction = FilteringFunction.none
|
||||
|
|
@ -35,7 +48,11 @@ class SyntheticDataGenerationRequest(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class SyntheticDataGenerationResponse(BaseModel):
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.
|
||||
|
||||
:param synthetic_data: List of generated synthetic data samples that passed the filtering criteria
|
||||
:param statistics: (Optional) Statistical information about the generation process and filtering results
|
||||
"""
|
||||
|
||||
synthetic_data: list[dict[str, Any]]
|
||||
statistics: dict[str, Any] | None = None
|
||||
|
|
@ -48,4 +65,12 @@ class SyntheticDataGeneration(Protocol):
|
|||
dialogs: list[Message],
|
||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||
model: str | None = None,
|
||||
) -> SyntheticDataGenerationResponse: ...
|
||||
) -> SyntheticDataGenerationResponse:
|
||||
"""Generate synthetic data based on input dialogs and apply filtering.
|
||||
|
||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
||||
:returns: Response containing filtered synthetic data samples and optional statistics
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -27,12 +27,27 @@ REQUIRED_SCOPE = "telemetry.read"
|
|||
|
||||
@json_schema_type
|
||||
class SpanStatus(Enum):
|
||||
"""The status of a span indicating whether it completed successfully or with an error.
|
||||
:cvar OK: Span completed successfully without errors
|
||||
:cvar ERROR: Span completed with an error or failure
|
||||
"""
|
||||
|
||||
OK = "ok"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Span(BaseModel):
|
||||
"""A span representing a single operation within a trace.
|
||||
:param span_id: Unique identifier for the span
|
||||
:param trace_id: Unique identifier for the trace this span belongs to
|
||||
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
|
||||
:param name: Human-readable name describing the operation this span represents
|
||||
:param start_time: Timestamp when the operation began
|
||||
:param end_time: (Optional) Timestamp when the operation finished, if completed
|
||||
:param attributes: (Optional) Key-value pairs containing additional metadata about the span
|
||||
"""
|
||||
|
||||
span_id: str
|
||||
trace_id: str
|
||||
parent_span_id: str | None = None
|
||||
|
|
@ -49,6 +64,13 @@ class Span(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Trace(BaseModel):
|
||||
"""A trace representing the complete execution path of a request across multiple operations.
|
||||
:param trace_id: Unique identifier for the trace
|
||||
:param root_span_id: Unique identifier for the root span that started this trace
|
||||
:param start_time: Timestamp when the trace began
|
||||
:param end_time: (Optional) Timestamp when the trace finished, if completed
|
||||
"""
|
||||
|
||||
trace_id: str
|
||||
root_span_id: str
|
||||
start_time: datetime
|
||||
|
|
@ -57,6 +79,12 @@ class Trace(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class EventType(Enum):
|
||||
"""The type of telemetry event being logged.
|
||||
:cvar UNSTRUCTURED_LOG: A simple log message with severity level
|
||||
:cvar STRUCTURED_LOG: A structured log event with typed payload data
|
||||
:cvar METRIC: A metric measurement with value and unit
|
||||
"""
|
||||
|
||||
UNSTRUCTURED_LOG = "unstructured_log"
|
||||
STRUCTURED_LOG = "structured_log"
|
||||
METRIC = "metric"
|
||||
|
|
@ -64,6 +92,15 @@ class EventType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class LogSeverity(Enum):
|
||||
"""The severity level of a log message.
|
||||
:cvar VERBOSE: Detailed diagnostic information for troubleshooting
|
||||
:cvar DEBUG: Debug information useful during development
|
||||
:cvar INFO: General informational messages about normal operation
|
||||
:cvar WARN: Warning messages about potentially problematic situations
|
||||
:cvar ERROR: Error messages indicating failures that don't stop execution
|
||||
:cvar CRITICAL: Critical error messages indicating severe failures
|
||||
"""
|
||||
|
||||
VERBOSE = "verbose"
|
||||
DEBUG = "debug"
|
||||
INFO = "info"
|
||||
|
|
@ -73,6 +110,13 @@ class LogSeverity(Enum):
|
|||
|
||||
|
||||
class EventCommon(BaseModel):
|
||||
"""Common fields shared by all telemetry events.
|
||||
:param trace_id: Unique identifier for the trace this event belongs to
|
||||
:param span_id: Unique identifier for the span this event belongs to
|
||||
:param timestamp: Timestamp when the event occurred
|
||||
:param attributes: (Optional) Key-value pairs containing additional metadata about the event
|
||||
"""
|
||||
|
||||
trace_id: str
|
||||
span_id: str
|
||||
timestamp: datetime
|
||||
|
|
@ -81,6 +125,12 @@ class EventCommon(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class UnstructuredLogEvent(EventCommon):
|
||||
"""An unstructured log event containing a simple text message.
|
||||
:param type: Event type identifier set to UNSTRUCTURED_LOG
|
||||
:param message: The log message text
|
||||
:param severity: The severity level of the log message
|
||||
"""
|
||||
|
||||
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
|
||||
message: str
|
||||
severity: LogSeverity
|
||||
|
|
@ -88,6 +138,13 @@ class UnstructuredLogEvent(EventCommon):
|
|||
|
||||
@json_schema_type
|
||||
class MetricEvent(EventCommon):
|
||||
"""A metric event containing a measured value.
|
||||
:param type: Event type identifier set to METRIC
|
||||
:param metric: The name of the metric being measured
|
||||
:param value: The numeric value of the metric measurement
|
||||
:param unit: The unit of measurement for the metric value
|
||||
"""
|
||||
|
||||
type: Literal[EventType.METRIC] = EventType.METRIC
|
||||
metric: str # this would be an enum
|
||||
value: int | float
|
||||
|
|
@ -96,6 +153,12 @@ class MetricEvent(EventCommon):
|
|||
|
||||
@json_schema_type
|
||||
class MetricInResponse(BaseModel):
|
||||
"""A metric value included in API responses.
|
||||
:param metric: The name of the metric
|
||||
:param value: The numeric value of the metric
|
||||
:param unit: (Optional) The unit of measurement for the metric value
|
||||
"""
|
||||
|
||||
metric: str
|
||||
value: int | float
|
||||
unit: str | None = None
|
||||
|
|
@ -122,17 +185,32 @@ class MetricInResponse(BaseModel):
|
|||
|
||||
|
||||
class MetricResponseMixin(BaseModel):
|
||||
"""Mixin class for API responses that can include metrics.
|
||||
:param metrics: (Optional) List of metrics associated with the API response
|
||||
"""
|
||||
|
||||
metrics: list[MetricInResponse] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StructuredLogType(Enum):
|
||||
"""The type of structured log event payload.
|
||||
:cvar SPAN_START: Event indicating the start of a new span
|
||||
:cvar SPAN_END: Event indicating the completion of a span
|
||||
"""
|
||||
|
||||
SPAN_START = "span_start"
|
||||
SPAN_END = "span_end"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStartPayload(BaseModel):
|
||||
"""Payload for a span start event.
|
||||
:param type: Payload type identifier set to SPAN_START
|
||||
:param name: Human-readable name describing the operation this span represents
|
||||
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
|
||||
"""
|
||||
|
||||
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
|
||||
name: str
|
||||
parent_span_id: str | None = None
|
||||
|
|
@ -140,6 +218,11 @@ class SpanStartPayload(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class SpanEndPayload(BaseModel):
|
||||
"""Payload for a span end event.
|
||||
:param type: Payload type identifier set to SPAN_END
|
||||
:param status: The final status of the span indicating success or failure
|
||||
"""
|
||||
|
||||
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
|
||||
status: SpanStatus
|
||||
|
||||
|
|
@ -153,6 +236,11 @@ register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
|||
|
||||
@json_schema_type
|
||||
class StructuredLogEvent(EventCommon):
|
||||
"""A structured log event containing typed payload data.
|
||||
:param type: Event type identifier set to STRUCTURED_LOG
|
||||
:param payload: The structured payload data for the log event
|
||||
"""
|
||||
|
||||
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
|
||||
payload: StructuredLogPayload
|
||||
|
||||
|
|
@ -166,6 +254,14 @@ register_schema(Event, name="Event")
|
|||
|
||||
@json_schema_type
|
||||
class EvalTrace(BaseModel):
|
||||
"""A trace record for evaluation purposes.
|
||||
:param session_id: Unique identifier for the evaluation session
|
||||
:param step: The evaluation step or phase identifier
|
||||
:param input: The input data for the evaluation
|
||||
:param output: The actual output produced during evaluation
|
||||
:param expected_output: The expected output for comparison during evaluation
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
step: str
|
||||
input: str
|
||||
|
|
@ -175,11 +271,22 @@ class EvalTrace(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class SpanWithStatus(Span):
|
||||
"""A span that includes status information.
|
||||
:param status: (Optional) The current status of the span
|
||||
"""
|
||||
|
||||
status: SpanStatus | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryConditionOp(Enum):
|
||||
"""Comparison operators for query conditions.
|
||||
:cvar EQ: Equal to comparison
|
||||
:cvar NE: Not equal to comparison
|
||||
:cvar GT: Greater than comparison
|
||||
:cvar LT: Less than comparison
|
||||
"""
|
||||
|
||||
EQ = "eq"
|
||||
NE = "ne"
|
||||
GT = "gt"
|
||||
|
|
@ -188,29 +295,59 @@ class QueryConditionOp(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class QueryCondition(BaseModel):
|
||||
"""A condition for filtering query results.
|
||||
:param key: The attribute key to filter on
|
||||
:param op: The comparison operator to apply
|
||||
:param value: The value to compare against
|
||||
"""
|
||||
|
||||
key: str
|
||||
op: QueryConditionOp
|
||||
value: Any
|
||||
|
||||
|
||||
class QueryTracesResponse(BaseModel):
|
||||
"""Response containing a list of traces.
|
||||
:param data: List of traces matching the query criteria
|
||||
"""
|
||||
|
||||
data: list[Trace]
|
||||
|
||||
|
||||
class QuerySpansResponse(BaseModel):
|
||||
"""Response containing a list of spans.
|
||||
:param data: List of spans matching the query criteria
|
||||
"""
|
||||
|
||||
data: list[Span]
|
||||
|
||||
|
||||
class QuerySpanTreeResponse(BaseModel):
|
||||
"""Response containing a tree structure of spans.
|
||||
:param data: Dictionary mapping span IDs to spans with status information
|
||||
"""
|
||||
|
||||
data: dict[str, SpanWithStatus]
|
||||
|
||||
|
||||
class MetricQueryType(Enum):
|
||||
"""The type of metric query to perform.
|
||||
:cvar RANGE: Query metrics over a time range
|
||||
:cvar INSTANT: Query metrics at a specific point in time
|
||||
"""
|
||||
|
||||
RANGE = "range"
|
||||
INSTANT = "instant"
|
||||
|
||||
|
||||
class MetricLabelOperator(Enum):
|
||||
"""Operators for matching metric labels.
|
||||
:cvar EQUALS: Label value must equal the specified value
|
||||
:cvar NOT_EQUALS: Label value must not equal the specified value
|
||||
:cvar REGEX_MATCH: Label value must match the specified regular expression
|
||||
:cvar REGEX_NOT_MATCH: Label value must not match the specified regular expression
|
||||
"""
|
||||
|
||||
EQUALS = "="
|
||||
NOT_EQUALS = "!="
|
||||
REGEX_MATCH = "=~"
|
||||
|
|
@ -218,6 +355,12 @@ class MetricLabelOperator(Enum):
|
|||
|
||||
|
||||
class MetricLabelMatcher(BaseModel):
|
||||
"""A matcher for filtering metrics by label values.
|
||||
:param name: The name of the label to match
|
||||
:param value: The value to match against
|
||||
:param operator: The comparison operator to use for matching
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: str
|
||||
operator: MetricLabelOperator = MetricLabelOperator.EQUALS
|
||||
|
|
@ -225,24 +368,44 @@ class MetricLabelMatcher(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class MetricLabel(BaseModel):
|
||||
"""A label associated with a metric.
|
||||
:param name: The name of the label
|
||||
:param value: The value of the label
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricDataPoint(BaseModel):
|
||||
"""A single data point in a metric time series.
|
||||
:param timestamp: Unix timestamp when the metric value was recorded
|
||||
:param value: The numeric value of the metric at this timestamp
|
||||
"""
|
||||
|
||||
timestamp: int
|
||||
value: float
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricSeries(BaseModel):
|
||||
"""A time series of metric data points.
|
||||
:param metric: The name of the metric
|
||||
:param labels: List of labels associated with this metric series
|
||||
:param values: List of data points in chronological order
|
||||
"""
|
||||
|
||||
metric: str
|
||||
labels: list[MetricLabel]
|
||||
values: list[MetricDataPoint]
|
||||
|
||||
|
||||
class QueryMetricsResponse(BaseModel):
|
||||
"""Response containing metric time series data.
|
||||
:param data: List of metric series matching the query criteria
|
||||
"""
|
||||
|
||||
data: list[MetricSeries]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class RRFRanker(BaseModel):
|
|||
|
||||
:param type: The type of ranker, always "rrf"
|
||||
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
|
||||
Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009).
|
||||
Must be greater than 0
|
||||
"""
|
||||
|
||||
type: Literal["rrf"] = "rrf"
|
||||
|
|
@ -76,12 +76,25 @@ class RAGDocument(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RAGQueryResult(BaseModel):
|
||||
"""Result of a RAG query containing retrieved content and metadata.
|
||||
|
||||
:param content: (Optional) The retrieved content from the query
|
||||
:param metadata: Additional metadata about the query result
|
||||
"""
|
||||
|
||||
content: InterleavedContent | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryGenerator(Enum):
|
||||
"""Types of query generators for RAG systems.
|
||||
|
||||
:cvar default: Default query generator using simple text processing
|
||||
:cvar llm: LLM-based query generator for enhanced query understanding
|
||||
:cvar custom: Custom query generator implementation
|
||||
"""
|
||||
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
|
@ -103,12 +116,25 @@ class RAGSearchMode(StrEnum):
|
|||
|
||||
@json_schema_type
|
||||
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||
"""Configuration for the default RAG query generator.
|
||||
|
||||
:param type: Type of query generator, always 'default'
|
||||
:param separator: String separator used to join query terms
|
||||
"""
|
||||
|
||||
type: Literal["default"] = "default"
|
||||
separator: str = " "
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||
"""Configuration for the LLM-based RAG query generator.
|
||||
|
||||
:param type: Type of query generator, always 'llm'
|
||||
:param model: Name of the language model to use for query generation
|
||||
:param template: Template string for formatting the query generation prompt
|
||||
"""
|
||||
|
||||
type: Literal["llm"] = "llm"
|
||||
model: str
|
||||
template: str
|
||||
|
|
@ -166,7 +192,12 @@ class RAGToolRuntime(Protocol):
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
"""Index documents so they can be used by the RAG system"""
|
||||
"""Index documents so they can be used by the RAG system.
|
||||
|
||||
:param documents: List of documents to index in the RAG system
|
||||
:param vector_db_id: ID of the vector database to store the document embeddings
|
||||
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/rag-tool/query", method="POST")
|
||||
|
|
@ -176,5 +207,11 @@ class RAGToolRuntime(Protocol):
|
|||
vector_db_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
"""Query the RAG system for context; typically invoked by the agent"""
|
||||
"""Query the RAG system for context; typically invoked by the agent.
|
||||
|
||||
:param content: The query content to search for in the indexed documents
|
||||
:param vector_db_ids: List of vector database IDs to search within
|
||||
:param query_config: (Optional) Configuration parameters for the query operation
|
||||
:returns: RAGQueryResult containing the retrieved content and metadata
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -20,6 +20,15 @@ from .rag_tool import RAGToolRuntime
|
|||
|
||||
@json_schema_type
|
||||
class ToolParameter(BaseModel):
|
||||
"""Parameter definition for a tool.
|
||||
|
||||
:param name: Name of the parameter
|
||||
:param parameter_type: Type of the parameter (e.g., string, integer)
|
||||
:param description: Human-readable description of what the parameter does
|
||||
:param required: Whether this parameter is required for tool invocation
|
||||
:param default: (Optional) Default value for the parameter if not provided
|
||||
"""
|
||||
|
||||
name: str
|
||||
parameter_type: str
|
||||
description: str
|
||||
|
|
@ -29,6 +38,15 @@ class ToolParameter(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
"""A tool that can be invoked by agents.
|
||||
|
||||
:param type: Type of resource, always 'tool'
|
||||
:param toolgroup_id: ID of the tool group this tool belongs to
|
||||
:param description: Human-readable description of what the tool does
|
||||
:param parameters: List of parameters this tool accepts
|
||||
:param metadata: (Optional) Additional metadata about the tool
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||
toolgroup_id: str
|
||||
description: str
|
||||
|
|
@ -38,6 +56,14 @@ class Tool(Resource):
|
|||
|
||||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
"""Tool definition used in runtime contexts.
|
||||
|
||||
:param name: Name of the tool
|
||||
:param description: (Optional) Human-readable description of what the tool does
|
||||
:param parameters: (Optional) List of parameters this tool accepts
|
||||
:param metadata: (Optional) Additional metadata about the tool
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: list[ToolParameter] | None = None
|
||||
|
|
@ -46,6 +72,14 @@ class ToolDef(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ToolGroupInput(BaseModel):
|
||||
"""Input data for registering a tool group.
|
||||
|
||||
:param toolgroup_id: Unique identifier for the tool group
|
||||
:param provider_id: ID of the provider that will handle this tool group
|
||||
:param args: (Optional) Additional arguments to pass to the provider
|
||||
:param mcp_endpoint: (Optional) Model Context Protocol endpoint for remote tools
|
||||
"""
|
||||
|
||||
toolgroup_id: str
|
||||
provider_id: str
|
||||
args: dict[str, Any] | None = None
|
||||
|
|
@ -54,6 +88,13 @@ class ToolGroupInput(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ToolGroup(Resource):
|
||||
"""A group of related tools managed together.
|
||||
|
||||
:param type: Type of resource, always 'tool_group'
|
||||
:param mcp_endpoint: (Optional) Model Context Protocol endpoint for remote tools
|
||||
:param args: (Optional) Additional arguments for the tool group
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.tool_group] = ResourceType.tool_group
|
||||
mcp_endpoint: URL | None = None
|
||||
args: dict[str, Any] | None = None
|
||||
|
|
@ -61,6 +102,14 @@ class ToolGroup(Resource):
|
|||
|
||||
@json_schema_type
|
||||
class ToolInvocationResult(BaseModel):
|
||||
"""Result of a tool invocation.
|
||||
|
||||
:param content: (Optional) The output content from the tool execution
|
||||
:param error_message: (Optional) Error message if the tool execution failed
|
||||
:param error_code: (Optional) Numeric error code if the tool execution failed
|
||||
:param metadata: (Optional) Additional metadata about the tool execution
|
||||
"""
|
||||
|
||||
content: InterleavedContent | None = None
|
||||
error_message: str | None = None
|
||||
error_code: int | None = None
|
||||
|
|
@ -73,14 +122,29 @@ class ToolStore(Protocol):
|
|||
|
||||
|
||||
class ListToolGroupsResponse(BaseModel):
|
||||
"""Response containing a list of tool groups.
|
||||
|
||||
:param data: List of tool groups
|
||||
"""
|
||||
|
||||
data: list[ToolGroup]
|
||||
|
||||
|
||||
class ListToolsResponse(BaseModel):
|
||||
"""Response containing a list of tools.
|
||||
|
||||
:param data: List of tools
|
||||
"""
|
||||
|
||||
data: list[Tool]
|
||||
|
||||
|
||||
class ListToolDefsResponse(BaseModel):
|
||||
"""Response containing a list of tool definitions.
|
||||
|
||||
:param data: List of tool definitions
|
||||
"""
|
||||
|
||||
data: list[ToolDef]
|
||||
|
||||
|
||||
|
|
@ -158,6 +222,11 @@ class ToolGroups(Protocol):
|
|||
|
||||
|
||||
class SpecialToolGroup(Enum):
|
||||
"""Special tool groups with predefined functionality.
|
||||
|
||||
:cvar rag_tool: Retrieval-Augmented Generation tool group for document search and retrieval
|
||||
"""
|
||||
|
||||
rag_tool = "rag_tool"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,13 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
@json_schema_type
|
||||
class VectorDB(Resource):
|
||||
"""Vector database resource for storing and querying vector embeddings.
|
||||
|
||||
:param type: Type of resource, always 'vector_db' for vector databases
|
||||
:param embedding_model: Name of the embedding model to use for vector generation
|
||||
:param embedding_dimension: Dimension of the embedding vectors
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.vector_db] = ResourceType.vector_db
|
||||
|
||||
embedding_model: str
|
||||
|
|
@ -31,6 +38,14 @@ class VectorDB(Resource):
|
|||
|
||||
|
||||
class VectorDBInput(BaseModel):
|
||||
"""Input parameters for creating or configuring a vector database.
|
||||
|
||||
:param vector_db_id: Unique identifier for the vector database
|
||||
:param embedding_model: Name of the embedding model to use for vector generation
|
||||
:param embedding_dimension: Dimension of the embedding vectors
|
||||
:param provider_vector_db_id: (Optional) Provider-specific identifier for the vector database
|
||||
"""
|
||||
|
||||
vector_db_id: str
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
|
|
@ -39,6 +54,11 @@ class VectorDBInput(BaseModel):
|
|||
|
||||
|
||||
class ListVectorDBsResponse(BaseModel):
|
||||
"""Response from listing vector databases.
|
||||
|
||||
:param data: List of vector databases
|
||||
"""
|
||||
|
||||
data: list[VectorDB]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.strong_typing.schema import register_schema
|
||||
|
||||
|
|
@ -94,12 +94,27 @@ class Chunk(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class QueryChunksResponse(BaseModel):
|
||||
"""Response from querying chunks in a vector database.
|
||||
|
||||
:param chunks: List of content chunks returned from the query
|
||||
:param scores: Relevance scores corresponding to each returned chunk
|
||||
"""
|
||||
|
||||
chunks: list[Chunk]
|
||||
scores: list[float]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileCounts(BaseModel):
|
||||
"""File processing status counts for a vector store.
|
||||
|
||||
:param completed: Number of files that have been successfully processed
|
||||
:param cancelled: Number of files that had their processing cancelled
|
||||
:param failed: Number of files that failed to process
|
||||
:param in_progress: Number of files currently being processed
|
||||
:param total: Total number of files in the vector store
|
||||
"""
|
||||
|
||||
completed: int
|
||||
cancelled: int
|
||||
failed: int
|
||||
|
|
@ -109,7 +124,20 @@ class VectorStoreFileCounts(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreObject(BaseModel):
|
||||
"""OpenAI Vector Store object."""
|
||||
"""OpenAI Vector Store object.
|
||||
|
||||
:param id: Unique identifier for the vector store
|
||||
:param object: Object type identifier, always "vector_store"
|
||||
:param created_at: Timestamp when the vector store was created
|
||||
:param name: (Optional) Name of the vector store
|
||||
:param usage_bytes: Storage space used by the vector store in bytes
|
||||
:param file_counts: File processing status counts for the vector store
|
||||
:param status: Current status of the vector store
|
||||
:param expires_after: (Optional) Expiration policy for the vector store
|
||||
:param expires_at: (Optional) Timestamp when the vector store will expire
|
||||
:param last_active_at: (Optional) Timestamp of last activity on the vector store
|
||||
:param metadata: Set of key-value pairs that can be attached to the vector store
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store"
|
||||
|
|
@ -126,7 +154,14 @@ class VectorStoreObject(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreCreateRequest(BaseModel):
|
||||
"""Request to create a vector store."""
|
||||
"""Request to create a vector store.
|
||||
|
||||
:param name: (Optional) Name for the vector store
|
||||
:param file_ids: List of file IDs to include in the vector store
|
||||
:param expires_after: (Optional) Expiration policy for the vector store
|
||||
:param chunking_strategy: (Optional) Strategy for splitting files into chunks
|
||||
:param metadata: Set of key-value pairs that can be attached to the vector store
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
file_ids: list[str] = Field(default_factory=list)
|
||||
|
|
@ -137,7 +172,12 @@ class VectorStoreCreateRequest(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreModifyRequest(BaseModel):
|
||||
"""Request to modify a vector store."""
|
||||
"""Request to modify a vector store.
|
||||
|
||||
:param name: (Optional) Updated name for the vector store
|
||||
:param expires_after: (Optional) Updated expiration policy for the vector store
|
||||
:param metadata: (Optional) Updated set of key-value pairs for the vector store
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
expires_after: dict[str, Any] | None = None
|
||||
|
|
@ -146,7 +186,14 @@ class VectorStoreModifyRequest(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreListResponse(BaseModel):
|
||||
"""Response from listing vector stores."""
|
||||
"""Response from listing vector stores.
|
||||
|
||||
:param object: Object type identifier, always "list"
|
||||
:param data: List of vector store objects
|
||||
:param first_id: (Optional) ID of the first vector store in the list for pagination
|
||||
:param last_id: (Optional) ID of the last vector store in the list for pagination
|
||||
:param has_more: Whether there are more vector stores available beyond this page
|
||||
"""
|
||||
|
||||
object: str = "list"
|
||||
data: list[VectorStoreObject]
|
||||
|
|
@ -157,7 +204,14 @@ class VectorStoreListResponse(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreSearchRequest(BaseModel):
|
||||
"""Request to search a vector store."""
|
||||
"""Request to search a vector store.
|
||||
|
||||
:param query: Search query as a string or list of strings
|
||||
:param filters: (Optional) Filters based on file attributes to narrow search results
|
||||
:param max_num_results: Maximum number of results to return, defaults to 10
|
||||
:param ranking_options: (Optional) Options for ranking and filtering search results
|
||||
:param rewrite_query: Whether to rewrite the query for better vector search performance
|
||||
"""
|
||||
|
||||
query: str | list[str]
|
||||
filters: dict[str, Any] | None = None
|
||||
|
|
@ -168,13 +222,26 @@ class VectorStoreSearchRequest(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreContent(BaseModel):
|
||||
"""Content item from a vector store file or search result.
|
||||
|
||||
:param type: Content type, currently only "text" is supported
|
||||
:param text: The actual text content
|
||||
"""
|
||||
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreSearchResponse(BaseModel):
|
||||
"""Response from searching a vector store."""
|
||||
"""Response from searching a vector store.
|
||||
|
||||
:param file_id: Unique identifier of the file containing the result
|
||||
:param filename: Name of the file containing the result
|
||||
:param score: Relevance score for this search result
|
||||
:param attributes: (Optional) Key-value attributes associated with the file
|
||||
:param content: List of content items matching the search query
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
|
|
@ -185,7 +252,14 @@ class VectorStoreSearchResponse(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreSearchResponsePage(BaseModel):
|
||||
"""Response from searching a vector store."""
|
||||
"""Paginated response from searching a vector store.
|
||||
|
||||
:param object: Object type identifier for the search results page
|
||||
:param search_query: The original search query that was executed
|
||||
:param data: List of search result objects
|
||||
:param has_more: Whether there are more results available beyond this page
|
||||
:param next_page: (Optional) Token for retrieving the next page of results
|
||||
"""
|
||||
|
||||
object: str = "vector_store.search_results.page"
|
||||
search_query: str
|
||||
|
|
@ -196,7 +270,12 @@ class VectorStoreSearchResponsePage(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreDeleteResponse(BaseModel):
|
||||
"""Response from deleting a vector store."""
|
||||
"""Response from deleting a vector store.
|
||||
|
||||
:param id: Unique identifier of the deleted vector store
|
||||
:param object: Object type identifier for the deletion response
|
||||
:param deleted: Whether the deletion operation was successful
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store.deleted"
|
||||
|
|
@ -205,17 +284,34 @@ class VectorStoreDeleteResponse(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreChunkingStrategyAuto(BaseModel):
|
||||
"""Automatic chunking strategy for vector store files.
|
||||
|
||||
:param type: Strategy type, always "auto" for automatic chunking
|
||||
"""
|
||||
|
||||
type: Literal["auto"] = "auto"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreChunkingStrategyStaticConfig(BaseModel):
|
||||
"""Configuration for static chunking strategy.
|
||||
|
||||
:param chunk_overlap_tokens: Number of tokens to overlap between adjacent chunks
|
||||
:param max_chunk_size_tokens: Maximum number of tokens per chunk, must be between 100 and 4096
|
||||
"""
|
||||
|
||||
chunk_overlap_tokens: int = 400
|
||||
max_chunk_size_tokens: int = Field(800, ge=100, le=4096)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreChunkingStrategyStatic(BaseModel):
|
||||
"""Static chunking strategy with configurable parameters.
|
||||
|
||||
:param type: Strategy type, always "static" for static chunking
|
||||
:param static: Configuration parameters for the static chunking strategy
|
||||
"""
|
||||
|
||||
type: Literal["static"] = "static"
|
||||
static: VectorStoreChunkingStrategyStaticConfig
|
||||
|
||||
|
|
@ -227,6 +323,12 @@ register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
|
|||
|
||||
|
||||
class SearchRankingOptions(BaseModel):
|
||||
"""Options for ranking and filtering search results.
|
||||
|
||||
:param ranker: (Optional) Name of the ranking algorithm to use
|
||||
:param score_threshold: (Optional) Minimum relevance score threshold for results
|
||||
"""
|
||||
|
||||
ranker: str | None = None
|
||||
# NOTE: OpenAI File Search Tool requires threshold to be between 0 and 1, however
|
||||
# we don't guarantee that the score is between 0 and 1, so will leave this unconstrained
|
||||
|
|
@ -236,6 +338,12 @@ class SearchRankingOptions(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreFileLastError(BaseModel):
|
||||
"""Error information for failed vector store file processing.
|
||||
|
||||
:param code: Error code indicating the type of failure
|
||||
:param message: Human-readable error message describing the failure
|
||||
"""
|
||||
|
||||
code: Literal["server_error"] | Literal["rate_limit_exceeded"]
|
||||
message: str
|
||||
|
||||
|
|
@ -246,7 +354,18 @@ register_schema(VectorStoreFileStatus, name="VectorStoreFileStatus")
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreFileObject(BaseModel):
|
||||
"""OpenAI Vector Store File object."""
|
||||
"""OpenAI Vector Store File object.
|
||||
|
||||
:param id: Unique identifier for the file
|
||||
:param object: Object type identifier, always "vector_store.file"
|
||||
:param attributes: Key-value attributes associated with the file
|
||||
:param chunking_strategy: Strategy used for splitting the file into chunks
|
||||
:param created_at: Timestamp when the file was added to the vector store
|
||||
:param last_error: (Optional) Error information if file processing failed
|
||||
:param status: Current processing status of the file
|
||||
:param usage_bytes: Storage space used by this file in bytes
|
||||
:param vector_store_id: ID of the vector store containing this file
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store.file"
|
||||
|
|
@ -261,7 +380,14 @@ class VectorStoreFileObject(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreListFilesResponse(BaseModel):
|
||||
"""Response from listing vector stores."""
|
||||
"""Response from listing files in a vector store.
|
||||
|
||||
:param object: Object type identifier, always "list"
|
||||
:param data: List of vector store file objects
|
||||
:param first_id: (Optional) ID of the first file in the list for pagination
|
||||
:param last_id: (Optional) ID of the last file in the list for pagination
|
||||
:param has_more: Whether there are more files available beyond this page
|
||||
"""
|
||||
|
||||
object: str = "list"
|
||||
data: list[VectorStoreFileObject]
|
||||
|
|
@ -272,7 +398,13 @@ class VectorStoreListFilesResponse(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreFileContentsResponse(BaseModel):
|
||||
"""Response from retrieving the contents of a vector store file."""
|
||||
"""Response from retrieving the contents of a vector store file.
|
||||
|
||||
:param file_id: Unique identifier for the file
|
||||
:param filename: Name of the file
|
||||
:param attributes: Key-value attributes associated with the file
|
||||
:param content: List of content items from the file
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
|
|
@ -282,7 +414,12 @@ class VectorStoreFileContentsResponse(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VectorStoreFileDeleteResponse(BaseModel):
|
||||
"""Response from deleting a vector store file."""
|
||||
"""Response from deleting a vector store file.
|
||||
|
||||
:param id: Unique identifier of the deleted file
|
||||
:param object: Object type identifier for the deletion response
|
||||
:param deleted: Whether the deletion operation was successful
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store.file.deleted"
|
||||
|
|
@ -478,6 +615,11 @@ class VectorIO(Protocol):
|
|||
"""List files in a vector store.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to list files from.
|
||||
:param limit: (Optional) A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
||||
:param order: (Optional) Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
|
||||
:param after: (Optional) A cursor for use in pagination. `after` is an object ID that defines your place in the list.
|
||||
:param before: (Optional) A cursor for use in pagination. `before` is an object ID that defines your place in the list.
|
||||
:param filter: (Optional) Filter by file status to only return files with the specified status.
|
||||
:returns: A VectorStoreListFilesResponse containing the list of files.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -323,7 +323,7 @@ def _hf_download(
|
|||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
repo_id = model.huggingface_repo
|
||||
if repo_id is None:
|
||||
|
|
@ -361,7 +361,7 @@ def _meta_download(
|
|||
info: "LlamaDownloadInfo",
|
||||
max_concurrent_downloads: int,
|
||||
):
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
output_dir = Path(model_local_dir(model.descriptor()))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
|
@ -403,7 +403,7 @@ class Manifest(BaseModel):
|
|||
|
||||
|
||||
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
with open(manifest_file) as f:
|
||||
d = json.load(f)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from pathlib import Path
|
|||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,73 +23,80 @@ from termcolor import colored, cprint
|
|||
|
||||
from llama_stack.cli.stack.utils import ImageType
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.distribution.build import (
|
||||
from llama_stack.core.build import (
|
||||
SERVER_DEPENDENCIES,
|
||||
build_image,
|
||||
get_provider_dependencies,
|
||||
)
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import (
|
||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.core.datatypes import (
|
||||
BuildConfig,
|
||||
BuildProvider,
|
||||
DistributionSpec,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.resolver import InvalidProviderError
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.exec import formulate_run_args, run_command
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def available_templates_specs() -> dict[str, BuildConfig]:
|
||||
def available_distros_specs() -> dict[str, BuildConfig]:
|
||||
import yaml
|
||||
|
||||
template_specs = {}
|
||||
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
||||
template_name = p.parent.name
|
||||
distro_specs = {}
|
||||
for p in DISTRIBS_PATH.rglob("*build.yaml"):
|
||||
distro_name = p.parent.name
|
||||
with open(p) as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
template_specs[template_name] = build_config
|
||||
return template_specs
|
||||
distro_specs[distro_name] = build_config
|
||||
return distro_specs
|
||||
|
||||
|
||||
def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||
if args.list_templates:
|
||||
return _run_template_list_cmd()
|
||||
if args.list_distros:
|
||||
return _run_distro_list_cmd()
|
||||
|
||||
if args.image_type == ImageType.VENV.value:
|
||||
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||
image_name = args.image_name or current_venv
|
||||
elif args.image_type == ImageType.CONDA.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
image_name = args.image_name or current_conda_env
|
||||
else:
|
||||
image_name = args.image_name
|
||||
|
||||
if args.template:
|
||||
available_templates = available_templates_specs()
|
||||
if args.template not in available_templates:
|
||||
cprint(
|
||||
"The --template argument is deprecated. Please use --distro instead.",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
distro_name = args.template
|
||||
else:
|
||||
distro_name = args.distribution
|
||||
|
||||
if distro_name:
|
||||
available_distros = available_distros_specs()
|
||||
if distro_name not in available_distros:
|
||||
cprint(
|
||||
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
||||
f"Could not find distribution {distro_name}. Please run `llama stack build --list-distros` to check out the available distributions",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
build_config = available_templates[args.template]
|
||||
build_config = available_distros[distro_name]
|
||||
if args.image_type:
|
||||
build_config.image_type = args.image_type
|
||||
else:
|
||||
cprint(
|
||||
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
|
||||
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {distro_name}",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
|
@ -132,14 +139,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
if not args.image_type:
|
||||
cprint(
|
||||
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||
f"Please specify a image-type (container | venv) for {args.template}",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
build_config = BuildConfig(image_type=args.image_type, distribution_spec=distribution_spec)
|
||||
elif not args.config and not args.template:
|
||||
elif not args.config and not distro_name:
|
||||
name = prompt(
|
||||
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
||||
validator=Validator.from_callable(
|
||||
|
|
@ -158,22 +165,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
),
|
||||
)
|
||||
|
||||
if image_type == ImageType.CONDA.value:
|
||||
if not image_name:
|
||||
cprint(
|
||||
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
image_name = f"llamastack-{name}"
|
||||
else:
|
||||
cprint(
|
||||
f"Using conda environment {image_name}",
|
||||
color="green",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
image_name = f"llamastack-{name}"
|
||||
image_name = f"llamastack-{name}"
|
||||
|
||||
cprint(
|
||||
textwrap.dedent(
|
||||
|
|
@ -236,7 +228,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
sys.exit(1)
|
||||
|
||||
if args.print_deps_only:
|
||||
print(f"# Dependencies for {args.template or args.config or image_name}")
|
||||
print(f"# Dependencies for {distro_name or args.config or image_name}")
|
||||
normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
print(f"uv pip install {' '.join(normal_deps)}")
|
||||
|
|
@ -251,7 +243,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
build_config,
|
||||
image_name=image_name,
|
||||
config_path=args.config,
|
||||
template_name=args.template,
|
||||
distro_name=distro_name,
|
||||
)
|
||||
|
||||
except (Exception, RuntimeError) as exc:
|
||||
|
|
@ -279,7 +271,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if config.external_providers_dir and not config.external_providers_dir.exists():
|
||||
config.external_providers_dir.mkdir(exist_ok=True)
|
||||
run_args = formulate_run_args(args.image_type, args.image_name)
|
||||
run_args = formulate_run_args(args.image_type, image_name or config.image_name)
|
||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)])
|
||||
run_command(run_args)
|
||||
|
||||
|
|
@ -362,20 +354,17 @@ def _generate_run_config(
|
|||
def _run_stack_build_command_from_build_config(
|
||||
build_config: BuildConfig,
|
||||
image_name: str | None = None,
|
||||
template_name: str | None = None,
|
||||
distro_name: str | None = None,
|
||||
config_path: str | None = None,
|
||||
) -> Path | Traversable:
|
||||
image_name = image_name or build_config.image_name
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
if template_name:
|
||||
image_name = f"distribution-{template_name}"
|
||||
if distro_name:
|
||||
image_name = f"distribution-{distro_name}"
|
||||
else:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a container image without a template")
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a conda image")
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
else:
|
||||
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
|
||||
image_name = "__system__"
|
||||
if not image_name:
|
||||
|
|
@ -385,9 +374,9 @@ def _run_stack_build_command_from_build_config(
|
|||
if image_name is None:
|
||||
raise ValueError("image_name should not be None after validation")
|
||||
|
||||
if template_name:
|
||||
build_dir = DISTRIBS_BASE_DIR / template_name
|
||||
build_file_path = build_dir / f"{template_name}-build.yaml"
|
||||
if distro_name:
|
||||
build_dir = DISTRIBS_BASE_DIR / distro_name
|
||||
build_file_path = build_dir / f"{distro_name}-build.yaml"
|
||||
else:
|
||||
if image_name is None:
|
||||
raise ValueError("image_name cannot be None")
|
||||
|
|
@ -398,7 +387,7 @@ def _run_stack_build_command_from_build_config(
|
|||
run_config_file = None
|
||||
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
||||
# Only do this if we're building a container image and we're not using a template
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not distro_name and config_path:
|
||||
cprint("Generating run.yaml file", color="yellow", file=sys.stderr)
|
||||
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||
|
||||
|
|
@ -431,48 +420,46 @@ def _run_stack_build_command_from_build_config(
|
|||
|
||||
return_code = build_image(
|
||||
build_config,
|
||||
build_file_path,
|
||||
image_name,
|
||||
template_or_config=template_name or config_path or str(build_file_path),
|
||||
distro_or_config=distro_name or config_path or str(build_file_path),
|
||||
run_config=run_config_file.as_posix() if run_config_file else None,
|
||||
)
|
||||
if return_code != 0:
|
||||
raise RuntimeError(f"Failed to build image {image_name}")
|
||||
|
||||
if template_name:
|
||||
# copy run.yaml from template to build_dir instead of generating it again
|
||||
template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml"
|
||||
run_config_file = build_dir / f"{template_name}-run.yaml"
|
||||
if distro_name:
|
||||
# copy run.yaml from distribution to build_dir instead of generating it again
|
||||
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
||||
|
||||
with importlib.resources.as_file(template_path) as path:
|
||||
with importlib.resources.as_file(distro_path) as path:
|
||||
shutil.copy(path, run_config_file)
|
||||
|
||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||
cprint(f"You can find the newly-built template here: {run_config_file}", color="blue", file=sys.stderr)
|
||||
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
||||
cprint(
|
||||
"You can run the new Llama Stack distro via: "
|
||||
+ colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"),
|
||||
color="green",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return template_path
|
||||
return distro_path
|
||||
else:
|
||||
return _generate_run_config(build_config, build_dir, image_name)
|
||||
|
||||
|
||||
def _run_template_list_cmd() -> None:
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
def _run_distro_list_cmd() -> None:
|
||||
headers = [
|
||||
"Template Name",
|
||||
"Distribution Name",
|
||||
# "Providers",
|
||||
"Description",
|
||||
]
|
||||
|
||||
rows = []
|
||||
for template_name, spec in available_templates_specs().items():
|
||||
for distro_name, spec in available_distros_specs().items():
|
||||
rows.append(
|
||||
[
|
||||
template_name,
|
||||
distro_name,
|
||||
# json.dumps(spec.distribution_spec.providers, indent=2),
|
||||
spec.distribution_spec.description,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -27,21 +27,31 @@ class StackBuild(Subcommand):
|
|||
"--config",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
|
||||
help="Path to a config file to use for the build. You can find example configs in llama_stack.cores/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates",
|
||||
help="""(deprecated) Name of the example template config to use for build. You may use `llama stack build --list-distros` to check out the available distributions""",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--distro",
|
||||
"--distribution",
|
||||
dest="distribution",
|
||||
type=str,
|
||||
default=None,
|
||||
help="""Name of the distribution to use for build. You may use `llama stack build --list-distros` to check out the available distributions""",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--list-templates",
|
||||
"--list-distros",
|
||||
"--list-distributions",
|
||||
action="store_true",
|
||||
dest="list_distros",
|
||||
default=False,
|
||||
help="Show the available templates for building a Llama Stack distribution",
|
||||
help="Show the available distributions for building a Llama Stack distribution",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
|
|
@ -56,7 +66,7 @@ class StackBuild(Subcommand):
|
|||
"--image-name",
|
||||
type=str,
|
||||
help=textwrap.dedent(
|
||||
f"""[for image-type={"|".join(e.value for e in ImageType)}] Name of the conda or virtual environment to use for
|
||||
f"""[for image-type={"|".join(e.value for e in ImageType)}] Name of the virtual environment to use for
|
||||
the build. If not specified, currently active environment will be used if found.
|
||||
"""
|
||||
),
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class StackListApis(Subcommand):
|
|||
|
||||
def _run_apis_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.distribution.distribution import stack_apis
|
||||
from llama_stack.core.distribution import stack_apis
|
||||
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
headers = [
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class StackListProviders(Subcommand):
|
|||
|
||||
@property
|
||||
def providable_apis(self):
|
||||
from llama_stack.distribution.distribution import providable_apis
|
||||
from llama_stack.core.distribution import providable_apis
|
||||
|
||||
return [api.value for api in providable_apis()]
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ class StackListProviders(Subcommand):
|
|||
|
||||
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.distribution.distribution import Api, get_provider_registry
|
||||
from llama_stack.core.distribution import Api, get_provider_registry
|
||||
|
||||
all_providers = get_provider_registry()
|
||||
if args.api:
|
||||
|
|
|
|||
|
|
@ -35,8 +35,8 @@ class StackRun(Subcommand):
|
|||
"config",
|
||||
type=str,
|
||||
nargs="?", # Make it optional
|
||||
metavar="config | template",
|
||||
help="Path to config file to use for the run or name of known template (`llama stack list` for a list).",
|
||||
metavar="config | distro",
|
||||
help="Path to config file to use for the run or name of known distro (`llama stack list` for a list).",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--port",
|
||||
|
|
@ -47,7 +47,8 @@ class StackRun(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"--image-name",
|
||||
type=str,
|
||||
help="Name of the image to run.",
|
||||
default=None,
|
||||
help="Name of the image to run. Defaults to the current environment",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--env",
|
||||
|
|
@ -58,7 +59,7 @@ class StackRun(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"--image-type",
|
||||
type=str,
|
||||
help="Image Type used during the build. This can be either conda or container or venv.",
|
||||
help="Image Type used during the build. This can be only venv.",
|
||||
choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value],
|
||||
)
|
||||
self.parser.add_argument(
|
||||
|
|
@ -67,44 +68,62 @@ class StackRun(Subcommand):
|
|||
help="Start the UI server",
|
||||
)
|
||||
|
||||
# If neither image type nor image name is provided, but at the same time
|
||||
# the current environment has conda breadcrumbs, then assume what the user
|
||||
# wants to use conda mode and not the usual default mode (using
|
||||
# pre-installed system packages).
|
||||
#
|
||||
# Note: yes, this is hacky. It's implemented this way to keep the existing
|
||||
# conda users unaffected by the switch of the default behavior to using
|
||||
# system packages.
|
||||
def _get_image_type_and_name(self, args: argparse.Namespace) -> tuple[str, str]:
|
||||
conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
if conda_env and args.image_name == conda_env:
|
||||
logger.warning(f"Conda detected. Using conda environment {conda_env} for the run.")
|
||||
return ImageType.CONDA.value, args.image_name
|
||||
return args.image_type, args.image_name
|
||||
def _resolve_config_and_distro(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
|
||||
"""Resolve config file path and distribution name from args.config"""
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
if not args.config:
|
||||
return None, None
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
distro_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a distribution
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "distributions" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
distro_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
return config_file, distro_name
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.core.utils.exec import formulate_run_args, run_command
|
||||
|
||||
if args.enable_ui:
|
||||
self._start_ui_development_server(args.port)
|
||||
image_type, image_name = self._get_image_type_and_name(args)
|
||||
image_type, image_name = args.image_type, args.image_name
|
||||
|
||||
if args.config:
|
||||
try:
|
||||
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
|
||||
config_file = resolve_config_or_template(args.config, Mode.RUN)
|
||||
config_file = resolve_config_or_distro(args.config, Mode.RUN)
|
||||
except ValueError as e:
|
||||
self.parser.error(str(e))
|
||||
else:
|
||||
config_file = None
|
||||
|
||||
# Check if config is required based on image type
|
||||
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file:
|
||||
self.parser.error("Config file is required for venv and conda environments")
|
||||
if image_type == ImageType.VENV.value and not config_file:
|
||||
self.parser.error("Config file is required for venv environment")
|
||||
|
||||
if config_file:
|
||||
logger.info(f"Using run configuration: {config_file}")
|
||||
|
|
@ -127,7 +146,7 @@ class StackRun(Subcommand):
|
|||
# using the current environment packages.
|
||||
if not image_type and not image_name:
|
||||
logger.info("No image type or image name provided. Assuming environment packages.")
|
||||
from llama_stack.distribution.server.server import main as server_main
|
||||
from llama_stack.core.server.server import main as server_main
|
||||
|
||||
# Build the server args from the current args passed to the CLI
|
||||
server_args = argparse.Namespace()
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from enum import Enum
|
|||
|
||||
|
||||
class ImageType(Enum):
|
||||
CONDA = "conda"
|
||||
CONTAINER = "container"
|
||||
VENV = "venv"
|
||||
|
||||
|
|
|
|||
|
|
@ -11,38 +11,19 @@ from llama_stack.log import get_logger
|
|||
logger = get_logger(name=__name__, category="cli")
|
||||
|
||||
|
||||
def add_config_template_args(parser: argparse.ArgumentParser):
|
||||
"""Add unified config/template arguments with backward compatibility."""
|
||||
# TODO: this can probably just be inlined now?
|
||||
def add_config_distro_args(parser: argparse.ArgumentParser):
|
||||
"""Add unified config/distro arguments."""
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
|
||||
group.add_argument(
|
||||
"config",
|
||||
nargs="?",
|
||||
help="Configuration file path or template name",
|
||||
)
|
||||
|
||||
# Backward compatibility arguments (deprecated)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
dest="config_deprecated",
|
||||
help="(DEPRECATED) Use positional argument [config] instead. Configuration file path",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--template",
|
||||
dest="template_deprecated",
|
||||
help="(DEPRECATED) Use positional argument [config] instead. Template name",
|
||||
help="Configuration file path or distribution name",
|
||||
)
|
||||
|
||||
|
||||
def get_config_from_args(args: argparse.Namespace) -> str | None:
|
||||
"""Extract config value from parsed arguments, handling both new and deprecated forms."""
|
||||
if args.config is not None:
|
||||
return str(args.config)
|
||||
elif hasattr(args, "config_deprecated") and args.config_deprecated is not None:
|
||||
logger.warning("Using deprecated --config argument. Use positional argument [config] instead.")
|
||||
return str(args.config_deprecated)
|
||||
elif hasattr(args, "template_deprecated") and args.template_deprecated is not None:
|
||||
logger.warning("Using deprecated --template argument. Use positional argument [config] instead.")
|
||||
return str(args.template_deprecated)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -
|
|||
|
||||
|
||||
def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
console = Console()
|
||||
model_dir = Path(model_local_dir(args.model_id))
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.core.datatypes import User
|
||||
|
||||
from .conditions import (
|
||||
Condition,
|
||||
|
|
@ -7,18 +7,17 @@
|
|||
import importlib.resources
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.datatypes import BuildConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.utils.exec import run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.core.datatypes import BuildConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.utils.exec import run_command
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.distributions.template import DistributionTemplate
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.templates.template import DistributionTemplate
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -106,9 +105,8 @@ def print_pip_install_help(config: BuildConfig):
|
|||
|
||||
def build_image(
|
||||
build_config: BuildConfig,
|
||||
build_file_path: Path,
|
||||
image_name: str,
|
||||
template_or_config: str,
|
||||
distro_or_config: str,
|
||||
run_config: str | None = None,
|
||||
):
|
||||
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
|
||||
|
|
@ -122,11 +120,11 @@ def build_image(
|
|||
normal_deps.extend(api_spec.pip_packages)
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
|
||||
script = str(importlib.resources.files("llama_stack") / "core/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
"--template-or-config",
|
||||
template_or_config,
|
||||
"--distro-or-config",
|
||||
distro_or_config,
|
||||
"--image-name",
|
||||
image_name,
|
||||
"--container-base",
|
||||
|
|
@ -138,19 +136,8 @@ def build_image(
|
|||
# build arguments
|
||||
if run_config is not None:
|
||||
args.extend(["--run-config", run_config])
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||
args = [
|
||||
script,
|
||||
"--env-name",
|
||||
str(image_name),
|
||||
"--build-file-path",
|
||||
str(build_file_path),
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
|
||||
else:
|
||||
script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
"--env-name",
|
||||
|
|
@ -18,10 +18,6 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
|
||||
# mounting is not supported by docker buildx, so we use COPY instead
|
||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||
|
||||
# Mount command for cache container .cache, can be overridden by the user if needed
|
||||
MOUNT_CACHE=${MOUNT_CACHE:-"--mount=type=cache,id=llama-stack-cache,target=/root/.cache"}
|
||||
|
||||
# Path to the run.yaml file in the container
|
||||
RUN_CONFIG_PATH=/app/run.yaml
|
||||
|
||||
|
|
@ -47,7 +43,7 @@ normal_deps=""
|
|||
external_provider_deps=""
|
||||
optional_deps=""
|
||||
run_config=""
|
||||
template_or_config=""
|
||||
distro_or_config=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
|
|
@ -100,12 +96,12 @@ while [[ $# -gt 0 ]]; do
|
|||
run_config="$2"
|
||||
shift 2
|
||||
;;
|
||||
--template-or-config)
|
||||
--distro-or-config)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --template-or-config requires a string value" >&2
|
||||
echo "Error: --distro-or-config requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
template_or_config="$2"
|
||||
distro_or_config="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
|
|
@ -176,18 +172,13 @@ RUN pip install uv
|
|||
EOF
|
||||
fi
|
||||
|
||||
# Set the link mode to copy so that uv doesn't attempt to symlink to the cache directory
|
||||
add_to_container << EOF
|
||||
ENV UV_LINK_MODE=copy
|
||||
EOF
|
||||
|
||||
# Add pip dependencies first since llama-stack is what will change most often
|
||||
# so we can reuse layers.
|
||||
if [ -n "$normal_deps" ]; then
|
||||
read -ra pip_args <<< "$normal_deps"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
fi
|
||||
|
||||
|
|
@ -197,7 +188,7 @@ if [ -n "$optional_deps" ]; then
|
|||
read -ra pip_args <<< "$part"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container <<EOF
|
||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
done
|
||||
fi
|
||||
|
|
@ -208,10 +199,10 @@ if [ -n "$external_provider_deps" ]; then
|
|||
read -ra pip_args <<< "$part"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container <<EOF
|
||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
add_to_container <<EOF
|
||||
RUN python3 - <<PYTHON | $MOUNT_CACHE uv pip install -r -
|
||||
RUN python3 - <<PYTHON | uv pip install --no-cache -r -
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
|
|
@ -293,7 +284,7 @@ COPY $dir $mount_point
|
|||
EOF
|
||||
fi
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install -e $mount_point
|
||||
RUN uv pip install --no-cache -e $mount_point
|
||||
EOF
|
||||
}
|
||||
|
||||
|
|
@ -308,10 +299,10 @@ else
|
|||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install fastapi libcst
|
||||
RUN uv pip install --no-cache fastapi libcst
|
||||
EOF
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-stack==$TEST_PYPI_VERSION
|
||||
|
||||
|
|
@ -323,7 +314,7 @@ EOF
|
|||
SPEC_VERSION="llama-stack"
|
||||
fi
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install $SPEC_VERSION
|
||||
RUN uv pip install --no-cache $SPEC_VERSION
|
||||
EOF
|
||||
fi
|
||||
fi
|
||||
|
|
@ -336,12 +327,11 @@ EOF
|
|||
# If a run config is provided, we use the --config flag
|
||||
if [[ -n "$run_config" ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--config", "$RUN_CONFIG_PATH"]
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$RUN_CONFIG_PATH"]
|
||||
EOF
|
||||
# If a template is provided (not a yaml file), we use the --template flag
|
||||
elif [[ "$template_or_config" != *.yaml ]]; then
|
||||
elif [[ "$distro_or_config" != *.yaml ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"]
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$distro_or_config"]
|
||||
EOF
|
||||
fi
|
||||
|
||||
|
|
@ -6,9 +6,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# TODO: combine this with build_conda_env.sh since it is almost identical
|
||||
# the only difference is that we don't do any conda-specific setup
|
||||
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
|
|
@ -95,6 +92,8 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
|||
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
ENVNAME=""
|
||||
|
||||
# pre-run checks to make sure we can proceed with the installation
|
||||
pre_run_checks() {
|
||||
local env_name="$1"
|
||||
|
|
@ -7,12 +7,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
cleanup() {
|
||||
envname="$1"
|
||||
|
||||
set +x
|
||||
echo "Cleaning up..."
|
||||
conda deactivate
|
||||
conda env remove --name "$envname" -y
|
||||
# For venv environments, no special cleanup is needed
|
||||
# This function exists to avoid "function not found" errors
|
||||
local env_name="$1"
|
||||
echo "Cleanup called for environment: $env_name"
|
||||
}
|
||||
|
||||
handle_int() {
|
||||
|
|
@ -31,19 +29,7 @@ handle_exit() {
|
|||
fi
|
||||
}
|
||||
|
||||
setup_cleanup_handlers() {
|
||||
trap handle_int INT
|
||||
trap handle_exit EXIT
|
||||
|
||||
if is_command_available conda; then
|
||||
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
|
||||
eval "$__conda_setup"
|
||||
conda deactivate
|
||||
else
|
||||
echo "conda is not available"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# check if a command is present
|
||||
is_command_available() {
|
||||
|
|
@ -7,20 +7,20 @@ import logging
|
|||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
from llama_stack.core.datatypes import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
DistributionSpec,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import (
|
||||
from llama_stack.core.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.distribution.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -24,7 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput
|
|||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
||||
|
|
@ -432,8 +432,8 @@ class BuildConfig(BaseModel):
|
|||
|
||||
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
|
||||
image_type: str = Field(
|
||||
default="conda",
|
||||
description="Type of package to build (conda | container | venv)",
|
||||
default="venv",
|
||||
description="Type of package to build (container | venv)",
|
||||
)
|
||||
image_name: str | None = Field(
|
||||
default=None,
|
||||
|
|
@ -12,8 +12,8 @@ from typing import Any
|
|||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
import yaml
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.distribution.datatypes import BuildConfig, StackRunConfig
|
||||
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
|
@ -15,9 +15,9 @@ from llama_stack.apis.inspect import (
|
|||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.server.routes import get_all_api_routes
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
||||
|
|
@ -31,23 +31,23 @@ from pydantic import BaseModel, TypeAdapter
|
|||
from rich.console import Console
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.distribution.request_headers import (
|
||||
from llama_stack.core.build import print_pip_install_help
|
||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.core.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.distribution.stack import (
|
||||
from llama_stack.core.resolver import ProviderRegistry
|
||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
get_stack_run_config_from_distro,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.distribution.utils.exec import in_notebook
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.core.utils.exec import in_notebook
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
|
|
@ -138,14 +138,14 @@ class LibraryClientHttpxResponse:
|
|||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
config_path_or_template_name: str,
|
||||
config_path_or_distro_name: str,
|
||||
skip_logger_removal: bool = False,
|
||||
custom_provider_registry: ProviderRegistry | None = None,
|
||||
provider_data: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||
config_path_or_template_name, custom_provider_registry, provider_data
|
||||
config_path_or_distro_name, custom_provider_registry, provider_data
|
||||
)
|
||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||
self.skip_logger_removal = skip_logger_removal
|
||||
|
|
@ -212,7 +212,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
config_path_or_template_name: str,
|
||||
config_path_or_distro_name: str,
|
||||
custom_provider_registry: ProviderRegistry | None = None,
|
||||
provider_data: dict[str, Any] | None = None,
|
||||
):
|
||||
|
|
@ -222,20 +222,21 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||
|
||||
if config_path_or_template_name.endswith(".yaml"):
|
||||
config_path = Path(config_path_or_template_name)
|
||||
if config_path_or_distro_name.endswith(".yaml"):
|
||||
config_path = Path(config_path_or_distro_name)
|
||||
if not config_path.exists():
|
||||
raise ValueError(f"Config file {config_path} does not exist")
|
||||
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
else:
|
||||
# template
|
||||
config = get_stack_run_config_from_template(config_path_or_template_name)
|
||||
# distribution
|
||||
config = get_stack_run_config_from_distro(config_path_or_distro_name)
|
||||
|
||||
self.config_path_or_template_name = config_path_or_template_name
|
||||
self.config_path_or_distro_name = config_path_or_distro_name
|
||||
self.config = config
|
||||
self.custom_provider_registry = custom_provider_registry
|
||||
self.provider_data = provider_data
|
||||
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
try:
|
||||
|
|
@ -244,11 +245,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
except ModuleNotFoundError as _e:
|
||||
cprint(_e.msg, color="red", file=sys.stderr)
|
||||
cprint(
|
||||
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
||||
"Using llama-stack as a library requires installing dependencies depending on the distribution (providers) you choose.\n",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if self.config_path_or_template_name.endswith(".yaml"):
|
||||
if self.config_path_or_distro_name.endswith(".yaml"):
|
||||
providers: dict[str, list[BuildProvider]] = {}
|
||||
for api, run_providers in self.config.providers.items():
|
||||
for provider in run_providers:
|
||||
|
|
@ -266,7 +267,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
else:
|
||||
prefix = "!" if in_notebook() else ""
|
||||
cprint(
|
||||
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
||||
f"Please run:\n\n{prefix}llama stack build --distro {self.config_path_or_distro_name} --image-type venv\n\n",
|
||||
"yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
|
@ -282,7 +283,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
if not os.environ.get("PYTEST_CURRENT_TEST"):
|
||||
console = Console()
|
||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||
console.print(f"Using config [blue]{self.config_path_or_distro_name}[/blue]:")
|
||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||
console.print(yaml.dump(safe_config, indent=2))
|
||||
|
||||
|
|
@ -297,8 +298,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
stream=False,
|
||||
stream_cls=None,
|
||||
):
|
||||
if not self.route_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized. Please call initialize() first.")
|
||||
|
||||
# Create headers with provider data if available
|
||||
headers = options.headers or {}
|
||||
|
|
@ -353,9 +354,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
cast_to: Any,
|
||||
options: Any,
|
||||
):
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
|
@ -412,9 +411,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
options: Any,
|
||||
stream_cls: Any,
|
||||
):
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
|
@ -474,9 +471,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if not body:
|
||||
return {}
|
||||
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
exclude_params = exclude_params or set()
|
||||
|
||||
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
|
|
@ -10,7 +10,7 @@ import logging
|
|||
from contextlib import AbstractContextManager
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.core.datatypes import User
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
|
|
@ -27,18 +27,18 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
from llama_stack.distribution.datatypes import (
|
||||
from llama_stack.core.client import get_client_impl
|
||||
from llama_stack.core.datatypes import (
|
||||
AccessRule,
|
||||
AutoRoutedProviderSpec,
|
||||
Provider,
|
||||
RoutingTableProviderSpec,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
|
|
@ -183,7 +183,7 @@ def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str,
|
|||
spec=RoutingTableProviderSpec(
|
||||
api=info.routing_table_api,
|
||||
router_api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
module="llama_stack.core.routers",
|
||||
api_dependencies=[],
|
||||
deps__=[f"inner-{info.router_api.value}"],
|
||||
),
|
||||
|
|
@ -197,7 +197,7 @@ def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str,
|
|||
config={},
|
||||
spec=AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
module="llama_stack.core.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
# Add telemetry as an optional dependency to all auto-routed providers
|
||||
|
|
@ -6,9 +6,9 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessRule, RoutedProtocol
|
||||
from llama_stack.distribution.stack import StackRunConfig
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.core.datatypes import AccessRule, RoutedProtocol
|
||||
from llama_stack.core.stack import StackRunConfig
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
|
||||
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.inference import (
|
||||
BatchChatCompletionResponse,
|
||||
BatchCompletionResponse,
|
||||
|
|
@ -79,11 +80,9 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("InferenceRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
|
@ -190,7 +189,7 @@ class InferenceRouter(Inference):
|
|||
sampling_params = SamplingParams()
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
if tool_config:
|
||||
|
|
@ -319,7 +318,7 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
|
|
@ -392,7 +391,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type == ModelType.llm:
|
||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
|
|
@ -432,7 +431,7 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
if model_obj is None:
|
||||
raise ValueError(f"Model '{model}' not found")
|
||||
raise ModelNotFoundError(model)
|
||||
if model_obj.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
|
||||
|
||||
|
|
@ -493,7 +492,7 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
if model_obj is None:
|
||||
raise ValueError(f"Model '{model}' not found")
|
||||
raise ModelNotFoundError(model)
|
||||
if model_obj.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
||||
|
||||
|
|
@ -564,7 +563,7 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
if model_obj is None:
|
||||
raise ValueError(f"Model '{model}' not found")
|
||||
raise ModelNotFoundError(model)
|
||||
if model_obj.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is not an embedding model")
|
||||
|
||||
|
|
@ -43,6 +43,10 @@ class SafetyRouter(Safety):
|
|||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
|
||||
return await self.routing_table.unregister_shield(identifier)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
|
|
@ -7,7 +7,7 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||
from llama_stack.distribution.datatypes import (
|
||||
from llama_stack.core.datatypes import (
|
||||
BenchmarkWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -6,19 +6,20 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.access_control.datatypes import Action
|
||||
from llama_stack.distribution.datatypes import (
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import Action
|
||||
from llama_stack.core.datatypes import (
|
||||
AccessRule,
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
|
@ -59,6 +60,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|||
return await p.unregister_vector_db(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
elif api == Api.safety:
|
||||
return await p.unregister_shield(obj.identifier)
|
||||
elif api == Api.datasetio:
|
||||
return await p.unregister_dataset(obj.identifier)
|
||||
elif api == Api.tool_runtime:
|
||||
|
|
@ -257,7 +260,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
|
|||
models = await routing_table.get_all_with_type("model")
|
||||
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
||||
if len(matching_models) == 0:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
raise ModelNotFoundError(model_id)
|
||||
|
||||
if len(matching_models) > 1:
|
||||
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import DatasetNotFoundError
|
||||
from llama_stack.apis.datasets import (
|
||||
Dataset,
|
||||
DatasetPurpose,
|
||||
|
|
@ -18,7 +19,7 @@ from llama_stack.apis.datasets import (
|
|||
URIDataSource,
|
||||
)
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
from llama_stack.core.datatypes import (
|
||||
DatasetWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -35,7 +36,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||
if dataset is None:
|
||||
raise ValueError(f"Dataset '{dataset_id}' not found")
|
||||
raise DatasetNotFoundError(dataset_id)
|
||||
return dataset
|
||||
|
||||
async def register_dataset(
|
||||
|
|
@ -87,6 +88,4 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
dataset = await self.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
await self.unregister_object(dataset)
|
||||
|
|
@ -7,8 +7,9 @@
|
|||
import time
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.distribution.datatypes import (
|
||||
from llama_stack.core.datatypes import (
|
||||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
)
|
||||
|
|
@ -111,7 +112,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
async def unregister_model(self, model_id: str) -> None:
|
||||
existing_model = await self.get_model(model_id)
|
||||
if existing_model is None:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
raise ModelNotFoundError(model_id)
|
||||
await self.unregister_object(existing_model)
|
||||
|
||||
async def update_registered_models(
|
||||
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import (
|
|||
ScoringFnParams,
|
||||
ScoringFunctions,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
from llama_stack.core.datatypes import (
|
||||
ScoringFnWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||
from llama_stack.distribution.datatypes import (
|
||||
from llama_stack.core.datatypes import (
|
||||
ShieldWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -55,3 +55,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
)
|
||||
await self.register_object(shield)
|
||||
return shield
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
existing_shield = await self.get_shield(identifier)
|
||||
await self.unregister_object(existing_shield)
|
||||
|
|
@ -7,8 +7,9 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ToolGroupNotFoundError
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||
from llama_stack.distribution.datatypes import ToolGroupWithOwner
|
||||
from llama_stack.core.datatypes import ToolGroupWithOwner
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
|
@ -87,7 +88,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group '{toolgroup_id}' not found")
|
||||
raise ToolGroupNotFoundError(toolgroup_id)
|
||||
return tool_group
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
|
|
@ -125,7 +126,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
tool_group = await self.get_tool_group(toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||
raise ToolGroupNotFoundError(toolgroup_id)
|
||||
await self.unregister_object(tool_group)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
|
|
@ -22,7 +23,7 @@ from llama_stack.apis.vector_io.vector_io import (
|
|||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
from llama_stack.core.datatypes import (
|
||||
VectorDBWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -39,7 +40,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
||||
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||
if vector_db is None:
|
||||
raise ValueError(f"Vector DB '{vector_db_id}' not found")
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
return vector_db
|
||||
|
||||
async def register_vector_db(
|
||||
|
|
@ -63,7 +64,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await lookup_model(self, embedding_model)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
raise ModelNotFoundError(embedding_model)
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
|
|
@ -83,8 +84,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
existing_vector_db = await self.get_vector_db(vector_db_id)
|
||||
if existing_vector_db is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
await self.unregister_object(existing_vector_db)
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
|
|
@ -9,10 +9,10 @@ import json
|
|||
import httpx
|
||||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, User
|
||||
from llama_stack.distribution.request_headers import user_from_scope
|
||||
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
||||
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.datatypes import AuthenticationConfig, User
|
||||
from llama_stack.core.request_headers import user_from_scope
|
||||
from llama_stack.core.server.auth_providers import create_auth_provider
|
||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
|
@ -14,7 +14,7 @@ import httpx
|
|||
from jose import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
GitHubTokenAuthConfig,
|
||||
|
|
@ -15,7 +15,7 @@ from starlette.routing import Route
|
|||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
EndpointFunc = Callable[..., Any]
|
||||
|
|
@ -32,36 +32,36 @@ from openai import BadRequestError
|
|||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_template_args, get_config_from_args
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.distribution.datatypes import (
|
||||
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
LoggingConfig,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.server.routes import (
|
||||
from llama_stack.core.resolver import InvalidProviderError
|
||||
from llama_stack.core.server.routes import (
|
||||
find_matching_route,
|
||||
get_all_api_routes,
|
||||
initialize_route_impls,
|
||||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
from llama_stack.core.stack import (
|
||||
cast_image_name_to_string,
|
||||
construct_stack,
|
||||
replace_env_vars,
|
||||
shutdown_stack,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
|
|
@ -377,7 +377,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
||||
add_config_template_args(parser)
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
|
|
@ -396,8 +396,8 @@ def main(args: argparse.Namespace | None = None):
|
|||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
config_or_template = get_config_from_args(args)
|
||||
config_file = resolve_config_or_template(config_or_template, Mode.RUN)
|
||||
config_or_distro = get_config_from_args(args)
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
|
||||
logger_config = None
|
||||
with open(config_file) as fp:
|
||||
|
|
@ -34,14 +34,14 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
|
||||
from llama_stack.core.store.registry import create_dist_registry
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
|
@ -94,6 +94,7 @@ RESOURCES = [
|
|||
|
||||
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
|
||||
REGISTRY_REFRESH_TASK = None
|
||||
TEST_RECORDING_CONTEXT = None
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
|
|
@ -307,6 +308,15 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||
) -> dict[Api, Any]:
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||
impls = await resolve_impls(
|
||||
|
|
@ -352,13 +362,20 @@ async def shutdown_stack(impls: dict[Api, Any]):
|
|||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
try:
|
||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference recording cleanup: {e}")
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
if REGISTRY_REFRESH_TASK:
|
||||
REGISTRY_REFRESH_TASK.cancel()
|
||||
|
||||
|
||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||
logger.info("refreshing registry")
|
||||
logger.debug("refreshing registry")
|
||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||
for routing_table in routing_tables:
|
||||
await routing_table.refresh()
|
||||
|
|
@ -372,12 +389,12 @@ async def refresh_registry_task(impls: dict[Api, Any]):
|
|||
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
||||
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
|
||||
def get_stack_run_config_from_distro(distro: str) -> StackRunConfig:
|
||||
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro}/run.yaml"
|
||||
|
||||
with importlib.resources.as_file(template_path) as path:
|
||||
with importlib.resources.as_file(distro_path) as path:
|
||||
if not path.exists():
|
||||
raise ValueError(f"Template '{template}' not found at {template_path}")
|
||||
raise ValueError(f"Distribution '{distro}' not found at {distro_path}")
|
||||
run_config = yaml.safe_load(path.open())
|
||||
|
||||
return StackRunConfig(**replace_env_vars(run_config))
|
||||
|
|
@ -40,7 +40,6 @@ port="$1"
|
|||
shift
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
# Initialize variables
|
||||
yaml_config=""
|
||||
|
|
@ -75,9 +74,9 @@ while [[ $# -gt 0 ]]; do
|
|||
esac
|
||||
done
|
||||
|
||||
# Check if yaml_config is required based on env_type
|
||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]] && [ -z "$yaml_config" ]; then
|
||||
echo -e "${RED}Error: --config is required for venv and conda environments${NC}" >&2
|
||||
# Check if yaml_config is required
|
||||
if [[ "$env_type" == "venv" ]] && [ -z "$yaml_config" ]; then
|
||||
echo -e "${RED}Error: --config is required for venv environment${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
@ -101,19 +100,14 @@ case "$env_type" in
|
|||
source "$env_path_or_name/bin/activate"
|
||||
fi
|
||||
;;
|
||||
"conda")
|
||||
if ! is_command_available conda; then
|
||||
echo -e "${RED}Error: conda not found" >&2
|
||||
exit 1
|
||||
fi
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda deactivate && conda activate "$env_path_or_name"
|
||||
PYTHON_BINARY="$CONDA_PREFIX/bin/python"
|
||||
;;
|
||||
*)
|
||||
# Handle unsupported env_types here
|
||||
echo -e "${RED}Error: Unsupported environment type '$env_type'. Only 'venv' is supported.${NC}" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||
if [[ "$env_type" == "venv" ]]; then
|
||||
set -x
|
||||
|
||||
if [ -n "$yaml_config" ]; then
|
||||
|
|
@ -122,7 +116,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
|||
yaml_config_arg=""
|
||||
fi
|
||||
|
||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||
$PYTHON_BINARY -m llama_stack.core.server.server \
|
||||
$yaml_config_arg \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
|
|
@ -10,8 +10,8 @@ from typing import Protocol
|
|||
|
||||
import pydantic
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.core.datatypes import RoutableObjectWithProvider
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
|
|
@ -9,7 +9,7 @@
|
|||
1. Start up Llama Stack API server. More details [here](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html).
|
||||
|
||||
```
|
||||
llama stack build --template together --image-type conda
|
||||
llama stack build --distro together --image-type venv
|
||||
|
||||
llama stack run together
|
||||
```
|
||||
|
|
@ -36,7 +36,7 @@ llama-stack-client benchmarks register \
|
|||
3. Start Streamlit UI
|
||||
|
||||
```bash
|
||||
uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
|
||||
uv run --with ".[ui]" streamlit run llama_stack.core/ui/app.py
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def datasets():
|
||||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def benchmarks():
|
||||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def models():
|
||||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def providers():
|
||||
|
|
@ -6,12 +6,12 @@
|
|||
|
||||
from streamlit_option_menu import option_menu
|
||||
|
||||
from llama_stack.distribution.ui.page.distribution.datasets import datasets
|
||||
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
|
||||
from llama_stack.distribution.ui.page.distribution.models import models
|
||||
from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions
|
||||
from llama_stack.distribution.ui.page.distribution.shields import shields
|
||||
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
|
||||
from llama_stack.core.ui.page.distribution.datasets import datasets
|
||||
from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
|
||||
from llama_stack.core.ui.page.distribution.models import models
|
||||
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
|
||||
from llama_stack.core.ui.page.distribution.shields import shields
|
||||
from llama_stack.core.ui.page.distribution.vector_dbs import vector_dbs
|
||||
|
||||
|
||||
def resources_page():
|
||||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def scoring_functions():
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue