mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
Merge branch 'main' into chroma
This commit is contained in:
commit
f856e53323
1881 changed files with 886579 additions and 84028 deletions
|
|
@ -27,7 +27,8 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
|
|
@ -41,6 +42,20 @@ from .openai_responses import (
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ResponseShieldSpec(BaseModel):
|
||||
"""Specification for a shield to apply during response generation.
|
||||
|
||||
:param type: The type/identifier of the shield.
|
||||
"""
|
||||
|
||||
type: str
|
||||
# TODO: more fields to be added for shield configuration
|
||||
|
||||
|
||||
ResponseShield = str | ResponseShieldSpec
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
||||
|
|
@ -471,17 +486,23 @@ class AgentStepResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Agents(Protocol):
|
||||
"""Agents API for creating and interacting with agentic systems.
|
||||
"""Agents
|
||||
|
||||
Main functionalities provided by this API:
|
||||
- Create agents with specific instructions and ability to use tools.
|
||||
- Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn".
|
||||
- Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details).
|
||||
- Agents can be provided with various shields (see the Safety API for more details).
|
||||
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
||||
"""
|
||||
APIs for creating and interacting with agentic systems."""
|
||||
|
||||
@webmethod(route="/agents", method="POST", descriptive_name="create_agent")
|
||||
@webmethod(
|
||||
route="/agents",
|
||||
method="POST",
|
||||
descriptive_name="create_agent",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents",
|
||||
method="POST",
|
||||
descriptive_name="create_agent",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
|
|
@ -494,7 +515,17 @@ class Agents(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
|
||||
route="/agents/{agent_id}/session/{session_id}/turn",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_turn",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_turn",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
|
|
@ -524,6 +555,14 @@ class Agents(Protocol):
|
|||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
descriptive_name="resume_agent_turn",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
descriptive_name="resume_agent_turn",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
|
|
@ -549,6 +588,13 @@ class Agents(Protocol):
|
|||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_turn(
|
||||
self,
|
||||
|
|
@ -568,6 +614,13 @@ class Agents(Protocol):
|
|||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_step(
|
||||
self,
|
||||
|
|
@ -586,7 +639,19 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_session",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_session",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
|
@ -600,7 +665,17 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
|
|
@ -616,7 +691,17 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="DELETE",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def delete_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
|
|
@ -629,7 +714,13 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}",
|
||||
method="DELETE",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def delete_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
|
@ -640,7 +731,8 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET")
|
||||
@webmethod(route="/agents", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||
"""List all agents.
|
||||
|
||||
|
|
@ -650,7 +742,13 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="GET")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
"""Describe an agent by its ID.
|
||||
|
||||
|
|
@ -659,7 +757,13 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/sessions",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
|
@ -682,25 +786,33 @@ class Agents(Protocol):
|
|||
#
|
||||
# Both of these APIs are inherently stateful.
|
||||
|
||||
@webmethod(route="/openai/v1/responses/{response_id}", method="GET")
|
||||
@webmethod(
|
||||
route="/openai/v1/responses/{response_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_openai_response(
|
||||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
"""Retrieve an OpenAI response by its ID.
|
||||
"""Get a model response.
|
||||
|
||||
:param response_id: The ID of the OpenAI response to retrieve.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses", method="POST")
|
||||
@webmethod(route="/openai/v1/responses", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
|
|
@ -708,18 +820,27 @@ class Agents(Protocol):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
shields: Annotated[
|
||||
list[ResponseShield] | None,
|
||||
ExtraBodyField(
|
||||
"List of shields to apply during response generation. Shields provide safety and content moderation."
|
||||
),
|
||||
] = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a new OpenAI response.
|
||||
"""Create a model response.
|
||||
|
||||
:param input: Input message(s) to create the response.
|
||||
:param model: The underlying LLM used for completions.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
|
||||
:param include: (Optional) Additional fields to include in the response.
|
||||
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses", method="GET")
|
||||
@webmethod(route="/openai/v1/responses", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
|
|
@ -727,7 +848,7 @@ class Agents(Protocol):
|
|||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
"""List all OpenAI responses.
|
||||
"""List all responses.
|
||||
|
||||
:param after: The ID of the last response to return.
|
||||
:param limit: The number of responses to return.
|
||||
|
|
@ -737,7 +858,10 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses/{response_id}/input_items", method="GET")
|
||||
@webmethod(
|
||||
route="/openai/v1/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
|
|
@ -747,7 +871,7 @@ class Agents(Protocol):
|
|||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""List input items for a given OpenAI response.
|
||||
"""List input items.
|
||||
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
|
|
@ -759,9 +883,10 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE")
|
||||
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
"""Delete an OpenAI response by its ID.
|
||||
"""Delete a response.
|
||||
|
||||
:param response_id: The ID of the OpenAI response to delete.
|
||||
:returns: An OpenAIDeleteResponseObject
|
||||
|
|
|
|||
|
|
@ -276,13 +276,40 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
|||
tools: list[MCPListToolsTool]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseMCPApprovalRequest(BaseModel):
|
||||
"""
|
||||
A request for human approval of a tool invocation.
|
||||
"""
|
||||
|
||||
arguments: str
|
||||
id: str
|
||||
name: str
|
||||
server_label: str
|
||||
type: Literal["mcp_approval_request"] = "mcp_approval_request"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseMCPApprovalResponse(BaseModel):
|
||||
"""
|
||||
A response to an MCP approval request.
|
||||
"""
|
||||
|
||||
approval_request_id: str
|
||||
approve: bool
|
||||
type: Literal["mcp_approval_response"] = "mcp_approval_response"
|
||||
id: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
OpenAIResponseOutput = Annotated[
|
||||
OpenAIResponseMessage
|
||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseOutputMessageMCPCall
|
||||
| OpenAIResponseOutputMessageMCPListTools,
|
||||
| OpenAIResponseOutputMessageMCPListTools
|
||||
| OpenAIResponseMCPApprovalRequest,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||
|
|
@ -319,6 +346,174 @@ class OpenAIResponseText(BaseModel):
|
|||
format: OpenAIResponseTextFormat | None = None
|
||||
|
||||
|
||||
# Must match type Literals of OpenAIResponseInputToolWebSearch below
|
||||
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||
"""Web search tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Web search tool type variant to use
|
||||
:param search_context_size: (Optional) Size of search context, must be "low", "medium", or "high"
|
||||
"""
|
||||
|
||||
# Must match values of WebSearchToolTypes above
|
||||
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
|
||||
"web_search"
|
||||
)
|
||||
# TODO: actually use search_context_size somewhere...
|
||||
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
|
||||
# TODO: add user_location
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFunction(BaseModel):
|
||||
"""Function tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "function"
|
||||
:param name: Name of the function that can be called
|
||||
:param description: (Optional) Description of what the function does
|
||||
:param parameters: (Optional) JSON schema defining the function's parameters
|
||||
:param strict: (Optional) Whether to enforce strict parameter validation
|
||||
"""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: dict[str, Any] | None
|
||||
strict: bool | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||
"""File search tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "file_search"
|
||||
:param vector_store_ids: List of vector store identifiers to search within
|
||||
:param filters: (Optional) Additional filters to apply to the search
|
||||
:param max_num_results: (Optional) Maximum number of search results to return (1-50)
|
||||
:param ranking_options: (Optional) Options for ranking and scoring search results
|
||||
"""
|
||||
|
||||
type: Literal["file_search"] = "file_search"
|
||||
vector_store_ids: list[str]
|
||||
filters: dict[str, Any] | None = None
|
||||
max_num_results: int | None = Field(default=10, ge=1, le=50)
|
||||
ranking_options: FileSearchRankingOptions | None = None
|
||||
|
||||
|
||||
class ApprovalFilter(BaseModel):
|
||||
"""Filter configuration for MCP tool approval requirements.
|
||||
|
||||
:param always: (Optional) List of tool names that always require approval
|
||||
:param never: (Optional) List of tool names that never require approval
|
||||
"""
|
||||
|
||||
always: list[str] | None = None
|
||||
never: list[str] | None = None
|
||||
|
||||
|
||||
class AllowedToolsFilter(BaseModel):
|
||||
"""Filter configuration for restricting which MCP tools can be used.
|
||||
|
||||
:param tool_names: (Optional) List of specific tool names that are allowed
|
||||
"""
|
||||
|
||||
tool_names: list[str] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolMCP(BaseModel):
|
||||
"""Model Context Protocol (MCP) tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "mcp"
|
||||
:param server_label: Label to identify this MCP server
|
||||
:param server_url: URL endpoint of the MCP server
|
||||
:param headers: (Optional) HTTP headers to include when connecting to the server
|
||||
:param require_approval: Approval requirement for tool calls ("always", "never", or filter)
|
||||
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
|
||||
"""
|
||||
|
||||
type: Literal["mcp"] = "mcp"
|
||||
server_label: str
|
||||
server_url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
|
||||
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
|
||||
allowed_tools: list[str] | AllowedToolsFilter | None = None
|
||||
|
||||
|
||||
OpenAIResponseInputTool = Annotated[
|
||||
OpenAIResponseInputToolWebSearch
|
||||
| OpenAIResponseInputToolFileSearch
|
||||
| OpenAIResponseInputToolFunction
|
||||
| OpenAIResponseInputToolMCP,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseToolMCP(BaseModel):
|
||||
"""Model Context Protocol (MCP) tool configuration for OpenAI response object.
|
||||
|
||||
:param type: Tool type identifier, always "mcp"
|
||||
:param server_label: Label to identify this MCP server
|
||||
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
|
||||
"""
|
||||
|
||||
type: Literal["mcp"] = "mcp"
|
||||
server_label: str
|
||||
allowed_tools: list[str] | AllowedToolsFilter | None = None
|
||||
|
||||
|
||||
OpenAIResponseTool = Annotated[
|
||||
OpenAIResponseInputToolWebSearch
|
||||
| OpenAIResponseInputToolFileSearch
|
||||
| OpenAIResponseInputToolFunction
|
||||
| OpenAIResponseToolMCP, # The only type that differes from that in the inputs is the MCP tool
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseTool, name="OpenAIResponseTool")
|
||||
|
||||
|
||||
class OpenAIResponseUsageOutputTokensDetails(BaseModel):
|
||||
"""Token details for output tokens in OpenAI response usage.
|
||||
|
||||
:param reasoning_tokens: Number of tokens used for reasoning (o1/o3 models)
|
||||
"""
|
||||
|
||||
reasoning_tokens: int | None = None
|
||||
|
||||
|
||||
class OpenAIResponseUsageInputTokensDetails(BaseModel):
|
||||
"""Token details for input tokens in OpenAI response usage.
|
||||
|
||||
:param cached_tokens: Number of tokens retrieved from cache
|
||||
"""
|
||||
|
||||
cached_tokens: int | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseUsage(BaseModel):
|
||||
"""Usage information for OpenAI response.
|
||||
|
||||
:param input_tokens: Number of tokens in the input
|
||||
:param output_tokens: Number of tokens in the output
|
||||
:param total_tokens: Total tokens used (input + output)
|
||||
:param input_tokens_details: Detailed breakdown of input token usage
|
||||
:param output_tokens_details: Detailed breakdown of output token usage
|
||||
"""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
total_tokens: int
|
||||
input_tokens_details: OpenAIResponseUsageInputTokensDetails | None = None
|
||||
output_tokens_details: OpenAIResponseUsageOutputTokensDetails | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObject(BaseModel):
|
||||
"""Complete OpenAI response object containing generation results and metadata.
|
||||
|
|
@ -335,8 +530,9 @@ class OpenAIResponseObject(BaseModel):
|
|||
:param temperature: (Optional) Sampling temperature used for generation
|
||||
:param text: Text formatting configuration for the response
|
||||
:param top_p: (Optional) Nucleus sampling parameter used for generation
|
||||
:param tools: (Optional) An array of tools the model may call while generating a response.
|
||||
:param truncation: (Optional) Truncation strategy applied to the response
|
||||
:param user: (Optional) User identifier associated with the request
|
||||
:param usage: (Optional) Token usage information for the response
|
||||
"""
|
||||
|
||||
created_at: int
|
||||
|
|
@ -353,8 +549,9 @@ class OpenAIResponseObject(BaseModel):
|
|||
# before the field was added. New responses will have this set always.
|
||||
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
top_p: float | None = None
|
||||
tools: list[OpenAIResponseTool] | None = None
|
||||
truncation: str | None = None
|
||||
user: str | None = None
|
||||
usage: OpenAIResponseUsage | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -375,7 +572,7 @@ class OpenAIDeleteResponseObject(BaseModel):
|
|||
class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
||||
"""Streaming event indicating a new response has been created.
|
||||
|
||||
:param response: The newly created response object
|
||||
:param response: The response object that was created
|
||||
:param type: Event type identifier, always "response.created"
|
||||
"""
|
||||
|
||||
|
|
@ -383,11 +580,25 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
|||
type: Literal["response.created"] = "response.created"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseInProgress(BaseModel):
|
||||
"""Streaming event indicating the response remains in progress.
|
||||
|
||||
:param response: Current response state while in progress
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.in_progress"
|
||||
"""
|
||||
|
||||
response: OpenAIResponseObject
|
||||
sequence_number: int
|
||||
type: Literal["response.in_progress"] = "response.in_progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||
"""Streaming event indicating a response has been completed.
|
||||
|
||||
:param response: The completed response object
|
||||
:param response: Completed response object
|
||||
:param type: Event type identifier, always "response.completed"
|
||||
"""
|
||||
|
||||
|
|
@ -395,6 +606,34 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
|||
type: Literal["response.completed"] = "response.completed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseIncomplete(BaseModel):
|
||||
"""Streaming event emitted when a response ends in an incomplete state.
|
||||
|
||||
:param response: Response object describing the incomplete state
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.incomplete"
|
||||
"""
|
||||
|
||||
response: OpenAIResponseObject
|
||||
sequence_number: int
|
||||
type: Literal["response.incomplete"] = "response.incomplete"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseFailed(BaseModel):
|
||||
"""Streaming event emitted when a response fails.
|
||||
|
||||
:param response: Response object describing the failure
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.failed"
|
||||
"""
|
||||
|
||||
response: OpenAIResponseObject
|
||||
sequence_number: int
|
||||
type: Literal["response.failed"] = "response.failed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseOutputItemAdded(BaseModel):
|
||||
"""Streaming event for when a new output item is added to the response.
|
||||
|
|
@ -625,19 +864,46 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartOutputText(BaseModel):
|
||||
"""Text content within a streamed response part.
|
||||
|
||||
:param type: Content part type identifier, always "output_text"
|
||||
:param text: Text emitted for this content part
|
||||
:param annotations: Structured annotations associated with the text
|
||||
:param logprobs: (Optional) Token log probability details
|
||||
"""
|
||||
|
||||
type: Literal["output_text"] = "output_text"
|
||||
text: str
|
||||
# TODO: add annotations, logprobs, etc.
|
||||
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
|
||||
logprobs: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartRefusal(BaseModel):
|
||||
"""Refusal content within a streamed response part.
|
||||
|
||||
:param type: Content part type identifier, always "refusal"
|
||||
:param refusal: Refusal text supplied by the model
|
||||
"""
|
||||
|
||||
type: Literal["refusal"] = "refusal"
|
||||
refusal: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartReasoningText(BaseModel):
|
||||
"""Reasoning text emitted as part of a streamed response.
|
||||
|
||||
:param type: Content part type identifier, always "reasoning_text"
|
||||
:param text: Reasoning text supplied by the model
|
||||
"""
|
||||
|
||||
type: Literal["reasoning_text"] = "reasoning_text"
|
||||
text: str
|
||||
|
||||
|
||||
OpenAIResponseContentPart = Annotated[
|
||||
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal | OpenAIResponseContentPartReasoningText,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
|
||||
|
|
@ -647,15 +913,19 @@ register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
|
|||
class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
|
||||
"""Streaming event for when a new content part is added to a response item.
|
||||
|
||||
:param content_index: Index position of the part within the content array
|
||||
:param response_id: Unique identifier of the response containing this content
|
||||
:param item_id: Unique identifier of the output item containing this content part
|
||||
:param output_index: Index position of the output item in the response
|
||||
:param part: The content part that was added
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.content_part.added"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
part: OpenAIResponseContentPart
|
||||
sequence_number: int
|
||||
type: Literal["response.content_part.added"] = "response.content_part.added"
|
||||
|
|
@ -665,22 +935,269 @@ class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
|
|||
class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel):
|
||||
"""Streaming event for when a content part is completed.
|
||||
|
||||
:param content_index: Index position of the part within the content array
|
||||
:param response_id: Unique identifier of the response containing this content
|
||||
:param item_id: Unique identifier of the output item containing this content part
|
||||
:param output_index: Index position of the output item in the response
|
||||
:param part: The completed content part
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.content_part.done"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
part: OpenAIResponseContentPart
|
||||
sequence_number: int
|
||||
type: Literal["response.content_part.done"] = "response.content_part.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseReasoningTextDelta(BaseModel):
|
||||
"""Streaming event for incremental reasoning text updates.
|
||||
|
||||
:param content_index: Index position of the reasoning content part
|
||||
:param delta: Incremental reasoning text being added
|
||||
:param item_id: Unique identifier of the output item being updated
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.reasoning_text.delta"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
delta: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.reasoning_text.delta"] = "response.reasoning_text.delta"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseReasoningTextDone(BaseModel):
|
||||
"""Streaming event for when reasoning text is completed.
|
||||
|
||||
:param content_index: Index position of the reasoning content part
|
||||
:param text: Final complete reasoning text
|
||||
:param item_id: Unique identifier of the completed output item
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.reasoning_text.done"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
text: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.reasoning_text.done"] = "response.reasoning_text.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartReasoningSummary(BaseModel):
|
||||
"""Reasoning summary part in a streamed response.
|
||||
|
||||
:param type: Content part type identifier, always "summary_text"
|
||||
:param text: Summary text
|
||||
"""
|
||||
|
||||
type: Literal["summary_text"] = "summary_text"
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded(BaseModel):
|
||||
"""Streaming event for when a new reasoning summary part is added.
|
||||
|
||||
:param item_id: Unique identifier of the output item
|
||||
:param output_index: Index position of the output item
|
||||
:param part: The summary part that was added
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param summary_index: Index of the summary part within the reasoning summary
|
||||
:param type: Event type identifier, always "response.reasoning_summary_part.added"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
part: OpenAIResponseContentPartReasoningSummary
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_part.added"] = "response.reasoning_summary_part.added"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseReasoningSummaryPartDone(BaseModel):
|
||||
"""Streaming event for when a reasoning summary part is completed.
|
||||
|
||||
:param item_id: Unique identifier of the output item
|
||||
:param output_index: Index position of the output item
|
||||
:param part: The completed summary part
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param summary_index: Index of the summary part within the reasoning summary
|
||||
:param type: Event type identifier, always "response.reasoning_summary_part.done"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
part: OpenAIResponseContentPartReasoningSummary
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_part.done"] = "response.reasoning_summary_part.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta(BaseModel):
|
||||
"""Streaming event for incremental reasoning summary text updates.
|
||||
|
||||
:param delta: Incremental summary text being added
|
||||
:param item_id: Unique identifier of the output item
|
||||
:param output_index: Index position of the output item
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param summary_index: Index of the summary part within the reasoning summary
|
||||
:param type: Event type identifier, always "response.reasoning_summary_text.delta"
|
||||
"""
|
||||
|
||||
delta: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_text.delta"] = "response.reasoning_summary_text.delta"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseReasoningSummaryTextDone(BaseModel):
|
||||
"""Streaming event for when reasoning summary text is completed.
|
||||
|
||||
:param text: Final complete summary text
|
||||
:param item_id: Unique identifier of the output item
|
||||
:param output_index: Index position of the output item
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param summary_index: Index of the summary part within the reasoning summary
|
||||
:param type: Event type identifier, always "response.reasoning_summary_text.done"
|
||||
"""
|
||||
|
||||
text: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_text.done"] = "response.reasoning_summary_text.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseRefusalDelta(BaseModel):
|
||||
"""Streaming event for incremental refusal text updates.
|
||||
|
||||
:param content_index: Index position of the content part
|
||||
:param delta: Incremental refusal text being added
|
||||
:param item_id: Unique identifier of the output item
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.refusal.delta"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
delta: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.refusal.delta"] = "response.refusal.delta"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseRefusalDone(BaseModel):
|
||||
"""Streaming event for when refusal text is completed.
|
||||
|
||||
:param content_index: Index position of the content part
|
||||
:param refusal: Final complete refusal text
|
||||
:param item_id: Unique identifier of the output item
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.refusal.done"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
refusal: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.refusal.done"] = "response.refusal.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded(BaseModel):
|
||||
"""Streaming event for when an annotation is added to output text.
|
||||
|
||||
:param item_id: Unique identifier of the item to which the annotation is being added
|
||||
:param output_index: Index position of the output item in the response's output array
|
||||
:param content_index: Index position of the content part within the output item
|
||||
:param annotation_index: Index of the annotation within the content part
|
||||
:param annotation: The annotation object being added
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.output_text.annotation.added"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
annotation_index: int
|
||||
annotation: OpenAIResponseAnnotations
|
||||
sequence_number: int
|
||||
type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseFileSearchCallInProgress(BaseModel):
|
||||
"""Streaming event for file search calls in progress.
|
||||
|
||||
:param item_id: Unique identifier of the file search call
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.file_search_call.in_progress"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.file_search_call.in_progress"] = "response.file_search_call.in_progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseFileSearchCallSearching(BaseModel):
|
||||
"""Streaming event for file search currently searching.
|
||||
|
||||
:param item_id: Unique identifier of the file search call
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.file_search_call.searching"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.file_search_call.searching"] = "response.file_search_call.searching"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseFileSearchCallCompleted(BaseModel):
|
||||
"""Streaming event for completed file search calls.
|
||||
|
||||
:param item_id: Unique identifier of the completed file search call
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.file_search_call.completed"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.file_search_call.completed"] = "response.file_search_call.completed"
|
||||
|
||||
|
||||
OpenAIResponseObjectStream = Annotated[
|
||||
OpenAIResponseObjectStreamResponseCreated
|
||||
| OpenAIResponseObjectStreamResponseInProgress
|
||||
| OpenAIResponseObjectStreamResponseOutputItemAdded
|
||||
| OpenAIResponseObjectStreamResponseOutputItemDone
|
||||
| OpenAIResponseObjectStreamResponseOutputTextDelta
|
||||
|
|
@ -700,6 +1217,20 @@ OpenAIResponseObjectStream = Annotated[
|
|||
| OpenAIResponseObjectStreamResponseMcpCallCompleted
|
||||
| OpenAIResponseObjectStreamResponseContentPartAdded
|
||||
| OpenAIResponseObjectStreamResponseContentPartDone
|
||||
| OpenAIResponseObjectStreamResponseReasoningTextDelta
|
||||
| OpenAIResponseObjectStreamResponseReasoningTextDone
|
||||
| OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded
|
||||
| OpenAIResponseObjectStreamResponseReasoningSummaryPartDone
|
||||
| OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta
|
||||
| OpenAIResponseObjectStreamResponseReasoningSummaryTextDone
|
||||
| OpenAIResponseObjectStreamResponseRefusalDelta
|
||||
| OpenAIResponseObjectStreamResponseRefusalDone
|
||||
| OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded
|
||||
| OpenAIResponseObjectStreamResponseFileSearchCallInProgress
|
||||
| OpenAIResponseObjectStreamResponseFileSearchCallSearching
|
||||
| OpenAIResponseObjectStreamResponseFileSearchCallCompleted
|
||||
| OpenAIResponseObjectStreamResponseIncomplete
|
||||
| OpenAIResponseObjectStreamResponseFailed
|
||||
| OpenAIResponseObjectStreamResponseCompleted,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
@ -725,6 +1256,8 @@ OpenAIResponseInput = Annotated[
|
|||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseInputFunctionToolCallOutput
|
||||
| OpenAIResponseMCPApprovalRequest
|
||||
| OpenAIResponseMCPApprovalResponse
|
||||
|
|
||||
# Fallback to the generic message type as a last resort
|
||||
OpenAIResponseMessage,
|
||||
|
|
@ -733,114 +1266,6 @@ OpenAIResponseInput = Annotated[
|
|||
register_schema(OpenAIResponseInput, name="OpenAIResponseInput")
|
||||
|
||||
|
||||
# Must match type Literals of OpenAIResponseInputToolWebSearch below
|
||||
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||
"""Web search tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Web search tool type variant to use
|
||||
:param search_context_size: (Optional) Size of search context, must be "low", "medium", or "high"
|
||||
"""
|
||||
|
||||
# Must match values of WebSearchToolTypes above
|
||||
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
|
||||
"web_search"
|
||||
)
|
||||
# TODO: actually use search_context_size somewhere...
|
||||
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
|
||||
# TODO: add user_location
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFunction(BaseModel):
|
||||
"""Function tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "function"
|
||||
:param name: Name of the function that can be called
|
||||
:param description: (Optional) Description of what the function does
|
||||
:param parameters: (Optional) JSON schema defining the function's parameters
|
||||
:param strict: (Optional) Whether to enforce strict parameter validation
|
||||
"""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: dict[str, Any] | None
|
||||
strict: bool | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||
"""File search tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "file_search"
|
||||
:param vector_store_ids: List of vector store identifiers to search within
|
||||
:param filters: (Optional) Additional filters to apply to the search
|
||||
:param max_num_results: (Optional) Maximum number of search results to return (1-50)
|
||||
:param ranking_options: (Optional) Options for ranking and scoring search results
|
||||
"""
|
||||
|
||||
type: Literal["file_search"] = "file_search"
|
||||
vector_store_ids: list[str]
|
||||
filters: dict[str, Any] | None = None
|
||||
max_num_results: int | None = Field(default=10, ge=1, le=50)
|
||||
ranking_options: FileSearchRankingOptions | None = None
|
||||
|
||||
|
||||
class ApprovalFilter(BaseModel):
|
||||
"""Filter configuration for MCP tool approval requirements.
|
||||
|
||||
:param always: (Optional) List of tool names that always require approval
|
||||
:param never: (Optional) List of tool names that never require approval
|
||||
"""
|
||||
|
||||
always: list[str] | None = None
|
||||
never: list[str] | None = None
|
||||
|
||||
|
||||
class AllowedToolsFilter(BaseModel):
|
||||
"""Filter configuration for restricting which MCP tools can be used.
|
||||
|
||||
:param tool_names: (Optional) List of specific tool names that are allowed
|
||||
"""
|
||||
|
||||
tool_names: list[str] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolMCP(BaseModel):
|
||||
"""Model Context Protocol (MCP) tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "mcp"
|
||||
:param server_label: Label to identify this MCP server
|
||||
:param server_url: URL endpoint of the MCP server
|
||||
:param headers: (Optional) HTTP headers to include when connecting to the server
|
||||
:param require_approval: Approval requirement for tool calls ("always", "never", or filter)
|
||||
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
|
||||
"""
|
||||
|
||||
type: Literal["mcp"] = "mcp"
|
||||
server_label: str
|
||||
server_url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
|
||||
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
|
||||
allowed_tools: list[str] | AllowedToolsFilter | None = None
|
||||
|
||||
|
||||
OpenAIResponseInputTool = Annotated[
|
||||
OpenAIResponseInputToolWebSearch
|
||||
| OpenAIResponseInputToolFileSearch
|
||||
| OpenAIResponseInputToolFunction
|
||||
| OpenAIResponseInputToolMCP,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||
|
||||
|
||||
class ListOpenAIResponseInputItem(BaseModel):
|
||||
"""List container for OpenAI response input items.
|
||||
|
||||
|
|
@ -861,6 +1286,10 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
|||
|
||||
input: list[OpenAIResponseInput]
|
||||
|
||||
def to_response_object(self) -> OpenAIResponseObject:
|
||||
"""Convert to OpenAIResponseObject by excluding input field."""
|
||||
return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIResponseObject(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.inference import (
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BatchInference(Protocol):
|
||||
"""Batch inference API for generating completions and chat completions.
|
||||
|
||||
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
|
||||
|
||||
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
|
||||
including (post-training, evals, etc).
|
||||
"""
|
||||
|
||||
@webmethod(route="/batch-inference/completion", method="POST")
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> Job:
|
||||
"""Generate completions for a batch of content.
|
||||
|
||||
:param model: The model to use for the completion.
|
||||
:param content_batch: The content to complete.
|
||||
:param sampling_params: The sampling parameters to use for the completion.
|
||||
:param response_format: The response format to use for the completion.
|
||||
:param logprobs: The logprobs to use for the completion.
|
||||
:returns: A job for the completion.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> Job:
|
||||
"""Generate chat completions for a batch of messages.
|
||||
|
||||
:param model: The model to use for the chat completion.
|
||||
:param messages_batch: The messages to complete.
|
||||
:param sampling_params: The sampling parameters to use for the completion.
|
||||
:param tools: The tools to use for the chat completion.
|
||||
:param tool_choice: The tool choice to use for the chat completion.
|
||||
:param tool_prompt_format: The tool prompt format to use for the chat completion.
|
||||
:param response_format: The response format to use for the chat completion.
|
||||
:param logprobs: The logprobs to use for the chat completion.
|
||||
:returns: A job for the chat completion.
|
||||
"""
|
||||
...
|
||||
|
|
@ -8,6 +8,7 @@ from typing import Literal, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
try:
|
||||
|
|
@ -42,7 +43,8 @@ class Batches(Protocol):
|
|||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="POST")
|
||||
@webmethod(route="/openai/v1/batches", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
|
|
@ -62,7 +64,8 @@ class Batches(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET")
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch.
|
||||
|
||||
|
|
@ -71,7 +74,8 @@ class Batches(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST")
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress.
|
||||
|
||||
|
|
@ -80,7 +84,8 @@ class Batches(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="GET")
|
||||
@webmethod(route="/openai/v1/batches", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
|
@ -53,7 +54,8 @@ class ListBenchmarksResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Benchmarks(Protocol):
|
||||
@webmethod(route="/eval/benchmarks", method="GET")
|
||||
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||
"""List all benchmarks.
|
||||
|
||||
|
|
@ -61,7 +63,8 @@ class Benchmarks(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
|
@ -73,7 +76,8 @@ class Benchmarks(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST")
|
||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def register_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
|
@ -94,7 +98,8 @@ class Benchmarks(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE")
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
"""Unregister a benchmark.
|
||||
|
||||
|
|
|
|||
|
|
@ -86,3 +86,18 @@ class TokenValidationError(ValueError):
|
|||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ConversationNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced conversation"""
|
||||
|
||||
def __init__(self, conversation_id: str) -> None:
|
||||
super().__init__(conversation_id, "Conversation", "client.conversations.list()")
|
||||
|
||||
|
||||
class InvalidConversationIdError(ValueError):
|
||||
"""raised when a conversation ID has an invalid format"""
|
||||
|
||||
def __init__(self, conversation_id: str) -> None:
|
||||
message = f"Invalid conversation ID '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||
super().__init__(message)
|
||||
|
|
|
|||
31
llama_stack/apis/conversations/__init__.py
Normal file
31
llama_stack/apis/conversations/__init__.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .conversations import (
|
||||
Conversation,
|
||||
ConversationCreateRequest,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemCreateRequest,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
ConversationUpdateRequest,
|
||||
Metadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Conversation",
|
||||
"ConversationCreateRequest",
|
||||
"ConversationDeletedResource",
|
||||
"ConversationItem",
|
||||
"ConversationItemCreateRequest",
|
||||
"ConversationItemDeletedResource",
|
||||
"ConversationItemList",
|
||||
"Conversations",
|
||||
"ConversationUpdateRequest",
|
||||
"Metadata",
|
||||
]
|
||||
260
llama_stack/apis/conversations/conversations.py
Normal file
260
llama_stack/apis/conversations/conversations.py
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Literal, Protocol, runtime_checkable
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
from openai._types import NotGiven
|
||||
from openai.types.responses.response_includable import ResponseIncludable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
Metadata = dict[str, str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Conversation(BaseModel):
|
||||
"""OpenAI-compatible conversation object."""
|
||||
|
||||
id: str = Field(..., description="The unique ID of the conversation.")
|
||||
object: Literal["conversation"] = Field(
|
||||
default="conversation", description="The object type, which is always conversation."
|
||||
)
|
||||
created_at: int = Field(
|
||||
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
|
||||
)
|
||||
metadata: Metadata | None = Field(
|
||||
default=None,
|
||||
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
|
||||
)
|
||||
items: list[dict] | None = Field(
|
||||
default=None,
|
||||
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationMessage(BaseModel):
|
||||
"""OpenAI-compatible message item for conversations."""
|
||||
|
||||
id: str = Field(..., description="unique identifier for this message")
|
||||
content: list[dict] = Field(..., description="message content")
|
||||
role: str = Field(..., description="message role")
|
||||
status: str = Field(..., description="message status")
|
||||
type: Literal["message"] = "message"
|
||||
object: Literal["message"] = "message"
|
||||
|
||||
|
||||
ConversationItem = Annotated[
|
||||
OpenAIResponseMessage
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageMCPCall
|
||||
| OpenAIResponseOutputMessageMCPListTools,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ConversationItem, name="ConversationItem")
|
||||
|
||||
# Using OpenAI types directly caused issues but some notes for reference:
|
||||
# Note that ConversationItem is a Annotated Union of the types below:
|
||||
# from openai.types.responses import *
|
||||
# from openai.types.responses.response_item import *
|
||||
# from openai.types.conversations import ConversationItem
|
||||
# f = [
|
||||
# ResponseFunctionToolCallItem,
|
||||
# ResponseFunctionToolCallOutputItem,
|
||||
# ResponseFileSearchToolCall,
|
||||
# ResponseFunctionWebSearch,
|
||||
# ImageGenerationCall,
|
||||
# ResponseComputerToolCall,
|
||||
# ResponseComputerToolCallOutputItem,
|
||||
# ResponseReasoningItem,
|
||||
# ResponseCodeInterpreterToolCall,
|
||||
# LocalShellCall,
|
||||
# LocalShellCallOutput,
|
||||
# McpListTools,
|
||||
# McpApprovalRequest,
|
||||
# McpApprovalResponse,
|
||||
# McpCall,
|
||||
# ResponseCustomToolCall,
|
||||
# ResponseCustomToolCallOutput
|
||||
# ]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationCreateRequest(BaseModel):
|
||||
"""Request body for creating a conversation."""
|
||||
|
||||
items: list[ConversationItem] | None = Field(
|
||||
default=[],
|
||||
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
max_length=20,
|
||||
)
|
||||
metadata: Metadata | None = Field(
|
||||
default={},
|
||||
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
|
||||
max_length=16,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationUpdateRequest(BaseModel):
|
||||
"""Request body for updating a conversation."""
|
||||
|
||||
metadata: Metadata = Field(
|
||||
...,
|
||||
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationDeletedResource(BaseModel):
|
||||
"""Response for deleted conversation."""
|
||||
|
||||
id: str = Field(..., description="The deleted conversation identifier")
|
||||
object: str = Field(default="conversation.deleted", description="Object type")
|
||||
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemCreateRequest(BaseModel):
|
||||
"""Request body for creating conversation items."""
|
||||
|
||||
items: list[ConversationItem] = Field(
|
||||
...,
|
||||
description="Items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
max_length=20,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemList(BaseModel):
|
||||
"""List of conversation items with pagination."""
|
||||
|
||||
object: str = Field(default="list", description="Object type")
|
||||
data: list[ConversationItem] = Field(..., description="List of conversation items")
|
||||
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
|
||||
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more items available")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemDeletedResource(BaseModel):
|
||||
"""Response for deleted conversation item."""
|
||||
|
||||
id: str = Field(..., description="The deleted item identifier")
|
||||
object: str = Field(default="conversation.item.deleted", description="Object type")
|
||||
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Conversations(Protocol):
|
||||
"""Protocol for conversation management operations."""
|
||||
|
||||
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_conversation(
|
||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||
) -> Conversation:
|
||||
"""Create a conversation.
|
||||
|
||||
:param items: Initial items to include in the conversation context.
|
||||
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||
:returns: The created conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Get a conversation with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:returns: The conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||
"""Update a conversation's metadata with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||
:returns: The updated conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||
"""Delete a conversation with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:returns: The deleted conversation resource.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||
"""Create items in the conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param items: Items to include in the conversation context.
|
||||
:returns: List of created items.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||
"""Retrieve a conversation item.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param item_id: The item identifier.
|
||||
:returns: The conversation item.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list(
|
||||
self,
|
||||
conversation_id: str,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
include: list[ResponseIncludable] | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
|
||||
) -> ConversationItemList:
|
||||
"""List items in the conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param after: An item ID to list items after, used in pagination.
|
||||
:param include: Specify additional output data to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned (1-100, default 20).
|
||||
:param order: The order to return items in (asc or desc, default desc).
|
||||
:returns: List of conversation items.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_conversation_item(
|
||||
self, conversation_id: str, item_id: str
|
||||
) -> ConversationItemDeletedResource:
|
||||
"""Delete a conversation item.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param item_id: The item identifier.
|
||||
:returns: The deleted item resource.
|
||||
"""
|
||||
...
|
||||
|
|
@ -8,6 +8,7 @@ from typing import Any, Protocol, runtime_checkable
|
|||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
|
|
@ -20,7 +21,8 @@ class DatasetIO(Protocol):
|
|||
# keeping for aligning with inference/safety, but this is not used
|
||||
dataset_store: DatasetStore
|
||||
|
||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
|
||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
@ -44,7 +46,10 @@ class DatasetIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||
@webmethod(
|
||||
route="/datasetio/append-rows/{dataset_id:path}", method="POST", deprecated=True, level=LLAMA_STACK_API_V1
|
||||
)
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1BETA)
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
"""Append rows to a dataset.
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Annotated, Any, Literal, Protocol
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
|
@ -145,7 +146,8 @@ class ListDatasetsResponse(BaseModel):
|
|||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets", method="POST")
|
||||
@webmethod(route="/datasets", method="POST", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
|
|
@ -214,7 +216,8 @@ class Datasets(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
@ -226,7 +229,8 @@ class Datasets(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets", method="GET")
|
||||
@webmethod(route="/datasets", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
"""List all datasets.
|
||||
|
||||
|
|
@ -234,7 +238,8 @@ class Datasets(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
|||
|
|
@ -96,7 +96,6 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
:cvar telemetry: Observability and system monitoring
|
||||
:cvar models: Model metadata and management
|
||||
:cvar shields: Safety shield implementations
|
||||
:cvar vector_dbs: Vector database management
|
||||
:cvar datasets: Dataset creation and management
|
||||
:cvar scoring_functions: Scoring function definitions
|
||||
:cvar benchmarks: Benchmark suite management
|
||||
|
|
@ -122,13 +121,13 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
|
||||
models = "models"
|
||||
shields = "shields"
|
||||
vector_dbs = "vector_dbs"
|
||||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
benchmarks = "benchmarks"
|
||||
tool_groups = "tool_groups"
|
||||
files = "files"
|
||||
prompts = "prompts"
|
||||
conversations = "conversations"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from llama_stack.apis.common.job_types import Job
|
|||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
|
@ -83,7 +84,8 @@ class EvaluateResponse(BaseModel):
|
|||
class Eval(Protocol):
|
||||
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
|
@ -97,7 +99,10 @@ class Eval(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
|
@ -115,7 +120,10 @@ class Eval(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of a job.
|
||||
|
||||
|
|
@ -125,7 +133,13 @@ class Eval(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel a job.
|
||||
|
||||
|
|
@ -134,7 +148,15 @@ class Eval(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET", level=LLAMA_STACK_API_V1ALPHA
|
||||
)
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Get the result of a job.
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from fastapi import File, Form, Response, UploadFile
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -103,31 +104,38 @@ class OpenAIFileDeleteResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Files(Protocol):
|
||||
"""Files
|
||||
|
||||
This API is used to upload documents that can be used with other Llama Stack APIs.
|
||||
"""
|
||||
|
||||
# OpenAI Files API Endpoints
|
||||
@webmethod(route="/openai/v1/files", method="POST")
|
||||
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
# TODO: expires_after is producing strange openapi spec, params are showing up as a required w/ oneOf being null
|
||||
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
"""Upload file.
|
||||
|
||||
Upload a file that can be used across various endpoints.
|
||||
|
||||
The file upload should be a multipart form request with:
|
||||
- file: The File object (not file name) to be uploaded.
|
||||
- purpose: The intended purpose of the uploaded file.
|
||||
- expires_after: Optional form values describing expiration for the file. Expected expires_after[anchor] = "created_at", expires_after[seconds] = <int>. Seconds must be between 3600 and 2592000 (1 hour to 30 days).
|
||||
- expires_after: Optional form values describing expiration for the file.
|
||||
|
||||
:param file: The uploaded file object containing content and metadata (filename, content_type, etc.).
|
||||
:param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune").
|
||||
:param expires_after: Optional form values describing expiration for the file.
|
||||
:returns: An OpenAIFileObject representing the uploaded file.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files", method="GET")
|
||||
@webmethod(route="/openai/v1/files", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_files(
|
||||
self,
|
||||
after: str | None = None,
|
||||
|
|
@ -135,7 +143,8 @@ class Files(Protocol):
|
|||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
"""
|
||||
"""List files.
|
||||
|
||||
Returns a list of files that belong to the user's organization.
|
||||
|
||||
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
|
||||
|
|
@ -146,12 +155,14 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="GET")
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_file(
|
||||
self,
|
||||
file_id: str,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
"""Retrieve file.
|
||||
|
||||
Returns information about a specific file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
|
|
@ -159,25 +170,27 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE")
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_file(
|
||||
self,
|
||||
file_id: str,
|
||||
) -> OpenAIFileDeleteResponse:
|
||||
"""
|
||||
Delete a file.
|
||||
"""Delete file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
:returns: An OpenAIFileDeleteResponse indicating successful deletion.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET")
|
||||
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_file_content(
|
||||
self,
|
||||
file_id: str,
|
||||
) -> Response:
|
||||
"""
|
||||
"""Retrieve file content.
|
||||
|
||||
Returns the contents of the specified file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
|
|
|
|||
|
|
@ -14,26 +14,26 @@ from typing import (
|
|||
runtime_checkable,
|
||||
)
|
||||
|
||||
from fastapi import Body
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.telemetry import MetricResponseMixin
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
register_schema(ToolCall)
|
||||
register_schema(ToolParamDefinition)
|
||||
register_schema(ToolDefinition)
|
||||
|
||||
from enum import StrEnum
|
||||
|
|
@ -777,12 +777,14 @@ class OpenAIChoiceDelta(BaseModel):
|
|||
:param refusal: (Optional) The refusal of the delta
|
||||
:param role: (Optional) The role of the delta
|
||||
:param tool_calls: (Optional) The tool calls of the delta
|
||||
:param reasoning_content: (Optional) The reasoning content from the model (non-standard, for o1/o3 models)
|
||||
"""
|
||||
|
||||
content: str | None = None
|
||||
refusal: str | None = None
|
||||
role: str | None = None
|
||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||
reasoning_content: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -817,6 +819,42 @@ class OpenAIChoice(BaseModel):
|
|||
logprobs: OpenAIChoiceLogprobs | None = None
|
||||
|
||||
|
||||
class OpenAIChatCompletionUsageCompletionTokensDetails(BaseModel):
|
||||
"""Token details for output tokens in OpenAI chat completion usage.
|
||||
|
||||
:param reasoning_tokens: Number of tokens used for reasoning (o1/o3 models)
|
||||
"""
|
||||
|
||||
reasoning_tokens: int | None = None
|
||||
|
||||
|
||||
class OpenAIChatCompletionUsagePromptTokensDetails(BaseModel):
|
||||
"""Token details for prompt tokens in OpenAI chat completion usage.
|
||||
|
||||
:param cached_tokens: Number of tokens retrieved from cache
|
||||
"""
|
||||
|
||||
cached_tokens: int | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionUsage(BaseModel):
|
||||
"""Usage information for OpenAI chat completion.
|
||||
|
||||
:param prompt_tokens: Number of tokens in the prompt
|
||||
:param completion_tokens: Number of tokens in the completion
|
||||
:param total_tokens: Total tokens used (prompt + completion)
|
||||
:param input_tokens_details: Detailed breakdown of input token usage
|
||||
:param output_tokens_details: Detailed breakdown of output token usage
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens_details: OpenAIChatCompletionUsagePromptTokensDetails | None = None
|
||||
completion_tokens_details: OpenAIChatCompletionUsageCompletionTokensDetails | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletion(BaseModel):
|
||||
"""Response from an OpenAI-compatible chat completion request.
|
||||
|
|
@ -826,6 +864,7 @@ class OpenAIChatCompletion(BaseModel):
|
|||
:param object: The object type, which will be "chat.completion"
|
||||
:param created: The Unix timestamp in seconds when the chat completion was created
|
||||
:param model: The model that was used to generate the chat completion
|
||||
:param usage: Token usage information for the completion
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
|
@ -833,6 +872,7 @@ class OpenAIChatCompletion(BaseModel):
|
|||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
usage: OpenAIChatCompletionUsage | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -844,6 +884,7 @@ class OpenAIChatCompletionChunk(BaseModel):
|
|||
:param object: The object type, which will be "chat.completion.chunk"
|
||||
:param created: The Unix timestamp in seconds when the chat completion was created
|
||||
:param model: The model that was used to generate the chat completion
|
||||
:param usage: Token usage information (typically included in final chunk with stream_options)
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
|
@ -851,6 +892,7 @@ class OpenAIChatCompletionChunk(BaseModel):
|
|||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
usage: OpenAIChatCompletionUsage | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -913,6 +955,7 @@ class OpenAIEmbeddingData(BaseModel):
|
|||
"""
|
||||
|
||||
object: Literal["embedding"] = "embedding"
|
||||
# TODO: consider dropping str and using openai.types.embeddings.Embedding instead of OpenAIEmbeddingData
|
||||
embedding: list[float] | str
|
||||
index: int
|
||||
|
||||
|
|
@ -973,26 +1016,6 @@ class EmbeddingTaskType(Enum):
|
|||
document = "document"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchCompletionResponse(BaseModel):
|
||||
"""Response from a batch completion request.
|
||||
|
||||
:param batch: List of completion responses, one for each input in the batch
|
||||
"""
|
||||
|
||||
batch: list[CompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionResponse(BaseModel):
|
||||
"""Response from a batch chat completion request.
|
||||
|
||||
:param batch: List of chat completion responses, one for each conversation in the batch
|
||||
"""
|
||||
|
||||
batch: list[ChatCompletionResponse]
|
||||
|
||||
|
||||
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
||||
input_messages: list[OpenAIMessageParam]
|
||||
|
||||
|
|
@ -1015,6 +1038,108 @@ class ListOpenAIChatCompletionResponse(BaseModel):
|
|||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
# extra_body can be accessed via .model_extra
|
||||
@json_schema_type
|
||||
class OpenAICompletionRequestWithExtraBody(BaseModel, extra="allow"):
|
||||
"""Request parameters for OpenAI-compatible completion endpoint.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param prompt: The prompt to generate a completion for.
|
||||
:param best_of: (Optional) The number of completions to generate.
|
||||
:param echo: (Optional) Whether to echo the prompt.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
:param suffix: (Optional) The suffix that should be appended to the completion.
|
||||
"""
|
||||
|
||||
# Standard OpenAI completion parameters
|
||||
model: str
|
||||
prompt: str | list[str] | list[int] | list[list[int]]
|
||||
best_of: int | None = None
|
||||
echo: bool | None = None
|
||||
frequency_penalty: float | None = None
|
||||
logit_bias: dict[str, float] | None = None
|
||||
logprobs: bool | None = None
|
||||
max_tokens: int | None = None
|
||||
n: int | None = None
|
||||
presence_penalty: float | None = None
|
||||
seed: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool | None = None
|
||||
stream_options: dict[str, Any] | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
user: str | None = None
|
||||
suffix: str | None = None
|
||||
|
||||
|
||||
# extra_body can be accessed via .model_extra
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionRequestWithExtraBody(BaseModel, extra="allow"):
|
||||
"""Request parameters for OpenAI-compatible chat completion endpoint.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param function_call: (Optional) The function call to use.
|
||||
:param functions: (Optional) List of functions to use.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param response_format: (Optional) The response format to use.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param tool_choice: (Optional) The tool choice to use.
|
||||
:param tools: (Optional) The tools to use.
|
||||
:param top_logprobs: (Optional) The top log probabilities to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
"""
|
||||
|
||||
# Standard OpenAI chat completion parameters
|
||||
model: str
|
||||
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)]
|
||||
frequency_penalty: float | None = None
|
||||
function_call: str | dict[str, Any] | None = None
|
||||
functions: list[dict[str, Any]] | None = None
|
||||
logit_bias: dict[str, float] | None = None
|
||||
logprobs: bool | None = None
|
||||
max_completion_tokens: int | None = None
|
||||
max_tokens: int | None = None
|
||||
n: int | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
presence_penalty: float | None = None
|
||||
response_format: OpenAIResponseFormatParam | None = None
|
||||
seed: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool | None = None
|
||||
stream_options: dict[str, Any] | None = None
|
||||
temperature: float | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
top_logprobs: int | None = None
|
||||
top_p: float | None = None
|
||||
user: str | None = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class InferenceProvider(Protocol):
|
||||
|
|
@ -1026,136 +1151,7 @@ class InferenceProvider(Protocol):
|
|||
|
||||
model_store: ModelStore | None = None
|
||||
|
||||
@webmethod(route="/inference/completion", method="POST")
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
"""Generate a completion for the given content using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param content: The content to generate a completion for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: If stream=False, returns a CompletionResponse with the full completion.
|
||||
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/batch-completion", method="POST", experimental=True)
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
"""Generate completions for a batch of content using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param content_batch: The content to generate completions for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: A BatchCompletionResponse with the full completions.
|
||||
"""
|
||||
raise NotImplementedError("Batch completion is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/inference/chat-completion", method="POST")
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
"""Generate a chat completion for the given messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation.
|
||||
:param sampling_params: Parameters to control the sampling strategy.
|
||||
:param tools: (Optional) List of tool definitions available to the model.
|
||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||
.. deprecated::
|
||||
Use tool_config instead.
|
||||
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
||||
.. deprecated::
|
||||
Use tool_config instead.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
|
||||
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
|
||||
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:param tool_config: (Optional) Configuration for tool use.
|
||||
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True)
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
"""Generate chat completions for a batch of messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages_batch: The messages to generate completions for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param tools: (Optional) List of tool definitions available to the model.
|
||||
:param tool_config: (Optional) Configuration for tool use.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: A BatchChatCompletionResponse with the full completions.
|
||||
"""
|
||||
raise NotImplementedError("Batch chat completion is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/inference/embeddings", method="POST")
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
"""Generate embeddings for content pieces using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
|
||||
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
||||
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
||||
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/rerank", method="POST", experimental=True)
|
||||
@webmethod(route="/inference/rerank", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -1174,114 +1170,34 @@ class InferenceProvider(Protocol):
|
|||
raise NotImplementedError("Reranking is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/openai/v1/completions", method="POST")
|
||||
@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_completion(
|
||||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAICompletion:
|
||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
"""Create completion.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param prompt: The prompt to generate a completion for.
|
||||
:param best_of: (Optional) The number of completions to generate.
|
||||
:param echo: (Optional) Whether to echo the prompt.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
:param suffix: (Optional) The suffix that should be appended to the completion.
|
||||
Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
:returns: An OpenAICompletion.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="POST")
|
||||
@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
"""Create chat completions.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param function_call: (Optional) The function call to use.
|
||||
:param functions: (Optional) List of functions to use.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param response_format: (Optional) The response format to use.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param tool_choice: (Optional) The tool choice to use.
|
||||
:param tools: (Optional) The tools to use.
|
||||
:param top_logprobs: (Optional) The top log probabilities to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
:returns: An OpenAIChatCompletion.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/embeddings", method="POST")
|
||||
@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -1290,7 +1206,9 @@ class InferenceProvider(Protocol):
|
|||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""Generate OpenAI-compatible embeddings for the given input using the specified model.
|
||||
"""Create embeddings.
|
||||
|
||||
Generate OpenAI-compatible embeddings for the given input using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
|
||||
|
|
@ -1303,14 +1221,17 @@ class InferenceProvider(Protocol):
|
|||
|
||||
|
||||
class Inference(InferenceProvider):
|
||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
"""Inference
|
||||
|
||||
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="GET")
|
||||
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
after: str | None = None,
|
||||
|
|
@ -1318,7 +1239,7 @@ class Inference(InferenceProvider):
|
|||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIChatCompletionResponse:
|
||||
"""List all chat completions.
|
||||
"""List chat completions.
|
||||
|
||||
:param after: The ID of the last chat completion to return.
|
||||
:param limit: The maximum number of chat completions to return.
|
||||
|
|
@ -1328,9 +1249,14 @@ class Inference(InferenceProvider):
|
|||
"""
|
||||
raise NotImplementedError("List chat completions is not implemented")
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
|
||||
@webmethod(
|
||||
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
"""Describe a chat completion by its ID.
|
||||
"""Get chat completion.
|
||||
|
||||
Describe a chat completion by its ID.
|
||||
|
||||
:param completion_id: ID of the chat completion.
|
||||
:returns: A OpenAICompletionWithInputMessages.
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -57,25 +58,36 @@ class ListRoutesResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/inspect/routes", method="GET")
|
||||
"""Inspect
|
||||
|
||||
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
|
||||
"""
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
"""List all available API routes with their methods and implementing providers.
|
||||
"""List routes.
|
||||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
|
||||
:returns: Response containing information about all available routes.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/health", method="GET")
|
||||
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
|
||||
async def health(self) -> HealthInfo:
|
||||
"""Get the current health status of the service.
|
||||
"""Get health status.
|
||||
|
||||
Get the current health status of the service.
|
||||
|
||||
:returns: Health information indicating if the service is operational.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/version", method="GET")
|
||||
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
|
||||
async def version(self) -> VersionInfo:
|
||||
"""Get the version of the service.
|
||||
"""Get version.
|
||||
|
||||
Get the version of the service.
|
||||
|
||||
:returns: Version information containing the service version number.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -102,7 +103,7 @@ class OpenAIListModelsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models", method="GET")
|
||||
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
"""List all models.
|
||||
|
||||
|
|
@ -110,7 +111,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/models", method="GET")
|
||||
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List models using the OpenAI API.
|
||||
|
||||
|
|
@ -118,19 +119,21 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||
@webmethod(route="/models/{model_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> Model:
|
||||
"""Get a model by its identifier.
|
||||
"""Get model.
|
||||
|
||||
Get a model by its identifier.
|
||||
|
||||
:param model_id: The identifier of the model to get.
|
||||
:returns: A Model.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models", method="POST")
|
||||
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -139,7 +142,9 @@ class Models(Protocol):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
"""Register a model.
|
||||
"""Register model.
|
||||
|
||||
Register a model.
|
||||
|
||||
:param model_id: The identifier of the model to register.
|
||||
:param provider_model_id: The identifier of the model in the provider.
|
||||
|
|
@ -150,12 +155,14 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> None:
|
||||
"""Unregister a model.
|
||||
"""Unregister model.
|
||||
|
||||
Unregister a model.
|
||||
|
||||
:param model_id: The identifier of the model to unregister.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.common.training_types import Checkpoint
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
|
@ -283,7 +284,8 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
|||
|
||||
|
||||
class PostTraining(Protocol):
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
|
@ -310,7 +312,8 @@ class PostTraining(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
|
@ -332,7 +335,8 @@ class PostTraining(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
"""Get all training jobs.
|
||||
|
||||
|
|
@ -340,7 +344,8 @@ class PostTraining(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/status", method="GET")
|
||||
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||
"""Get the status of a training job.
|
||||
|
||||
|
|
@ -349,7 +354,8 @@ class PostTraining(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
"""Cancel a training job.
|
||||
|
||||
|
|
@ -357,7 +363,8 @@ class PostTraining(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||
"""Get the artifacts of a training job.
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -93,9 +94,11 @@ class ListPromptsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Prompts(Protocol):
|
||||
"""Protocol for prompt management operations."""
|
||||
"""Prompts
|
||||
|
||||
@webmethod(route="/prompts", method="GET")
|
||||
Protocol for prompt management operations."""
|
||||
|
||||
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
"""List all prompts.
|
||||
|
||||
|
|
@ -103,25 +106,29 @@ class Prompts(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}/versions", method="GET")
|
||||
@webmethod(route="/prompts/{prompt_id}/versions", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_prompt_versions(
|
||||
self,
|
||||
prompt_id: str,
|
||||
) -> ListPromptsResponse:
|
||||
"""List all versions of a specific prompt.
|
||||
"""List prompt versions.
|
||||
|
||||
List all versions of a specific prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to list versions for.
|
||||
:returns: A ListPromptsResponse containing all versions of the prompt.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}", method="GET")
|
||||
@webmethod(route="/prompts/{prompt_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
version: int | None = None,
|
||||
) -> Prompt:
|
||||
"""Get a prompt by its identifier and optional version.
|
||||
"""Get prompt.
|
||||
|
||||
Get a prompt by its identifier and optional version.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to get.
|
||||
:param version: The version of the prompt to get (defaults to latest).
|
||||
|
|
@ -129,13 +136,15 @@ class Prompts(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts", method="POST")
|
||||
@webmethod(route="/prompts", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Create a new prompt.
|
||||
"""Create prompt.
|
||||
|
||||
Create a new prompt.
|
||||
|
||||
:param prompt: The prompt text content with variable placeholders.
|
||||
:param variables: List of variable names that can be used in the prompt template.
|
||||
|
|
@ -143,7 +152,7 @@ class Prompts(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}", method="PUT")
|
||||
@webmethod(route="/prompts/{prompt_id}", method="PUT", level=LLAMA_STACK_API_V1)
|
||||
async def update_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
|
|
@ -152,7 +161,9 @@ class Prompts(Protocol):
|
|||
variables: list[str] | None = None,
|
||||
set_as_default: bool = True,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version).
|
||||
"""Update prompt.
|
||||
|
||||
Update an existing prompt (increments version).
|
||||
|
||||
:param prompt_id: The identifier of the prompt to update.
|
||||
:param prompt: The updated prompt text content.
|
||||
|
|
@ -163,24 +174,28 @@ class Prompts(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}", method="DELETE")
|
||||
@webmethod(route="/prompts/{prompt_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def delete_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
) -> None:
|
||||
"""Delete a prompt.
|
||||
"""Delete prompt.
|
||||
|
||||
Delete a prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT")
|
||||
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT", level=LLAMA_STACK_API_V1)
|
||||
async def set_default_version(
|
||||
self,
|
||||
prompt_id: str,
|
||||
version: int,
|
||||
) -> Prompt:
|
||||
"""Set which version of a prompt should be the default in get_prompt (latest).
|
||||
"""Set prompt version.
|
||||
|
||||
Set which version of a prompt should be the default in get_prompt (latest).
|
||||
|
||||
:param prompt_id: The identifier of the prompt.
|
||||
:param version: The version to set as default.
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.datatypes import HealthResponse
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -41,21 +42,26 @@ class ListProvidersResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""
|
||||
"""Providers
|
||||
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET")
|
||||
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
"""List all available providers.
|
||||
"""List providers.
|
||||
|
||||
List all available providers.
|
||||
|
||||
:returns: A ListProvidersResponse containing information about all providers.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
"""Get detailed information about a specific provider.
|
||||
"""Get provider.
|
||||
|
||||
Get detailed information about a specific provider.
|
||||
|
||||
:param provider_id: The ID of the provider to inspect.
|
||||
:returns: A ProviderInfo object containing the provider's details.
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -95,16 +96,23 @@ class ShieldStore(Protocol):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Safety(Protocol):
|
||||
"""Safety
|
||||
|
||||
OpenAI-compatible Moderations API.
|
||||
"""
|
||||
|
||||
shield_store: ShieldStore
|
||||
|
||||
@webmethod(route="/safety/run-shield", method="POST")
|
||||
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
"""Run a shield.
|
||||
"""Run shield.
|
||||
|
||||
Run a shield.
|
||||
|
||||
:param shield_id: The identifier of the shield to run.
|
||||
:param messages: The messages to run the shield on.
|
||||
|
|
@ -113,9 +121,12 @@ class Safety(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/moderations", method="POST")
|
||||
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
"""Classifies if text and/or image inputs are potentially harmful.
|
||||
"""Create moderation.
|
||||
|
||||
Classifies if text and/or image inputs are potentially harmful.
|
||||
:param input: Input (or inputs) to classify.
|
||||
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
||||
:param model: The content moderation model you would like to use.
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import Any, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
# mapping of metric to value
|
||||
|
|
@ -61,7 +62,7 @@ class ScoringFunctionStore(Protocol):
|
|||
class Scoring(Protocol):
|
||||
scoring_function_store: ScoringFunctionStore
|
||||
|
||||
@webmethod(route="/scoring/score-batch", method="POST")
|
||||
@webmethod(route="/scoring/score-batch", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
@ -77,7 +78,7 @@ class Scoring(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring/score", method="POST")
|
||||
@webmethod(route="/scoring/score", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def score(
|
||||
self,
|
||||
input_rows: list[dict[str, Any]],
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
|
@ -160,7 +161,7 @@ class ListScoringFunctionsResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring-functions", method="GET")
|
||||
@webmethod(route="/scoring-functions", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
"""List all scoring functions.
|
||||
|
||||
|
|
@ -168,7 +169,7 @@ class ScoringFunctions(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
|
||||
"""Get a scoring function by its ID.
|
||||
|
||||
|
|
@ -177,7 +178,7 @@ class ScoringFunctions(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST")
|
||||
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
|
|
@ -198,7 +199,7 @@ class ScoringFunctions(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE")
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
||||
"""Unregister a scoring function.
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -49,7 +50,7 @@ class ListShieldsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields", method="GET")
|
||||
@webmethod(route="/shields", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
"""List all shields.
|
||||
|
||||
|
|
@ -57,7 +58,7 @@ class Shields(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||
@webmethod(route="/shields/{identifier:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_shield(self, identifier: str) -> Shield:
|
||||
"""Get a shield by its identifier.
|
||||
|
||||
|
|
@ -66,7 +67,7 @@ class Shields(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields", method="POST")
|
||||
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
|
|
@ -84,7 +85,7 @@ class Shields(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE")
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
"""Unregister a shield.
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Any, Protocol
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
|
@ -59,7 +60,7 @@ class SyntheticDataGenerationResponse(BaseModel):
|
|||
|
||||
|
||||
class SyntheticDataGeneration(Protocol):
|
||||
@webmethod(route="/synthetic-data-generation/generate")
|
||||
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
|
||||
def synthetic_data_generate(
|
||||
self,
|
||||
dialogs: list[Message],
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from typing import (
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
|
@ -412,7 +413,7 @@ class QueryMetricsResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/telemetry/events", method="POST")
|
||||
@webmethod(route="/telemetry/events", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def log_event(
|
||||
self,
|
||||
event: Event,
|
||||
|
|
@ -425,7 +426,14 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
@webmethod(
|
||||
route="/telemetry/traces",
|
||||
method="POST",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
|
|
@ -443,7 +451,19 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE)
|
||||
@webmethod(
|
||||
route="/telemetry/traces/{trace_id:path}",
|
||||
method="GET",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/telemetry/traces/{trace_id:path}",
|
||||
method="GET",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
"""Get a trace by its ID.
|
||||
|
||||
|
|
@ -453,7 +473,17 @@ class Telemetry(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(
|
||||
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE
|
||||
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}",
|
||||
method="GET",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}",
|
||||
method="GET",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||
"""Get a span by its ID.
|
||||
|
|
@ -464,7 +494,19 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
@webmethod(
|
||||
route="/telemetry/spans/{span_id:path}/tree",
|
||||
method="POST",
|
||||
deprecated=True,
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/telemetry/spans/{span_id:path}/tree",
|
||||
method="POST",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
|
|
@ -480,7 +522,14 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
@webmethod(
|
||||
route="/telemetry/spans",
|
||||
method="POST",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def query_spans(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition],
|
||||
|
|
@ -496,7 +545,8 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/export", method="POST")
|
||||
@webmethod(route="/telemetry/spans/export", method="POST", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/telemetry/spans/export", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def save_spans_to_dataset(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition],
|
||||
|
|
@ -513,7 +563,19 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
@webmethod(
|
||||
route="/telemetry/metrics/{metric_name}",
|
||||
method="POST",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/telemetry/metrics/{metric_name}",
|
||||
method="POST",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def query_metrics(
|
||||
self,
|
||||
metric_name: str,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||
from typing_extensions import runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
|
@ -185,7 +186,7 @@ class RAGQueryConfig(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class RAGToolRuntime(Protocol):
|
||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def insert(
|
||||
self,
|
||||
documents: list[RAGDocument],
|
||||
|
|
@ -200,7 +201,7 @@ class RAGToolRuntime(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/rag-tool/query", method="POST")
|
||||
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
|
|
|
|||
|
|
@ -7,66 +7,35 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from .rag_tool import RAGToolRuntime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolParameter(BaseModel):
|
||||
"""Parameter definition for a tool.
|
||||
|
||||
:param name: Name of the parameter
|
||||
:param parameter_type: Type of the parameter (e.g., string, integer)
|
||||
:param description: Human-readable description of what the parameter does
|
||||
:param required: Whether this parameter is required for tool invocation
|
||||
:param default: (Optional) Default value for the parameter if not provided
|
||||
"""
|
||||
|
||||
name: str
|
||||
parameter_type: str
|
||||
description: str
|
||||
required: bool = Field(default=True)
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
"""A tool that can be invoked by agents.
|
||||
|
||||
:param type: Type of resource, always 'tool'
|
||||
:param toolgroup_id: ID of the tool group this tool belongs to
|
||||
:param description: Human-readable description of what the tool does
|
||||
:param parameters: List of parameters this tool accepts
|
||||
:param metadata: (Optional) Additional metadata about the tool
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||
toolgroup_id: str
|
||||
description: str
|
||||
parameters: list[ToolParameter]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
"""Tool definition used in runtime contexts.
|
||||
|
||||
:param name: Name of the tool
|
||||
:param description: (Optional) Human-readable description of what the tool does
|
||||
:param parameters: (Optional) List of parameters this tool accepts
|
||||
:param input_schema: (Optional) JSON Schema for tool inputs (MCP inputSchema)
|
||||
:param output_schema: (Optional) JSON Schema for tool outputs (MCP outputSchema)
|
||||
:param metadata: (Optional) Additional metadata about the tool
|
||||
:param toolgroup_id: (Optional) ID of the tool group this tool belongs to
|
||||
"""
|
||||
|
||||
toolgroup_id: str | None = None
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: list[ToolParameter] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
|
|
@ -117,7 +86,7 @@ class ToolInvocationResult(BaseModel):
|
|||
|
||||
|
||||
class ToolStore(Protocol):
|
||||
async def get_tool(self, tool_name: str) -> Tool: ...
|
||||
async def get_tool(self, tool_name: str) -> ToolDef: ...
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||
|
||||
|
||||
|
|
@ -130,15 +99,6 @@ class ListToolGroupsResponse(BaseModel):
|
|||
data: list[ToolGroup]
|
||||
|
||||
|
||||
class ListToolsResponse(BaseModel):
|
||||
"""Response containing a list of tools.
|
||||
|
||||
:param data: List of tools
|
||||
"""
|
||||
|
||||
data: list[Tool]
|
||||
|
||||
|
||||
class ListToolDefsResponse(BaseModel):
|
||||
"""Response containing a list of tool definitions.
|
||||
|
||||
|
|
@ -151,7 +111,7 @@ class ListToolDefsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolGroups(Protocol):
|
||||
@webmethod(route="/toolgroups", method="POST")
|
||||
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
|
@ -168,7 +128,7 @@ class ToolGroups(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
|
@ -180,7 +140,7 @@ class ToolGroups(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups", method="GET")
|
||||
@webmethod(route="/toolgroups", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
"""List tool groups with optional provider.
|
||||
|
||||
|
|
@ -188,28 +148,28 @@ class ToolGroups(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools", method="GET")
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
@webmethod(route="/tools", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
"""List tools with optional tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to list tools for.
|
||||
:returns: A ListToolsResponse.
|
||||
:returns: A ListToolDefsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools/{tool_name:path}", method="GET")
|
||||
@webmethod(route="/tools/{tool_name:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
) -> Tool:
|
||||
) -> ToolDef:
|
||||
"""Get a tool by its name.
|
||||
|
||||
:param tool_name: The name of the tool to get.
|
||||
:returns: A Tool.
|
||||
:returns: A ToolDef.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_toolgroup(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
|
@ -238,7 +198,7 @@ class ToolRuntime(Protocol):
|
|||
rag_tool: RAGToolRuntime | None = None
|
||||
|
||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
|
|
@ -250,7 +210,7 @@ class ToolRuntime(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments.
|
||||
|
||||
|
|
|
|||
|
|
@ -4,13 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -60,57 +59,3 @@ class ListVectorDBsResponse(BaseModel):
|
|||
"""
|
||||
|
||||
data: list[VectorDB]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class VectorDBs(Protocol):
|
||||
@webmethod(route="/vector-dbs", method="GET")
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
"""List all vector databases.
|
||||
|
||||
:returns: A ListVectorDBsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
|
||||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
) -> VectorDB:
|
||||
"""Get a vector database by its identifier.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to get.
|
||||
:returns: A VectorDB.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs", method="POST")
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorDB:
|
||||
"""Register a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to register.
|
||||
:param embedding_model: The embedding model to use.
|
||||
:param embedding_dimension: The dimension of the embedding model.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param vector_db_name: The name of the vector database.
|
||||
:param provider_vector_db_id: The identifier of the vector database in the provider.
|
||||
:returns: A VectorDB.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
"""Unregister a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
|
@ -317,7 +318,8 @@ class VectorStoreChunkingStrategyStatic(BaseModel):
|
|||
|
||||
|
||||
VectorStoreChunkingStrategy = Annotated[
|
||||
VectorStoreChunkingStrategyAuto | VectorStoreChunkingStrategyStatic, Field(discriminator="type")
|
||||
VectorStoreChunkingStrategyAuto | VectorStoreChunkingStrategyStatic,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
|
||||
|
||||
|
|
@ -426,6 +428,44 @@ class VectorStoreFileDeleteResponse(BaseModel):
|
|||
deleted: bool = True
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileBatchObject(BaseModel):
|
||||
"""OpenAI Vector Store File Batch object.
|
||||
|
||||
:param id: Unique identifier for the file batch
|
||||
:param object: Object type identifier, always "vector_store.file_batch"
|
||||
:param created_at: Timestamp when the file batch was created
|
||||
:param vector_store_id: ID of the vector store containing the file batch
|
||||
:param status: Current processing status of the file batch
|
||||
:param file_counts: File processing status counts for the batch
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store.file_batch"
|
||||
created_at: int
|
||||
vector_store_id: str
|
||||
status: VectorStoreFileStatus
|
||||
file_counts: VectorStoreFileCounts
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFilesListInBatchResponse(BaseModel):
|
||||
"""Response from listing files in a vector store file batch.
|
||||
|
||||
:param object: Object type identifier, always "list"
|
||||
:param data: List of vector store file objects in the batch
|
||||
:param first_id: (Optional) ID of the first file in the list for pagination
|
||||
:param last_id: (Optional) ID of the last file in the list for pagination
|
||||
:param has_more: Whether there are more files available beyond this page
|
||||
"""
|
||||
|
||||
object: str = "list"
|
||||
data: list[VectorStoreFileObject]
|
||||
first_id: str | None = None
|
||||
last_id: str | None = None
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
class VectorDBStore(Protocol):
|
||||
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
|
||||
|
||||
|
|
@ -437,7 +477,7 @@ class VectorIO(Protocol):
|
|||
|
||||
# this will just block now until chunks are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
@webmethod(route="/vector-io/insert", method="POST")
|
||||
@webmethod(route="/vector-io/insert", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
|
@ -455,7 +495,7 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-io/query", method="POST")
|
||||
@webmethod(route="/vector-io/query", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
|
@ -472,7 +512,8 @@ class VectorIO(Protocol):
|
|||
...
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
@webmethod(route="/openai/v1/vector_stores", method="POST")
|
||||
@webmethod(route="/openai/v1/vector_stores", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str | None = None,
|
||||
|
|
@ -498,7 +539,8 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores", method="GET")
|
||||
@webmethod(route="/openai/v1/vector_stores", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/vector_stores", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
|
|
@ -516,7 +558,10 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="GET")
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
|
@ -528,7 +573,14 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="POST")
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
|
@ -546,7 +598,14 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE")
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
|
@ -558,7 +617,17 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/search", method="POST")
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/search",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/search",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
|
@ -567,7 +636,9 @@ class VectorIO(Protocol):
|
|||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
|
||||
search_mode: (
|
||||
str | None
|
||||
) = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
"""Search for chunks in a vector store.
|
||||
|
||||
|
|
@ -584,7 +655,17 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files", method="POST")
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
|
@ -602,7 +683,17 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files", method="GET")
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
|
@ -624,7 +715,17 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}", method="GET")
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
|
@ -638,7 +739,17 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content", method="GET")
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}/content",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
|
@ -652,7 +763,17 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}", method="POST")
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
|
@ -668,7 +789,17 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}", method="DELETE")
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
|
@ -681,3 +812,113 @@ class VectorIO(Protocol):
|
|||
:returns: A VectorStoreFileDeleteResponse indicating the deletion status.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_ids: list[str],
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Create a vector store file batch.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to create the file batch for.
|
||||
:param file_ids: A list of File IDs that the vector store should use.
|
||||
:param attributes: (Optional) Key-value attributes to store with the files.
|
||||
:param chunking_strategy: (Optional) The chunking strategy used to chunk the file(s). Defaults to auto.
|
||||
:returns: A VectorStoreFileBatchObject representing the created file batch.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Retrieve a vector store file batch.
|
||||
|
||||
:param batch_id: The ID of the file batch to retrieve.
|
||||
:param vector_store_id: The ID of the vector store containing the file batch.
|
||||
:returns: A VectorStoreFileBatchObject representing the file batch.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: str | None = None,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
) -> VectorStoreFilesListInBatchResponse:
|
||||
"""Returns a list of vector store files in a batch.
|
||||
|
||||
:param batch_id: The ID of the file batch to list files from.
|
||||
:param vector_store_id: The ID of the vector store containing the file batch.
|
||||
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list.
|
||||
:param before: A cursor for use in pagination. `before` is an object ID that defines your place in the list.
|
||||
:param filter: Filter by file status. One of in_progress, completed, failed, cancelled.
|
||||
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
||||
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
|
||||
:returns: A VectorStoreFilesListInBatchResponse containing the list of files in the batch.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Cancels a vector store file batch.
|
||||
|
||||
:param batch_id: The ID of the file batch to cancel.
|
||||
:param vector_store_id: The ID of the vector store containing the file batch.
|
||||
:returns: A VectorStoreFileBatchObject representing the cancelled file batch.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -4,4 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_STACK_API_VERSION = "v1"
|
||||
LLAMA_STACK_API_V1 = "v1"
|
||||
LLAMA_STACK_API_V1BETA = "v1beta"
|
||||
LLAMA_STACK_API_V1ALPHA = "v1alpha"
|
||||
|
|
|
|||
|
|
@ -1,495 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
DownloadColumn,
|
||||
Progress,
|
||||
TextColumn,
|
||||
TimeRemainingColumn,
|
||||
TransferSpeedColumn,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||
from llama_stack.models.llama.sku_types import Model
|
||||
|
||||
|
||||
class Download(Subcommand):
|
||||
"""Llama cli for downloading llama toolchain assets"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"download",
|
||||
prog="llama download",
|
||||
description="Download a model from llama.meta.com or Hugging Face Hub",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
setup_download_parser(self.parser)
|
||||
|
||||
|
||||
def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
choices=["meta", "huggingface"],
|
||||
default="meta",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
required=False,
|
||||
help="See `llama model list` or `llama model list --show-all` for the list of available models. Specify multiple model IDs with commas, e.g. --model-id Llama3.2-1B,Llama3.2-3B",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-token",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--meta-url",
|
||||
type=str,
|
||||
required=False,
|
||||
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-parallel",
|
||||
type=int,
|
||||
required=False,
|
||||
default=3,
|
||||
help="Maximum number of concurrent downloads",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore-patterns",
|
||||
type=str,
|
||||
required=False,
|
||||
default="*.safetensors",
|
||||
help="""For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
|
||||
safetensors files to avoid downloading duplicate weights.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--manifest-file",
|
||||
type=str,
|
||||
help="For source=meta, you can download models from a manifest file containing a file => URL mapping",
|
||||
required=False,
|
||||
)
|
||||
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadTask:
|
||||
url: str
|
||||
output_file: str
|
||||
total_size: int = 0
|
||||
downloaded_size: int = 0
|
||||
task_id: int | None = None
|
||||
retries: int = 0
|
||||
max_retries: int = 3
|
||||
|
||||
|
||||
class DownloadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CustomTransferSpeedColumn(TransferSpeedColumn):
|
||||
def render(self, task):
|
||||
if task.finished:
|
||||
return "-"
|
||||
return super().render(task)
|
||||
|
||||
|
||||
class ParallelDownloader:
|
||||
def __init__(
|
||||
self,
|
||||
max_concurrent_downloads: int = 3,
|
||||
buffer_size: int = 1024 * 1024,
|
||||
timeout: int = 30,
|
||||
):
|
||||
self.max_concurrent_downloads = max_concurrent_downloads
|
||||
self.buffer_size = buffer_size
|
||||
self.timeout = timeout
|
||||
self.console = Console()
|
||||
self.progress = Progress(
|
||||
TextColumn("[bold blue]{task.description}"),
|
||||
BarColumn(bar_width=40),
|
||||
"[progress.percentage]{task.percentage:>3.1f}%",
|
||||
DownloadColumn(),
|
||||
CustomTransferSpeedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=self.console,
|
||||
expand=True,
|
||||
)
|
||||
self.client_options = {
|
||||
"timeout": httpx.Timeout(timeout),
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
|
||||
last_exception = None
|
||||
for attempt in range(task.max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < task.max_retries - 1:
|
||||
wait_time = min(30, 2**attempt) # Cap at 30 seconds
|
||||
self.console.print(
|
||||
f"[yellow]Attempt {attempt + 1}/{task.max_retries} failed, "
|
||||
f"retrying in {wait_time} seconds: {str(e)}[/yellow]"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
raise last_exception
|
||||
|
||||
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
|
||||
if task.total_size > 0:
|
||||
self.progress.update(task.task_id, total=task.total_size)
|
||||
return
|
||||
|
||||
async def _get_info():
|
||||
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
try:
|
||||
response = await self.retry_with_exponential_backoff(task, _get_info)
|
||||
|
||||
task.url = str(response.url)
|
||||
task.total_size = int(response.headers.get("Content-Length", 0))
|
||||
|
||||
if task.total_size == 0:
|
||||
raise DownloadError(
|
||||
f"Unable to determine file size for {task.output_file}. "
|
||||
"The server might not support range requests."
|
||||
)
|
||||
|
||||
# Update the progress bar's total size once we know it
|
||||
if task.task_id is not None:
|
||||
self.progress.update(task.task_id, total=task.total_size)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
|
||||
raise
|
||||
|
||||
def verify_file_integrity(self, task: DownloadTask) -> bool:
|
||||
if not os.path.exists(task.output_file):
|
||||
return False
|
||||
return os.path.getsize(task.output_file) == task.total_size
|
||||
|
||||
async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
|
||||
async def _download_chunk():
|
||||
headers = {"Range": f"bytes={start}-{end}"}
|
||||
async with client.stream("GET", task.url, headers=headers, **self.client_options) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
with open(task.output_file, "ab") as file:
|
||||
file.seek(start)
|
||||
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||
file.write(chunk)
|
||||
task.downloaded_size += len(chunk)
|
||||
self.progress.update(
|
||||
task.task_id,
|
||||
completed=task.downloaded_size,
|
||||
)
|
||||
|
||||
try:
|
||||
await self.retry_with_exponential_backoff(task, _download_chunk)
|
||||
except Exception as e:
|
||||
raise DownloadError(
|
||||
f"Failed to download chunk {start}-{end} after {task.max_retries} attempts: {str(e)}"
|
||||
) from e
|
||||
|
||||
async def prepare_download(self, task: DownloadTask) -> None:
|
||||
output_dir = os.path.dirname(task.output_file)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if os.path.exists(task.output_file):
|
||||
task.downloaded_size = os.path.getsize(task.output_file)
|
||||
|
||||
async def download_file(self, task: DownloadTask) -> None:
|
||||
try:
|
||||
async with httpx.AsyncClient(**self.client_options) as client:
|
||||
await self.get_file_info(client, task)
|
||||
|
||||
# Check if file is already downloaded
|
||||
if os.path.exists(task.output_file):
|
||||
if self.verify_file_integrity(task):
|
||||
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
|
||||
self.progress.update(task.task_id, completed=task.total_size)
|
||||
return
|
||||
|
||||
await self.prepare_download(task)
|
||||
|
||||
try:
|
||||
# Split the remaining download into chunks
|
||||
chunk_size = 27_000_000_000 # Cloudfront max chunk size
|
||||
chunks = []
|
||||
|
||||
current_pos = task.downloaded_size
|
||||
while current_pos < task.total_size:
|
||||
chunk_end = min(current_pos + chunk_size - 1, task.total_size - 1)
|
||||
chunks.append((current_pos, chunk_end))
|
||||
current_pos = chunk_end + 1
|
||||
|
||||
# Download chunks in sequence
|
||||
for chunk_start, chunk_end in chunks:
|
||||
await self.download_chunk(client, task, chunk_start, chunk_end)
|
||||
|
||||
except Exception as e:
|
||||
raise DownloadError(f"Download failed: {str(e)}") from e
|
||||
|
||||
except Exception as e:
|
||||
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
||||
|
||||
def has_disk_space(self, tasks: list[DownloadTask]) -> bool:
|
||||
try:
|
||||
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
|
||||
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
||||
free_space = shutil.disk_usage(dir_path).free
|
||||
|
||||
# Add 10% buffer for safety
|
||||
required_space = int(total_remaining_size * 1.1)
|
||||
|
||||
if free_space < required_space:
|
||||
self.console.print(
|
||||
f"[red]Not enough disk space. Required: {required_space // (1024 * 1024)} MB, "
|
||||
f"Available: {free_space // (1024 * 1024)} MB[/red]"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
|
||||
|
||||
async def download_all(self, tasks: list[DownloadTask]) -> None:
|
||||
if not tasks:
|
||||
raise ValueError("No download tasks provided")
|
||||
|
||||
if not os.environ.get("LLAMA_DOWNLOAD_NO_SPACE_CHECK") and not self.has_disk_space(tasks):
|
||||
raise DownloadError("Insufficient disk space for downloads")
|
||||
|
||||
failed_tasks = []
|
||||
|
||||
with self.progress:
|
||||
for task in tasks:
|
||||
desc = f"Downloading {Path(task.output_file).name}"
|
||||
task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
|
||||
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
|
||||
|
||||
async def download_with_semaphore(task: DownloadTask):
|
||||
async with semaphore:
|
||||
try:
|
||||
await self.download_file(task)
|
||||
except Exception as e:
|
||||
failed_tasks.append((task, str(e)))
|
||||
|
||||
await asyncio.gather(*(download_with_semaphore(task) for task in tasks))
|
||||
|
||||
if failed_tasks:
|
||||
self.console.print("\n[red]Some downloads failed:[/red]")
|
||||
for task, error in failed_tasks:
|
||||
self.console.print(f"[red]- {Path(task.output_file).name}: {error}[/red]")
|
||||
raise DownloadError(f"{len(failed_tasks)} downloads failed")
|
||||
|
||||
|
||||
def _hf_download(
|
||||
model: "Model",
|
||||
hf_token: str,
|
||||
ignore_patterns: str,
|
||||
parser: argparse.ArgumentParser,
|
||||
):
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
repo_id = model.huggingface_repo
|
||||
if repo_id is None:
|
||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||
|
||||
output_dir = model_local_dir(model.descriptor())
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
try:
|
||||
true_output_dir = snapshot_download(
|
||||
repo_id,
|
||||
local_dir=output_dir,
|
||||
ignore_patterns=ignore_patterns,
|
||||
token=hf_token,
|
||||
library_name="llama-stack",
|
||||
)
|
||||
except GatedRepoError:
|
||||
parser.error(
|
||||
"It looks like you are trying to access a gated repository. Please ensure you "
|
||||
"have access to the repository and have provided the proper Hugging Face API token "
|
||||
"using the option `--hf-token` or by running `huggingface-cli login`."
|
||||
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.")
|
||||
except Exception as e:
|
||||
parser.error(e)
|
||||
|
||||
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
||||
|
||||
|
||||
def _meta_download(
|
||||
model: "Model",
|
||||
model_id: str,
|
||||
meta_url: str,
|
||||
info: "LlamaDownloadInfo",
|
||||
max_concurrent_downloads: int,
|
||||
):
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
output_dir = Path(model_local_dir(model.descriptor()))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Create download tasks for each file
|
||||
tasks = []
|
||||
for f in info.files:
|
||||
output_file = str(output_dir / f)
|
||||
url = meta_url.replace("*", f"{info.folder}/{f}")
|
||||
total_size = info.pth_size if "consolidated" in f else 0
|
||||
tasks.append(DownloadTask(url=url, output_file=output_file, total_size=total_size, max_retries=3))
|
||||
|
||||
# Initialize and run parallel downloader
|
||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||
asyncio.run(downloader.download_all(tasks))
|
||||
|
||||
cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr)
|
||||
cprint(
|
||||
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
cprint(
|
||||
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
class ModelEntry(BaseModel):
|
||||
model_id: str
|
||||
files: dict[str, str]
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class Manifest(BaseModel):
|
||||
models: list[ModelEntry]
|
||||
expires_on: datetime
|
||||
|
||||
|
||||
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
with open(manifest_file) as f:
|
||||
d = json.load(f)
|
||||
manifest = Manifest(**d)
|
||||
|
||||
if datetime.now(UTC) > manifest.expires_on.astimezone(UTC):
|
||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||
|
||||
console = Console()
|
||||
for entry in manifest.models:
|
||||
console.print(f"[blue]Downloading model {entry.model_id}...[/blue]")
|
||||
output_dir = Path(model_local_dir(entry.model_id))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if any(output_dir.iterdir()):
|
||||
console.print(f"[yellow]Output directory {output_dir} is not empty.[/yellow]")
|
||||
|
||||
while True:
|
||||
resp = input("Do you want to (C)ontinue download or (R)estart completely? (continue/restart): ")
|
||||
if resp.lower() in ["restart", "r"]:
|
||||
shutil.rmtree(output_dir)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
break
|
||||
elif resp.lower() in ["continue", "c"]:
|
||||
console.print("[blue]Continuing download...[/blue]")
|
||||
break
|
||||
else:
|
||||
console.print("[red]Invalid response. Please try again.[/red]")
|
||||
|
||||
# Create download tasks for all files in the manifest
|
||||
tasks = [
|
||||
DownloadTask(url=url, output_file=str(output_dir / fname), max_retries=3)
|
||||
for fname, url in entry.files.items()
|
||||
]
|
||||
|
||||
# Initialize and run parallel downloader
|
||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||
asyncio.run(downloader.download_all(tasks))
|
||||
|
||||
|
||||
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
"""Main download command handler"""
|
||||
try:
|
||||
if args.manifest_file:
|
||||
_download_from_manifest(args.manifest_file, args.max_parallel)
|
||||
return
|
||||
|
||||
if args.model_id is None:
|
||||
parser.error("Please provide a model id")
|
||||
return
|
||||
|
||||
# Handle comma-separated model IDs
|
||||
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
|
||||
|
||||
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
|
||||
|
||||
from .model.safety_models import (
|
||||
prompt_guard_download_info_map,
|
||||
prompt_guard_model_sku_map,
|
||||
)
|
||||
|
||||
prompt_guard_model_sku_map = prompt_guard_model_sku_map()
|
||||
prompt_guard_download_info_map = prompt_guard_download_info_map()
|
||||
|
||||
for model_id in model_ids:
|
||||
if model_id in prompt_guard_model_sku_map.keys():
|
||||
model = prompt_guard_model_sku_map[model_id]
|
||||
info = prompt_guard_download_info_map[model_id]
|
||||
else:
|
||||
model = resolve_model(model_id)
|
||||
if model is None:
|
||||
parser.error(f"Model {model_id} not found")
|
||||
continue
|
||||
info = llama_meta_net_info(model)
|
||||
|
||||
if args.source == "huggingface":
|
||||
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
||||
else:
|
||||
meta_url = args.meta_url or input(
|
||||
f"Please provide the signed URL for model {model_id} you received via email "
|
||||
f"after visiting https://www.llama.com/llama-downloads/ "
|
||||
f"(e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
)
|
||||
if "llamameta.net" not in meta_url:
|
||||
parser.error("Invalid Meta URL provided")
|
||||
_meta_download(model, model_id, meta_url, info, args.max_parallel)
|
||||
|
||||
except Exception as e:
|
||||
parser.error(f"Download failed: {str(e)}")
|
||||
|
|
@ -6,11 +6,8 @@
|
|||
|
||||
import argparse
|
||||
|
||||
from .download import Download
|
||||
from .model import ModelParser
|
||||
from .stack import StackParser
|
||||
from .stack.utils import print_subcommand_description
|
||||
from .verify_download import VerifyDownload
|
||||
|
||||
|
||||
class LlamaCLIParser:
|
||||
|
|
@ -30,10 +27,7 @@ class LlamaCLIParser:
|
|||
subparsers = self.parser.add_subparsers(title="subcommands")
|
||||
|
||||
# Add sub-commands
|
||||
ModelParser.create(subparsers)
|
||||
StackParser.create(subparsers)
|
||||
Download.create(subparsers)
|
||||
VerifyDownload.create(subparsers)
|
||||
|
||||
print_subcommand_description(self.parser, subparsers)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,70 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
|
||||
class ModelDescribe(Subcommand):
|
||||
"""Show details about a model"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"describe",
|
||||
prog="llama model describe",
|
||||
description="Show details about a llama model",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_model_describe_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"-m",
|
||||
"--model-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="See `llama model list` or `llama model list --show-all` for the list of available models",
|
||||
)
|
||||
|
||||
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
|
||||
from .safety_models import prompt_guard_model_sku_map
|
||||
|
||||
prompt_guard_model_map = prompt_guard_model_sku_map()
|
||||
if args.model_id in prompt_guard_model_map.keys():
|
||||
model = prompt_guard_model_map[args.model_id]
|
||||
else:
|
||||
model = resolve_model(args.model_id)
|
||||
|
||||
if model is None:
|
||||
self.parser.error(
|
||||
f"Model {args.model_id} not found; try 'llama model list' for a list of available models."
|
||||
)
|
||||
return
|
||||
|
||||
headers = [
|
||||
"Model",
|
||||
model.descriptor(),
|
||||
]
|
||||
|
||||
rows = [
|
||||
("Hugging Face ID", model.huggingface_repo or "<Not Available>"),
|
||||
("Description", model.description),
|
||||
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
|
||||
("Weights format", model.quantization_format.value),
|
||||
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
||||
]
|
||||
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class ModelDownload(Subcommand):
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"download",
|
||||
prog="llama model download",
|
||||
description="Download a model from llama.meta.com or Hugging Face Hub",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
from llama_stack.cli.download import setup_download_parser
|
||||
|
||||
setup_download_parser(self.parser)
|
||||
|
|
@ -1,119 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
|
||||
|
||||
def _get_model_size(model_dir):
|
||||
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
|
||||
|
||||
|
||||
def _convert_to_model_descriptor(model):
|
||||
for m in all_registered_models():
|
||||
if model == m.descriptor().replace(":", "-"):
|
||||
return str(m.descriptor())
|
||||
return str(model)
|
||||
|
||||
|
||||
def _run_model_list_downloaded_cmd() -> None:
|
||||
headers = ["Model", "Size", "Modified Time"]
|
||||
|
||||
rows = []
|
||||
for model in os.listdir(DEFAULT_CHECKPOINT_DIR):
|
||||
abs_path = os.path.join(DEFAULT_CHECKPOINT_DIR, model)
|
||||
space_usage = _get_model_size(abs_path)
|
||||
model_size = f"{space_usage / (1024**3):.2f} GB"
|
||||
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
|
||||
rows.append(
|
||||
[
|
||||
_convert_to_model_descriptor(model),
|
||||
model_size,
|
||||
modified_time,
|
||||
]
|
||||
)
|
||||
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
|
||||
|
||||
class ModelList(Subcommand):
|
||||
"""List available llama models"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"list",
|
||||
prog="llama model list",
|
||||
description="Show available llama models",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_model_list_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"--show-all",
|
||||
action="store_true",
|
||||
help="Show all models (not just defaults)",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--downloaded",
|
||||
action="store_true",
|
||||
help="List the downloaded models",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-s",
|
||||
"--search",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Search for the input string as a substring in the model descriptor(ID)",
|
||||
)
|
||||
|
||||
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
from .safety_models import prompt_guard_model_skus
|
||||
|
||||
if args.downloaded:
|
||||
return _run_model_list_downloaded_cmd()
|
||||
|
||||
headers = [
|
||||
"Model Descriptor(ID)",
|
||||
"Hugging Face Repo",
|
||||
"Context Length",
|
||||
]
|
||||
|
||||
rows = []
|
||||
for model in all_registered_models() + prompt_guard_model_skus():
|
||||
if not args.show_all and not model.is_featured:
|
||||
continue
|
||||
|
||||
descriptor = model.descriptor()
|
||||
if not args.search or args.search.lower() in descriptor.lower():
|
||||
rows.append(
|
||||
[
|
||||
descriptor,
|
||||
model.huggingface_repo,
|
||||
f"{model.max_seq_length // 1024}K",
|
||||
]
|
||||
)
|
||||
if len(rows) == 0:
|
||||
print(f"Did not find any model matching `{args.search}`.")
|
||||
else:
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.model.describe import ModelDescribe
|
||||
from llama_stack.cli.model.download import ModelDownload
|
||||
from llama_stack.cli.model.list import ModelList
|
||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||
from llama_stack.cli.model.remove import ModelRemove
|
||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class ModelParser(Subcommand):
|
||||
"""Llama cli for model interface apis"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"model",
|
||||
prog="llama model",
|
||||
description="Work with llama models",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||
|
||||
subparsers = self.parser.add_subparsers(title="model_subcommands")
|
||||
|
||||
# Add sub-commands
|
||||
ModelDownload.create(subparsers)
|
||||
ModelList.create(subparsers)
|
||||
ModelPromptFormat.create(subparsers)
|
||||
ModelDescribe.create(subparsers)
|
||||
ModelVerifyDownload.create(subparsers)
|
||||
ModelRemove.create(subparsers)
|
||||
|
||||
print_subcommand_description(self.parser, subparsers)
|
||||
|
|
@ -1,133 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import textwrap
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class ModelPromptFormat(Subcommand):
|
||||
"""Llama model cli for describe a model prompt format (message formats)"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"prompt-format",
|
||||
prog="llama model prompt-format",
|
||||
description="Show llama model message formats",
|
||||
epilog=textwrap.dedent(
|
||||
"""
|
||||
Example:
|
||||
llama model prompt-format <options>
|
||||
"""
|
||||
),
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_model_template_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"-m",
|
||||
"--model-name",
|
||||
type=str,
|
||||
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
|
||||
"(Run `llama model list` to see a list of valid model names)",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-l",
|
||||
"--list",
|
||||
action="store_true",
|
||||
help="List all available models",
|
||||
)
|
||||
|
||||
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
||||
import importlib.resources
|
||||
|
||||
# Only Llama 3.1 and 3.2 are supported
|
||||
supported_model_ids = [
|
||||
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
]
|
||||
|
||||
model_list = [m.value for m in supported_model_ids]
|
||||
|
||||
if args.list:
|
||||
headers = ["Model(s)"]
|
||||
rows = []
|
||||
for m in model_list:
|
||||
rows.append(
|
||||
[
|
||||
m,
|
||||
]
|
||||
)
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
model_id = CoreModelId(args.model_name)
|
||||
except ValueError:
|
||||
self.parser.error(
|
||||
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
if model_id not in supported_model_ids:
|
||||
self.parser.error(
|
||||
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
|
||||
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
|
||||
llama_3_2_vision_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "vision_prompt_format.md"
|
||||
if model_family(model_id) == ModelFamily.llama3_1:
|
||||
with importlib.resources.as_file(llama_3_1_file) as f:
|
||||
content = f.open("r").read()
|
||||
elif model_family(model_id) == ModelFamily.llama3_2:
|
||||
if is_multimodal(model_id):
|
||||
with importlib.resources.as_file(llama_3_2_vision_file) as f:
|
||||
content = f.open("r").read()
|
||||
else:
|
||||
with importlib.resources.as_file(llama_3_2_text_file) as f:
|
||||
content = f.open("r").read()
|
||||
|
||||
render_markdown_to_pager(content)
|
||||
|
||||
|
||||
def render_markdown_to_pager(markdown_content: str):
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.style import Style
|
||||
from rich.text import Text
|
||||
|
||||
class LeftAlignedHeaderMarkdown(Markdown):
|
||||
def parse_header(self, token):
|
||||
level = token.type.count("h")
|
||||
content = Text(token.content)
|
||||
header_style = Style(color="bright_blue", bold=True)
|
||||
header = Text(f"{'#' * level} ", style=header_style) + content
|
||||
self.add_text(header)
|
||||
|
||||
# Render the Markdown
|
||||
md = LeftAlignedHeaderMarkdown(markdown_content)
|
||||
|
||||
# Capture the rendered output
|
||||
output = StringIO()
|
||||
console = Console(file=output, force_terminal=True, width=100) # Set a fixed width
|
||||
console.print(md)
|
||||
rendered_content = output.getvalue()
|
||||
print(rendered_content)
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
|
||||
class ModelRemove(Subcommand):
|
||||
"""Remove the downloaded llama model"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"remove",
|
||||
prog="llama model remove",
|
||||
description="Remove the downloaded llama model",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_model_remove_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
required=True,
|
||||
help="Specify the llama downloaded model name, see `llama model list --downloaded`",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-f",
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Used to forcefully remove the llama model from the storage without further confirmation",
|
||||
)
|
||||
|
||||
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
|
||||
from .safety_models import prompt_guard_model_sku_map
|
||||
|
||||
prompt_guard_model_map = prompt_guard_model_sku_map()
|
||||
|
||||
if args.model in prompt_guard_model_map.keys():
|
||||
model = prompt_guard_model_map[args.model]
|
||||
else:
|
||||
model = resolve_model(args.model)
|
||||
|
||||
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model.replace(":", "-"))
|
||||
|
||||
if model is None or not os.path.isdir(model_path):
|
||||
print(f"'{args.model}' is not a valid llama model or does not exist.")
|
||||
return
|
||||
|
||||
if args.force:
|
||||
shutil.rmtree(model_path)
|
||||
print(f"{args.model} removed.")
|
||||
else:
|
||||
if input(f"Are you sure you want to remove {args.model}? (y/n): ").strip().lower() == "y":
|
||||
shutil.rmtree(model_path)
|
||||
print(f"{args.model} removed.")
|
||||
else:
|
||||
print("Removal aborted.")
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
|
||||
|
||||
|
||||
class PromptGuardModel(BaseModel):
|
||||
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
||||
|
||||
model_id: str
|
||||
huggingface_repo: str
|
||||
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
|
||||
is_featured: bool = False
|
||||
max_seq_length: int = 512
|
||||
is_instruct_model: bool = False
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
arch_args: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def descriptor(self) -> str:
|
||||
return self.model_id
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
def prompt_guard_model_skus():
|
||||
return [
|
||||
PromptGuardModel(model_id="Prompt-Guard-86M", huggingface_repo="meta-llama/Prompt-Guard-86M"),
|
||||
PromptGuardModel(
|
||||
model_id="Llama-Prompt-Guard-2-86M",
|
||||
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-86M",
|
||||
),
|
||||
PromptGuardModel(
|
||||
model_id="Llama-Prompt-Guard-2-22M",
|
||||
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-22M",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def prompt_guard_model_sku_map() -> dict[str, Any]:
|
||||
return {model.model_id: model for model in prompt_guard_model_skus()}
|
||||
|
||||
|
||||
def prompt_guard_download_info_map() -> dict[str, LlamaDownloadInfo]:
|
||||
return {
|
||||
model.model_id: LlamaDownloadInfo(
|
||||
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
|
||||
files=[
|
||||
"model.safetensors",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
],
|
||||
pth_size=1,
|
||||
)
|
||||
for model in prompt_guard_model_skus()
|
||||
}
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class ModelVerifyDownload(Subcommand):
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"verify-download",
|
||||
prog="llama model verify-download",
|
||||
description="Verify the downloaded checkpoints' checksums for models downloaded from Meta",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
from llama_stack.cli.verify_download import setup_verify_download_parser
|
||||
|
||||
setup_verify_download_parser(self.parser)
|
||||
|
|
@ -444,12 +444,24 @@ def _run_stack_build_command_from_build_config(
|
|||
|
||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
||||
cprint(
|
||||
"You can run the new Llama Stack distro via: "
|
||||
+ colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"),
|
||||
color="green",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if build_config.image_type == LlamaStackImageType.VENV:
|
||||
cprint(
|
||||
"You can run the new Llama Stack distro (after activating "
|
||||
+ colored(image_name, "cyan")
|
||||
+ ") via: "
|
||||
+ colored(f"llama stack run {run_config_file}", "blue"),
|
||||
color="green",
|
||||
file=sys.stderr,
|
||||
)
|
||||
elif build_config.image_type == LlamaStackImageType.CONTAINER:
|
||||
cprint(
|
||||
"You can run the container with: "
|
||||
+ colored(
|
||||
f"docker run -p 8321:8321 -v ~/.llama:/root/.llama localhost/{image_name} --port 8321", "blue"
|
||||
),
|
||||
color="green",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return distro_path
|
||||
else:
|
||||
return _generate_run_config(build_config, build_dir, image_name)
|
||||
|
|
|
|||
|
|
@ -6,11 +6,18 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import ssl
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
import yaml
|
||||
|
||||
from llama_stack.cli.stack.utils import ImageType
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.core.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
|
@ -48,18 +55,12 @@ class StackRun(Subcommand):
|
|||
"--image-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the image to run. Defaults to the current environment",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
|
||||
metavar="KEY=VALUE",
|
||||
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--image-type",
|
||||
type=str,
|
||||
help="Image Type used during the build. This can be only venv.",
|
||||
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
|
||||
choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value],
|
||||
)
|
||||
self.parser.add_argument(
|
||||
|
|
@ -68,48 +69,22 @@ class StackRun(Subcommand):
|
|||
help="Start the UI server",
|
||||
)
|
||||
|
||||
def _resolve_config_and_distro(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
|
||||
"""Resolve config file path and distribution name from args.config"""
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
if not args.config:
|
||||
return None, None
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
distro_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a distribution
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "distributions" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
distro_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
return config_file, distro_name
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.core.utils.exec import formulate_run_args, run_command
|
||||
|
||||
if args.image_type or args.image_name:
|
||||
self.parser.error(
|
||||
"The --image-type and --image-name flags are no longer supported.\n\n"
|
||||
"Please activate your virtual environment manually before running `llama stack run`.\n\n"
|
||||
"For example:\n"
|
||||
" source /path/to/venv/bin/activate\n"
|
||||
" llama stack run <config>\n"
|
||||
)
|
||||
|
||||
if args.enable_ui:
|
||||
self._start_ui_development_server(args.port)
|
||||
image_type, image_name = args.image_type, args.image_name
|
||||
|
||||
if args.config:
|
||||
try:
|
||||
|
|
@ -121,10 +96,6 @@ class StackRun(Subcommand):
|
|||
else:
|
||||
config_file = None
|
||||
|
||||
# Check if config is required based on image type
|
||||
if image_type == ImageType.VENV.value and not config_file:
|
||||
self.parser.error("Config file is required for venv environment")
|
||||
|
||||
if config_file:
|
||||
logger.info(f"Using run configuration: {config_file}")
|
||||
|
||||
|
|
@ -139,50 +110,67 @@ class StackRun(Subcommand):
|
|||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||
except AttributeError as e:
|
||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||
|
||||
self._uvicorn_run(config_file, args)
|
||||
|
||||
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
|
||||
if not config_file:
|
||||
self.parser.error("Config file is required")
|
||||
|
||||
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
port = args.port or config.server.port
|
||||
host = config.server.host or ["::", "0.0.0.0"]
|
||||
|
||||
# Set the config file in environment so create_app can find it
|
||||
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
||||
|
||||
uvicorn_config = {
|
||||
"factory": True,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
|
||||
keyfile = config.server.tls_keyfile
|
||||
certfile = config.server.tls_certfile
|
||||
if keyfile and certfile:
|
||||
uvicorn_config["ssl_keyfile"] = config.server.tls_keyfile
|
||||
uvicorn_config["ssl_certfile"] = config.server.tls_certfile
|
||||
if config.server.tls_cafile:
|
||||
uvicorn_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
uvicorn_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
|
||||
logger.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
config = None
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
# If neither image type nor image name is provided, assume the server should be run directly
|
||||
# using the current environment packages.
|
||||
if not image_type and not image_name:
|
||||
logger.info("No image type or image name provided. Assuming environment packages.")
|
||||
from llama_stack.core.server.server import main as server_main
|
||||
logger.info(f"Listening on {host}:{port}")
|
||||
|
||||
# Build the server args from the current args passed to the CLI
|
||||
server_args = argparse.Namespace()
|
||||
for arg in vars(args):
|
||||
# If this is a function, avoid passing it
|
||||
# "args" contains:
|
||||
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
|
||||
if callable(getattr(args, arg)):
|
||||
continue
|
||||
if arg == "config":
|
||||
server_args.config = str(config_file)
|
||||
else:
|
||||
setattr(server_args, arg, getattr(args, arg))
|
||||
|
||||
# Run the server
|
||||
server_main(server_args)
|
||||
else:
|
||||
run_args = formulate_run_args(image_type, image_name)
|
||||
|
||||
run_args.extend([str(args.port)])
|
||||
|
||||
if config_file:
|
||||
run_args.extend(["--config", str(config_file)])
|
||||
|
||||
if args.env:
|
||||
for env_var in args.env:
|
||||
if "=" not in env_var:
|
||||
self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
|
||||
return
|
||||
key, value = env_var.split("=", 1) # split on first = only
|
||||
if not key:
|
||||
self.parser.error(f"Environment variable '{env_var}' has empty key")
|
||||
return
|
||||
run_args.extend(["--env", f"{key}={value}"])
|
||||
|
||||
run_command(run_args)
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
uvicorn.run("llama_stack.core.server.server:create_app", **uvicorn_config)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
|
||||
def _start_ui_development_server(self, stack_server_port: int):
|
||||
logger.info("Attempting to start UI development server...")
|
||||
|
|
|
|||
|
|
@ -1,141 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
filename: str
|
||||
expected_hash: str
|
||||
actual_hash: str | None
|
||||
exists: bool
|
||||
matches: bool
|
||||
|
||||
|
||||
class VerifyDownload(Subcommand):
|
||||
"""Llama cli for verifying downloaded model files"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"verify-download",
|
||||
prog="llama verify-download",
|
||||
description="Verify integrity of downloaded model files",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
setup_verify_download_parser(self.parser)
|
||||
|
||||
|
||||
def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
required=True,
|
||||
help="Model ID to verify (only for models downloaded from Meta)",
|
||||
)
|
||||
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
|
||||
|
||||
|
||||
def calculate_sha256(filepath: Path, chunk_size: int = 8192) -> str:
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(filepath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
sha256_hash.update(chunk)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def load_checksums(checklist_path: Path) -> dict[str, str]:
|
||||
checksums = {}
|
||||
with open(checklist_path) as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
sha256sum, filepath = line.strip().split(" ", 1)
|
||||
# Remove leading './' if present
|
||||
filepath = filepath.lstrip("./")
|
||||
checksums[filepath] = sha256sum
|
||||
return checksums
|
||||
|
||||
|
||||
def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -> list[VerificationResult]:
|
||||
results = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
console=console,
|
||||
) as progress:
|
||||
for filepath, expected_hash in checksums.items():
|
||||
full_path = model_dir / filepath
|
||||
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
|
||||
|
||||
exists = full_path.exists()
|
||||
actual_hash = None
|
||||
matches = False
|
||||
|
||||
if exists:
|
||||
actual_hash = calculate_sha256(full_path)
|
||||
matches = actual_hash == expected_hash
|
||||
|
||||
results.append(
|
||||
VerificationResult(
|
||||
filename=filepath,
|
||||
expected_hash=expected_hash,
|
||||
actual_hash=actual_hash,
|
||||
exists=exists,
|
||||
matches=matches,
|
||||
)
|
||||
)
|
||||
|
||||
progress.remove_task(task_id)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
console = Console()
|
||||
model_dir = Path(model_local_dir(args.model_id))
|
||||
checklist_path = model_dir / "checklist.chk"
|
||||
|
||||
if not model_dir.exists():
|
||||
parser.error(f"Model directory not found: {model_dir}")
|
||||
|
||||
if not checklist_path.exists():
|
||||
parser.error(f"Checklist file not found: {checklist_path}")
|
||||
|
||||
checksums = load_checksums(checklist_path)
|
||||
results = verify_files(model_dir, checksums, console)
|
||||
|
||||
# Print results
|
||||
console.print("\nVerification Results:")
|
||||
|
||||
all_good = True
|
||||
for result in results:
|
||||
if not result.exists:
|
||||
console.print(f"[red]❌ {result.filename}: File not found[/red]")
|
||||
all_good = False
|
||||
elif not result.matches:
|
||||
console.print(
|
||||
f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
|
||||
f" Expected: {result.expected_hash}\n"
|
||||
f" Got: {result.actual_hash}"
|
||||
)
|
||||
all_good = False
|
||||
else:
|
||||
console.print(f"[green]✓ {result.filename}: Verified[/green]")
|
||||
|
||||
if all_good:
|
||||
console.print("\n[green]All files verified successfully![/green]")
|
||||
|
|
@ -147,7 +147,7 @@ WORKDIR /app
|
|||
|
||||
RUN dnf -y update && dnf install -y iputils git net-tools wget \
|
||||
vim-minimal python3.12 python3.12-pip python3.12-wheel \
|
||||
python3.12-setuptools python3.12-devel gcc make && \
|
||||
python3.12-setuptools python3.12-devel gcc gcc-c++ make && \
|
||||
ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
|
|
@ -164,7 +164,7 @@ RUN apt-get update && apt-get install -y \
|
|||
procps psmisc lsof \
|
||||
traceroute \
|
||||
bubblewrap \
|
||||
gcc \
|
||||
gcc g++ \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
|
|
@ -324,14 +324,14 @@ fi
|
|||
RUN pip uninstall -y uv
|
||||
EOF
|
||||
|
||||
# If a run config is provided, we use the --config flag
|
||||
# If a run config is provided, we use the llama stack CLI
|
||||
if [[ -n "$run_config" ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$RUN_CONFIG_PATH"]
|
||||
ENTRYPOINT ["llama", "stack", "run", "$RUN_CONFIG_PATH"]
|
||||
EOF
|
||||
elif [[ "$distro_or_config" != *.yaml ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$distro_or_config"]
|
||||
ENTRYPOINT ["llama", "stack", "run", "$distro_or_config"]
|
||||
EOF
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import httpx
|
|||
from pydantic import BaseModel, parse_obj_as
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
_CLIENT_CLASSES = {}
|
||||
|
|
@ -114,7 +113,24 @@ def create_api_client_class(protocol) -> type:
|
|||
break
|
||||
kwargs[param.name] = args[i]
|
||||
|
||||
url = f"{self.base_url}/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
# Get all webmethods for this method (supports multiple decorators)
|
||||
webmethods = getattr(method, "__webmethods__", [])
|
||||
|
||||
if not webmethods:
|
||||
raise RuntimeError(f"Method {method} has no webmethod decorators")
|
||||
|
||||
# Choose the preferred webmethod (non-deprecated if available)
|
||||
preferred_webmethod = None
|
||||
for wm in webmethods:
|
||||
if not getattr(wm, "deprecated", False):
|
||||
preferred_webmethod = wm
|
||||
break
|
||||
|
||||
# If no non-deprecated found, use the first one
|
||||
if preferred_webmethod is None:
|
||||
preferred_webmethod = webmethods[0]
|
||||
|
||||
url = f"{self.base_url}/{preferred_webmethod.level}/{preferred_webmethod.route.lstrip('/')}"
|
||||
|
||||
def convert(value):
|
||||
if isinstance(value, list):
|
||||
|
|
|
|||
|
|
@ -3,5 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .batch_inference import *
|
||||
309
llama_stack/core/conversations/conversations.py
Normal file
309
llama_stack/core/conversations/conversations.py
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreConfig,
|
||||
sqlstore_impl,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_conversations")
|
||||
|
||||
|
||||
class ConversationServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in conversation service.
|
||||
|
||||
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
|
||||
:param policy: Access control rules
|
||||
"""
|
||||
|
||||
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
|
||||
)
|
||||
policy: list[AccessRule] = []
|
||||
|
||||
|
||||
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
"""Get the conversation service implementation."""
|
||||
impl = ConversationServiceImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ConversationServiceImpl(Conversations):
|
||||
"""Built-in conversation service implementation using AuthorizedSqlStore."""
|
||||
|
||||
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.policy = config.policy
|
||||
|
||||
base_sql_store = sqlstore_impl(config.conversations_store)
|
||||
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the store and create tables."""
|
||||
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
|
||||
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"openai_conversations",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"items": ColumnType.JSON,
|
||||
"metadata": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"conversation_items",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"conversation_id": ColumnType.STRING,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"item_data": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
async def create_conversation(
|
||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||
) -> Conversation:
|
||||
"""Create a conversation."""
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
conversation_id = f"conv_{random_bytes.hex()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
record_data = {
|
||||
"id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"items": [],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
await self.sql_store.insert(
|
||||
table="openai_conversations",
|
||||
data=record_data,
|
||||
)
|
||||
|
||||
if items:
|
||||
item_records = []
|
||||
for item in items:
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
item_records.append(item_record)
|
||||
|
||||
await self.sql_store.insert(table="conversation_items", data=item_records)
|
||||
|
||||
conversation = Conversation(
|
||||
id=conversation_id,
|
||||
created_at=created_at,
|
||||
metadata=metadata,
|
||||
object="conversation",
|
||||
)
|
||||
|
||||
logger.info(f"Created conversation {conversation_id}")
|
||||
return conversation
|
||||
|
||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Get a conversation with the given ID."""
|
||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Conversation {conversation_id} not found")
|
||||
|
||||
return Conversation(
|
||||
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
|
||||
)
|
||||
|
||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||
"""Update a conversation's metadata with the given ID"""
|
||||
await self.sql_store.update(
|
||||
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
|
||||
)
|
||||
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||
"""Delete a conversation with the given ID."""
|
||||
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
logger.info(f"Deleted conversation {conversation_id}")
|
||||
return ConversationDeletedResource(id=conversation_id)
|
||||
|
||||
def _validate_conversation_id(self, conversation_id: str) -> None:
|
||||
"""Validate conversation ID format."""
|
||||
if not conversation_id.startswith("conv_"):
|
||||
raise ValueError(
|
||||
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||
)
|
||||
|
||||
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
|
||||
"""Get existing item ID or generate one if missing."""
|
||||
if item.id is None:
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
if item.type == "message":
|
||||
item_id = f"msg_{random_bytes.hex()}"
|
||||
else:
|
||||
item_id = f"item_{random_bytes.hex()}"
|
||||
item_dict["id"] = item_id
|
||||
return item_id
|
||||
return item.id
|
||||
|
||||
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Validate conversation ID and return the conversation if it exists."""
|
||||
self._validate_conversation_id(conversation_id)
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||
"""Create (add) items to a conversation."""
|
||||
await self._get_validated_conversation(conversation_id)
|
||||
|
||||
created_items = []
|
||||
base_time = int(time.time())
|
||||
|
||||
for i, item in enumerate(items):
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
# make each timestamp unique to maintain order
|
||||
created_at = base_time + i
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||
try:
|
||||
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_items",
|
||||
data={"created_at": created_at, "item_data": item_dict},
|
||||
where={"id": item_id},
|
||||
)
|
||||
|
||||
created_items.append(item_dict)
|
||||
|
||||
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
|
||||
|
||||
# Convert created items (dicts) to proper ConversationItem types
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=created_items[0]["id"] if created_items else None,
|
||||
last_id=created_items[-1]["id"] if created_items else None,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||
"""Retrieve a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
# Get item from conversation_items table
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
return adapter.validate_python(record["item_data"])
|
||||
|
||||
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
|
||||
"""List items in the conversation."""
|
||||
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
|
||||
records = result.data
|
||||
|
||||
if order != NOT_GIVEN and order == "asc":
|
||||
records.sort(key=lambda x: x["created_at"])
|
||||
else:
|
||||
records.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
actual_limit = 20
|
||||
if limit != NOT_GIVEN and isinstance(limit, int):
|
||||
actual_limit = limit
|
||||
|
||||
records = records[:actual_limit]
|
||||
items = [record["item_data"] for record in records]
|
||||
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
|
||||
|
||||
first_id = response_items[0].id if response_items else None
|
||||
last_id = response_items[-1].id if response_items else None
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def openai_delete_conversation_item(
|
||||
self, conversation_id: str, item_id: str
|
||||
) -> ConversationItemDeletedResource:
|
||||
"""Delete a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
_ = await self._get_validated_conversation(conversation_id)
|
||||
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
await self.sql_store.delete(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
|
||||
return ConversationItemDeletedResource(id=item_id)
|
||||
|
|
@ -22,7 +22,7 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
|
|
@ -84,15 +84,11 @@ class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
|
|||
pass
|
||||
|
||||
|
||||
class ToolWithOwner(Tool, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | ToolGroup
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
ModelWithOwner
|
||||
|
|
@ -101,7 +97,6 @@ RoutableObjectWithProvider = Annotated[
|
|||
| DatasetWithOwner
|
||||
| ScoringFnWithOwner
|
||||
| BenchmarkWithOwner
|
||||
| ToolWithOwner
|
||||
| ToolGroupWithOwner,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
@ -433,6 +428,12 @@ class InferenceStoreConfig(BaseModel):
|
|||
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||
|
||||
|
||||
class ResponsesStoreConfig(BaseModel):
|
||||
sql_store_config: SqlStoreConfig
|
||||
max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store")
|
||||
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
|
|
@ -474,6 +475,13 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
|
|||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
conversations_store: SqlStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the conversations API.
|
||||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import (
|
|||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
|
||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
|
||||
|
||||
|
||||
def stack_apis() -> list[Api]:
|
||||
|
|
@ -47,10 +47,6 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
|||
routing_table_api=Api.shields,
|
||||
router_api=Api.safety,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.vector_dbs,
|
||||
router_api=Api.vector_io,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.datasets,
|
||||
router_api=Api.datasetio,
|
||||
|
|
@ -243,6 +239,7 @@ def get_external_providers_from_module(
|
|||
spec = module.get_provider_spec()
|
||||
else:
|
||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
spec = ProviderSpec(
|
||||
api=Api(provider_api),
|
||||
provider_type=provider.provider_type,
|
||||
|
|
@ -251,9 +248,20 @@ def get_external_providers_from_module(
|
|||
config_class="",
|
||||
)
|
||||
provider_type = provider.provider_type
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
# return a partially filled out spec that the build script will populate.
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
if isinstance(spec, list):
|
||||
# optionally allow people to pass inline and remote provider specs as a returned list.
|
||||
# with the old method, users could pass in directories of specs using overlapping code
|
||||
# we want to ensure we preserve that flexibility in this method.
|
||||
logger.info(
|
||||
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
|
||||
)
|
||||
for provider_spec in spec:
|
||||
if provider_spec.provider_type != provider.provider_type:
|
||||
continue
|
||||
logger.info(f"Adding {provider.provider_type} to registry")
|
||||
registry[Api(provider_api)][provider.provider_type] = provider_spec
|
||||
else:
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
||||
|
|
|
|||
42
llama_stack/core/id_generation.py
Normal file
42
llama_stack/core/id_generation.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
IdFactory = Callable[[], str]
|
||||
IdOverride = Callable[[str, IdFactory], str]
|
||||
|
||||
_id_override: IdOverride | None = None
|
||||
|
||||
|
||||
def generate_object_id(kind: str, factory: IdFactory) -> str:
|
||||
"""Generate an identifier for the given kind using the provided factory.
|
||||
|
||||
Allows tests to override ID generation deterministically by installing an
|
||||
override callback via :func:`set_id_override`.
|
||||
"""
|
||||
|
||||
override = _id_override
|
||||
if override is not None:
|
||||
return override(kind, factory)
|
||||
return factory()
|
||||
|
||||
|
||||
def set_id_override(override: IdOverride) -> IdOverride | None:
|
||||
"""Install an override used to generate deterministic identifiers."""
|
||||
|
||||
global _id_override
|
||||
|
||||
previous = _id_override
|
||||
_id_override = override
|
||||
return previous
|
||||
|
||||
|
||||
def reset_id_override(previous: IdOverride | None) -> None:
|
||||
"""Restore the previous override returned by :func:`set_id_override`."""
|
||||
|
||||
global _id_override
|
||||
_id_override = previous
|
||||
|
|
@ -54,6 +54,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
@ -374,12 +375,16 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
# Merge extra_json parameters (extra_body from SDK is converted to extra_json)
|
||||
if hasattr(options, "extra_json") and options.extra_json:
|
||||
body |= options.extra_json
|
||||
|
||||
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
|
||||
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
|
@ -442,7 +447,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
# Prepare body for the function call (handles both Pydantic and traditional params)
|
||||
body = self._convert_body(func, body)
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
|
@ -489,17 +495,20 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(
|
||||
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
|
||||
) -> dict:
|
||||
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
exclude_params = exclude_params or set()
|
||||
|
||||
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
sig = inspect.signature(func)
|
||||
params_list = [p for p in sig.parameters.values() if p.name != "self"]
|
||||
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
|
||||
if len(params_list) == 1:
|
||||
param = params_list[0]
|
||||
param_type = param.annotation
|
||||
if is_unwrapped_body_param(param_type):
|
||||
base_type = get_args(param_type)[0]
|
||||
return {param.name: base_type(**body)}
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Any
|
|||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||
|
|
@ -27,8 +28,8 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
|||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.core.client import get_client_impl
|
||||
from llama_stack.core.datatypes import (
|
||||
AccessRule,
|
||||
|
|
@ -53,7 +54,6 @@ from llama_stack.providers.datatypes import (
|
|||
ScoringFunctionsProtocolPrivate,
|
||||
ShieldsProtocolPrivate,
|
||||
ToolGroupsProtocolPrivate,
|
||||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
|
@ -79,7 +79,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.inspect: Inspect,
|
||||
Api.batches: Batches,
|
||||
Api.vector_io: VectorIO,
|
||||
Api.vector_dbs: VectorDBs,
|
||||
Api.models: Models,
|
||||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
|
|
@ -95,6 +94,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.tool_runtime: ToolRuntime,
|
||||
Api.files: Files,
|
||||
Api.prompts: Prompts,
|
||||
Api.conversations: Conversations,
|
||||
}
|
||||
|
||||
if external_apis:
|
||||
|
|
@ -122,7 +122,6 @@ def additional_protocols_map() -> dict[Api, Any]:
|
|||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||
Api.scoring: (
|
||||
|
|
@ -147,6 +146,7 @@ async def resolve_impls(
|
|||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
internal_impls: dict[Api, Any] | None = None,
|
||||
) -> dict[Api, Any]:
|
||||
"""
|
||||
Resolves provider implementations by:
|
||||
|
|
@ -169,7 +169,7 @@ async def resolve_impls(
|
|||
|
||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy, internal_impls)
|
||||
|
||||
|
||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||
|
|
@ -277,9 +277,10 @@ async def instantiate_providers(
|
|||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
policy: list[AccessRule],
|
||||
internal_impls: dict[Api, Any] | None = None,
|
||||
) -> dict[Api, Any]:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: dict[Api, Any] = {}
|
||||
impls: dict[Api, Any] = internal_impls.copy() if internal_impls else {}
|
||||
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
for api_str, provider in sorted_providers:
|
||||
# Skip providers that are not enabled
|
||||
|
|
@ -412,8 +413,14 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
|
||||
mro = type(obj).__mro__
|
||||
for name, value in inspect.getmembers(protocol):
|
||||
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
||||
if value.__webmethod__.experimental:
|
||||
if inspect.isfunction(value) and hasattr(value, "__webmethods__"):
|
||||
has_alpha_api = False
|
||||
for webmethod in value.__webmethods__:
|
||||
if webmethod.level == LLAMA_STACK_API_V1ALPHA:
|
||||
has_alpha_api = True
|
||||
break
|
||||
# if this API has multiple webmethods, and one of them is an alpha API, this API should be skipped when checking for missing or not callable routes
|
||||
if has_alpha_api:
|
||||
continue
|
||||
if not hasattr(obj, name):
|
||||
missing_methods.append((name, "missing"))
|
||||
|
|
|
|||
|
|
@ -26,10 +26,8 @@ async def get_routing_table_impl(
|
|||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||
from ..routing_tables.shields import ShieldsRoutingTable
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
api_to_tables = {
|
||||
"vector_dbs": VectorDBsRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
"shields": ShieldsRoutingTable,
|
||||
"datasets": DatasetsRoutingTable,
|
||||
|
|
|
|||
|
|
@ -10,50 +10,40 @@ from collections.abc import AsyncGenerator, AsyncIterator
|
|||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import Body
|
||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import Field, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.inference import (
|
||||
BatchChatCompletionResponse,
|
||||
BatchCompletionResponse,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAICompletionWithInputMessages,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
Order,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
|
|
@ -191,243 +181,25 @@ class InferenceRouter(Inference):
|
|||
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||
return model
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = None,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id, ModelType.llm)
|
||||
if tool_config:
|
||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
||||
else:
|
||||
params = {}
|
||||
if tool_choice:
|
||||
params["tool_choice"] = tool_choice
|
||||
if tool_prompt_format:
|
||||
params["tool_prompt_format"] = tool_prompt_format
|
||||
tool_config = ToolConfig(**params)
|
||||
|
||||
tools = tools or []
|
||||
if tool_config.tool_choice == ToolChoice.none:
|
||||
tools = []
|
||||
elif tool_config.tool_choice == ToolChoice.auto:
|
||||
pass
|
||||
elif tool_config.tool_choice == ToolChoice.required:
|
||||
pass
|
||||
else:
|
||||
# verify tool_choice is one of the tools
|
||||
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
||||
if tool_config.tool_choice not in tool_names:
|
||||
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
||||
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
response_stream = await provider.chat_completion(**params)
|
||||
return self.stream_tokens_and_compute_metrics(
|
||||
response=response_stream,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
tool_prompt_format=tool_config.tool_prompt_format,
|
||||
)
|
||||
|
||||
response = await provider.chat_completion(**params)
|
||||
metrics = await self.count_tokens_and_compute_metrics(
|
||||
response=response,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
tool_prompt_format=tool_config.tool_prompt_format,
|
||||
)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_chat_completion(
|
||||
model_id=model_id,
|
||||
messages_batch=messages_batch,
|
||||
tools=tools,
|
||||
tool_config=tool_config,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
logger.debug(
|
||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||
)
|
||||
model = await self._get_model(model_id, ModelType.llm)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
prompt_tokens = await self._count_tokens(content)
|
||||
response = await provider.completion(**params)
|
||||
if stream:
|
||||
return self.stream_tokens_and_compute_metrics(
|
||||
response=response,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
)
|
||||
|
||||
metrics = await self.count_tokens_and_compute_metrics(
|
||||
response=response, prompt_tokens=prompt_tokens, model=model
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
|
||||
return response
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
await self._get_model(model_id, ModelType.embedding)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
output_dimension=output_dimension,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAICompletion:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||
)
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
suffix=suffix,
|
||||
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
return await provider.openai_completion(**params)
|
||||
if params.stream:
|
||||
return await provider.openai_completion(params)
|
||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||
# response_stream = await provider.openai_completion(**params)
|
||||
|
||||
response = await provider.openai_completion(**params)
|
||||
response = await provider.openai_completion(params)
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
|
|
@ -446,93 +218,49 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
|
||||
)
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
|
||||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
if tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
||||
if tools is None:
|
||||
if params.tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
|
||||
if params.tools is None:
|
||||
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if params.tools:
|
||||
for tool in params.tools:
|
||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||
|
||||
# Some providers make tool calls even when tool_choice is "none"
|
||||
# so just clear them both out to avoid unexpected tool calls
|
||||
if tool_choice == "none" and tools is not None:
|
||||
tool_choice = None
|
||||
tools = None
|
||||
if params.tool_choice == "none" and params.tools is not None:
|
||||
params.tool_choice = None
|
||||
params.tools = None
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
response_stream = await provider.openai_chat_completion(**params)
|
||||
if params.stream:
|
||||
response_stream = await provider.openai_chat_completion(params)
|
||||
|
||||
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
||||
# We need to add metrics to each chunk and store the final completion
|
||||
return self.stream_tokens_and_compute_metrics_openai_chat(
|
||||
response=response_stream,
|
||||
model=model_obj,
|
||||
messages=messages,
|
||||
messages=params.messages,
|
||||
)
|
||||
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
||||
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
||||
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
|
|
@ -588,8 +316,10 @@ class InferenceRouter(Inference):
|
|||
return await self.store.get_chat_completion(completion_id)
|
||||
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
||||
|
||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(**params)
|
||||
async def _nonstream_openai_chat_completion(
|
||||
self, provider: Inference, params: OpenAIChatCompletionRequestWithExtraBody
|
||||
) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(params)
|
||||
for choice in response.choices:
|
||||
# some providers return an empty list for no tool calls in non-streaming responses
|
||||
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
||||
|
|
@ -803,7 +533,7 @@ class InferenceRouter(Inference):
|
|||
completion_text += "".join(choice_data["content_parts"])
|
||||
|
||||
# Add metrics to the chunk
|
||||
if self.telemetry and chunk.usage:
|
||||
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolsResponse,
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
|
|
@ -86,6 +86,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolsResponse:
|
||||
) -> ListToolDefsResponse:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.list_tools(tool_group_id)
|
||||
|
|
|
|||
|
|
@ -8,9 +8,7 @@ import asyncio
|
|||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
|
@ -19,9 +17,11 @@ from llama_stack.apis.vector_io import (
|
|||
VectorIO,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileBatchObject,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFilesListInBatchResponse,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreListResponse,
|
||||
VectorStoreObject,
|
||||
|
|
@ -193,7 +193,10 @@ class VectorIORouter(VectorIO):
|
|||
all_stores = all_stores[after_index + 1 :]
|
||||
|
||||
if before:
|
||||
before_index = next((i for i, store in enumerate(all_stores) if store.id == before), len(all_stores))
|
||||
before_index = next(
|
||||
(i for i, store in enumerate(all_stores) if store.id == before),
|
||||
len(all_stores),
|
||||
)
|
||||
all_stores = all_stores[:before_index]
|
||||
|
||||
# Apply limit
|
||||
|
|
@ -363,3 +366,61 @@ class VectorIORouter(VectorIO):
|
|||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||
)
|
||||
return health_statuses
|
||||
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_ids: list[str],
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(file_ids)} files")
|
||||
return await self.routing_table.openai_create_vector_store_file_batch(
|
||||
vector_store_id=vector_store_id,
|
||||
file_ids=file_ids,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||
return await self.routing_table.openai_retrieve_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: str | None = None,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
) -> VectorStoreFilesListInBatchResponse:
|
||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||
return await self.routing_table.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
limit=limit,
|
||||
order=order,
|
||||
)
|
||||
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||
return await self.routing_table.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from typing import Any
|
|||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import Action
|
||||
from llama_stack.core.datatypes import (
|
||||
|
|
@ -17,6 +16,7 @@ from llama_stack.core.datatypes import (
|
|||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
ScoringFnWithOwner,
|
||||
)
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
|
|
@ -114,7 +114,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
await add_objects(scoring_functions, pid, ScoringFnWithOwner)
|
||||
elif api == Api.eval:
|
||||
p.benchmark_store = self
|
||||
elif api == Api.tool_runtime:
|
||||
|
|
@ -134,15 +134,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||
from .shields import ShieldsRoutingTable
|
||||
from .toolgroups import ToolGroupsRoutingTable
|
||||
from .vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
elif isinstance(self, ShieldsRoutingTable):
|
||||
return ("Safety", "shield")
|
||||
elif isinstance(self, VectorDBsRoutingTable):
|
||||
return ("VectorIO", "vector_db")
|
||||
elif isinstance(self, DatasetsRoutingTable):
|
||||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
logger.debug(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
|
|
@ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
|
||||
return self.impls_by_provider_id[model.provider_id]
|
||||
|
||||
async def has_model(self, model_id: str) -> bool:
|
||||
"""
|
||||
Check if a model exists in the routing table.
|
||||
|
||||
:param model_id: The model identifier to check
|
||||
:return: True if the model exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
await lookup_model(self, model_id)
|
||||
return True
|
||||
except ModelNotFoundError:
|
||||
return False
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ToolGroupNotFoundError
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||
from llama_stack.core.datatypes import ToolGroupWithOwner
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ListToolGroupsResponse, ToolDef, ToolGroup, ToolGroups
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
|
@ -27,7 +27,7 @@ def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name
|
|||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
toolgroups_to_tools: dict[str, list[Tool]] = {}
|
||||
toolgroups_to_tools: dict[str, list[ToolDef]] = {}
|
||||
tool_to_toolgroup: dict[str, str] = {}
|
||||
|
||||
# overridden
|
||||
|
|
@ -43,7 +43,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
if toolgroup_id:
|
||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||
toolgroup_id = group_id
|
||||
|
|
@ -54,33 +54,33 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
all_tools = []
|
||||
for toolgroup in toolgroups:
|
||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||
await self._index_tools(toolgroup)
|
||||
try:
|
||||
await self._index_tools(toolgroup)
|
||||
except AuthenticationRequiredError:
|
||||
# Send authentication errors back to the client so it knows
|
||||
# that it needs to supply credentials for remote MCP servers.
|
||||
raise
|
||||
except Exception as e:
|
||||
# Other errors that the client cannot fix are logged and
|
||||
# those specific toolgroups are skipped.
|
||||
logger.warning(f"Error listing tools for toolgroup {toolgroup.identifier}: {e}")
|
||||
logger.debug(e, exc_info=True)
|
||||
continue
|
||||
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
||||
|
||||
return ListToolsResponse(data=all_tools)
|
||||
return ListToolDefsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
|
||||
# TODO: kill this Tool vs ToolDef distinction
|
||||
tooldefs = tooldefs_response.data
|
||||
tools = []
|
||||
for t in tooldefs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=t.name,
|
||||
toolgroup_id=toolgroup.identifier,
|
||||
description=t.description or "",
|
||||
parameters=t.parameters or [],
|
||||
metadata=t.metadata,
|
||||
provider_id=toolgroup.provider_id,
|
||||
)
|
||||
)
|
||||
t.toolgroup_id = toolgroup.identifier
|
||||
|
||||
self.toolgroups_to_tools[toolgroup.identifier] = tools
|
||||
for tool in tools:
|
||||
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
|
||||
self.toolgroups_to_tools[toolgroup.identifier] = tooldefs
|
||||
for tool in tooldefs:
|
||||
self.tool_to_toolgroup[tool.name] = toolgroup.identifier
|
||||
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||
|
|
@ -91,12 +91,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
raise ToolGroupNotFoundError(toolgroup_id)
|
||||
return tool_group
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
async def get_tool(self, tool_name: str) -> ToolDef:
|
||||
if tool_name in self.tool_to_toolgroup:
|
||||
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
||||
tools = self.toolgroups_to_tools[toolgroup_id]
|
||||
for tool in tools:
|
||||
if tool.identifier == tool_name:
|
||||
if tool.name == tool_name:
|
||||
return tool
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
|
|
@ -121,7 +121,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
# baked in some of the code and tests right now.
|
||||
if not toolgroup.mcp_endpoint:
|
||||
await self._index_tools(toolgroup)
|
||||
return toolgroup
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
await self.unregister_object(await self.get_tool_group(toolgroup_id))
|
||||
|
|
|
|||
|
|
@ -1,247 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import (
|
||||
VectorDBWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
||||
|
||||
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
||||
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||
if vector_db is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
return vector_db
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
) -> VectorDB:
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
if len(self.impls_by_provider_id) > 1:
|
||||
logger.warning(
|
||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await lookup_model(self, embedding_model)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(embedding_model)
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
|
||||
provider = self.impls_by_provider_id[provider_id]
|
||||
logger.warning(
|
||||
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
||||
)
|
||||
vector_store = await provider.openai_create_vector_store(
|
||||
name=vector_db_name or vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=model.metadata["embedding_dimension"],
|
||||
provider_id=provider_id,
|
||||
provider_vector_db_id=provider_vector_db_id,
|
||||
)
|
||||
|
||||
vector_store_id = vector_store.id
|
||||
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
|
||||
logger.warning(
|
||||
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
|
||||
)
|
||||
|
||||
vector_db_data = {
|
||||
"identifier": vector_store_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": actual_provider_vector_db_id,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
"vector_db_name": vector_store.name,
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
existing_vector_db = await self.get_vector_db(vector_db_id)
|
||||
await self.unregister_object(existing_vector_db)
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
@ -27,6 +27,11 @@ class AuthenticationMiddleware:
|
|||
3. Extracts user attributes from the provider's response
|
||||
4. Makes these attributes available to the route handlers for access control
|
||||
|
||||
Unauthenticated Access:
|
||||
Endpoints can opt out of authentication by setting require_authentication=False
|
||||
in their @webmethod decorator. This is typically used for operational endpoints
|
||||
like /health and /version to support monitoring, load balancers, and observability tools.
|
||||
|
||||
The middleware supports multiple authentication providers through the AuthProvider interface:
|
||||
- Kubernetes: Validates tokens against the Kubernetes API server
|
||||
- Custom: Validates tokens against a custom endpoint
|
||||
|
|
@ -88,7 +93,26 @@ class AuthenticationMiddleware:
|
|||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
# First, handle authentication
|
||||
# Find the route and check if authentication is required
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", hdrs.METH_GET)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
|
||||
webmethod = None
|
||||
try:
|
||||
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass here to run auth anyways
|
||||
pass
|
||||
|
||||
# If webmethod explicitly sets require_authentication=False, allow without auth
|
||||
if webmethod and webmethod.require_authentication is False:
|
||||
logger.debug(f"Allowing unauthenticated access to endpoint: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Handle authentication
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
|
|
@ -127,19 +151,7 @@ class AuthenticationMiddleware:
|
|||
)
|
||||
|
||||
# Scope-based API access control
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", hdrs.METH_GET)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
|
||||
try:
|
||||
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if webmethod.required_scope:
|
||||
if webmethod and webmethod.required_scope:
|
||||
user = user_from_scope(scope)
|
||||
if not _has_required_scope(webmethod.required_scope, user):
|
||||
return await self._send_auth_error(
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from starlette.routing import Route
|
|||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
|
|
@ -54,22 +53,23 @@ def get_all_api_routes(
|
|||
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||
|
||||
for name, method in protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
# Get all webmethods for this method (supports multiple decorators)
|
||||
webmethods = getattr(method, "__webmethods__", [])
|
||||
if not webmethods:
|
||||
continue
|
||||
|
||||
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
|
||||
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
|
||||
webmethod = method.__webmethod__ # type: ignore[attr-defined]
|
||||
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
if webmethod.method == hdrs.METH_GET:
|
||||
http_method = hdrs.METH_GET
|
||||
elif webmethod.method == hdrs.METH_DELETE:
|
||||
http_method = hdrs.METH_DELETE
|
||||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
# Create routes for each webmethod decorator
|
||||
for webmethod in webmethods:
|
||||
path = f"/{webmethod.level}/{webmethod.route.lstrip('/')}"
|
||||
if webmethod.method == hdrs.METH_GET:
|
||||
http_method = hdrs.METH_GET
|
||||
elif webmethod.method == hdrs.METH_DELETE:
|
||||
http_method = hdrs.METH_DELETE
|
||||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
|
||||
apis[api] = routes
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import functools
|
||||
|
|
@ -12,7 +11,6 @@ import inspect
|
|||
import json
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
|
@ -25,7 +23,6 @@ from typing import Annotated, Any, get_origin
|
|||
import httpx
|
||||
import rich.pretty
|
||||
import yaml
|
||||
from aiohttp import hdrs
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
|
@ -36,7 +33,6 @@ from pydantic import BaseModel, ValidationError
|
|||
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
|
|
@ -45,22 +41,17 @@ from llama_stack.core.datatypes import (
|
|||
process_cors_config,
|
||||
)
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.core.server.routes import (
|
||||
find_matching_route,
|
||||
get_all_api_routes,
|
||||
initialize_route_impls,
|
||||
)
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
cast_image_name_to_string,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
|
|
@ -73,13 +64,12 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
|||
)
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .quota import QuotaMiddleware
|
||||
from .tracing import TracingMiddleware
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
|
@ -148,6 +138,13 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}")
|
||||
elif isinstance(exc, AuthenticationRequiredError):
|
||||
return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}")
|
||||
elif hasattr(exc, "status_code") and isinstance(getattr(exc, "status_code", None), int):
|
||||
# Handle provider SDK exceptions (e.g., OpenAI's APIStatusError and subclasses)
|
||||
# These include AuthenticationError (401), PermissionDeniedError (403), etc.
|
||||
# This preserves the actual HTTP status code from the provider
|
||||
status_code = exc.status_code
|
||||
detail = str(exc)
|
||||
return HTTPException(status_code=status_code, detail=detail)
|
||||
else:
|
||||
return HTTPException(
|
||||
status_code=httpx.codes.INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -187,7 +184,17 @@ async def lifespan(app: StackApp):
|
|||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
# If there's a stream parameter at top level, use it
|
||||
if "stream" in kwargs:
|
||||
return kwargs["stream"]
|
||||
|
||||
# If there's a stream parameter inside a "params" parameter, e.g. openai_chat_completion() use it
|
||||
if "params" in kwargs:
|
||||
params = kwargs["params"]
|
||||
if hasattr(params, "stream"):
|
||||
return params.stream
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def maybe_await(value):
|
||||
|
|
@ -242,15 +249,31 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
test_context_token = None
|
||||
test_context_var = None
|
||||
reset_test_context_fn = None
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user):
|
||||
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
|
||||
from llama_stack.core.testing_context import (
|
||||
TEST_CONTEXT,
|
||||
reset_test_context,
|
||||
sync_test_context_from_provider_data,
|
||||
)
|
||||
|
||||
test_context_token = sync_test_context_from_provider_data()
|
||||
test_context_var = TEST_CONTEXT
|
||||
reset_test_context_fn = reset_test_context
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_contexts_async_generator(
|
||||
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
)
|
||||
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
if test_context_var is not None:
|
||||
context_vars.append(test_context_var)
|
||||
gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars)
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
|
|
@ -263,11 +286,14 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
return result
|
||||
except Exception as e:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
else:
|
||||
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
if test_context_token is not None and reset_test_context_fn is not None:
|
||||
reset_test_context_fn(test_context_token)
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
|
|
@ -299,65 +325,6 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
return route_handler
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
class ClientVersionMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
|
@ -398,23 +365,18 @@ class ClientVersionMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def create_app(
|
||||
config_file: str | None = None,
|
||||
env_vars: list[str] | None = None,
|
||||
) -> StackApp:
|
||||
def create_app() -> StackApp:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
Args:
|
||||
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
|
||||
env_vars: List of environment variables in KEY=value format.
|
||||
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
|
||||
This factory function reads configuration from environment variables:
|
||||
- LLAMA_STACK_CONFIG: Path to config file (required)
|
||||
|
||||
Returns:
|
||||
Configured StackApp instance.
|
||||
"""
|
||||
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
|
||||
config_file = os.getenv("LLAMA_STACK_CONFIG")
|
||||
if config_file is None:
|
||||
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
|
||||
raise ValueError("LLAMA_STACK_CONFIG environment variable is required")
|
||||
|
||||
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||
|
||||
|
|
@ -426,16 +388,6 @@ def create_app(
|
|||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||
|
||||
if env_vars:
|
||||
for env_pair in env_vars:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
|
|
@ -516,6 +468,7 @@ def create_app(
|
|||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
apis_to_serve.add("prompts")
|
||||
apis_to_serve.add("conversations")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
|
@ -558,101 +511,6 @@ def create_app(
|
|||
return app
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
config_or_distro = get_config_from_args(args)
|
||||
|
||||
try:
|
||||
app = create_app(
|
||||
config_file=config_or_distro,
|
||||
env_vars=args.env,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating app: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Configure SSL if certificates are provided
|
||||
port = args.port or config.server.port
|
||||
|
||||
ssl_config = None
|
||||
keyfile = config.server.tls_keyfile
|
||||
certfile = config.server.tls_certfile
|
||||
|
||||
if keyfile and certfile:
|
||||
ssl_config = {
|
||||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
if config.server.tls_cafile:
|
||||
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
logger.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = config.server.host or ["::", "0.0.0.0"]
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
"host": listen_host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
"""Logs the run config with redacted fields and disabled providers removed."""
|
||||
logger.info("Run configuration:")
|
||||
|
|
@ -679,7 +537,3 @@ def remove_disabled_providers(obj):
|
|||
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
80
llama_stack/core/server/tracing.py
Normal file
80
llama_stack/core/server/tracing.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.core.external import ExternalApiSpec
|
||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
|
||||
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Log deprecation warning if route is deprecated
|
||||
if getattr(webmethod, "deprecated", False):
|
||||
logger.warning(
|
||||
f"DEPRECATED ROUTE USED: {scope.get('method', 'GET')} {path} - "
|
||||
f"This route is deprecated and may be removed in a future version. "
|
||||
f"Please check the docs for the supported version."
|
||||
)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
|
@ -14,8 +14,8 @@ from typing import Any
|
|||
import yaml
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batch_inference import BatchInference
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
|
|
@ -33,8 +33,8 @@ from llama_stack.apis.shields import Shields
|
|||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
|
|
@ -52,9 +52,7 @@ logger = get_logger(name=__name__, category="core")
|
|||
|
||||
class LlamaStack(
|
||||
Providers,
|
||||
VectorDBs,
|
||||
Inference,
|
||||
BatchInference,
|
||||
Agents,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
|
|
@ -75,6 +73,7 @@ class LlamaStack(
|
|||
RAGToolRuntime,
|
||||
Files,
|
||||
Prompts,
|
||||
Conversations,
|
||||
):
|
||||
pass
|
||||
|
||||
|
|
@ -82,7 +81,6 @@ class LlamaStack(
|
|||
RESOURCES = [
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
("shields", Api.shields, "register_shield", "list_shields"),
|
||||
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
|
||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||
(
|
||||
"scoring_fns",
|
||||
|
|
@ -273,22 +271,6 @@ def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|||
return config_dict
|
||||
|
||||
|
||||
def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||
"""Validate and split an environment variable key-value pair."""
|
||||
try:
|
||||
key, value = env_pair.split("=", 1)
|
||||
key = key.strip()
|
||||
if not key:
|
||||
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
|
||||
if not all(c.isalnum() or c == "_" for c in key):
|
||||
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
|
||||
return key, value
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
|
||||
) from e
|
||||
|
||||
|
||||
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||
|
||||
|
|
@ -314,6 +296,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
)
|
||||
impls[Api.prompts] = prompts_impl
|
||||
|
||||
conversations_impl = ConversationServiceImpl(
|
||||
ConversationServiceConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.conversations] = conversations_impl
|
||||
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
|
|
@ -325,25 +313,32 @@ class Stack:
|
|||
# asked for in the run config.
|
||||
async def initialize(self):
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||
from llama_stack.testing.api_recorder import setup_api_recording
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||
TEST_RECORDING_CONTEXT = setup_api_recording()
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||
impls = await resolve_impls(
|
||||
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
|
||||
)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, self.run_config)
|
||||
internal_impls = {}
|
||||
add_internal_implementations(internal_impls, self.run_config)
|
||||
|
||||
impls = await resolve_impls(
|
||||
self.run_config,
|
||||
self.provider_registry or get_provider_registry(self.run_config),
|
||||
dist_registry,
|
||||
policy,
|
||||
internal_impls,
|
||||
)
|
||||
|
||||
if Api.prompts in impls:
|
||||
await impls[Api.prompts].initialize()
|
||||
if Api.conversations in impls:
|
||||
await impls[Api.conversations].initialize()
|
||||
|
||||
await register_resources(self.run_config, impls)
|
||||
|
||||
|
|
@ -388,7 +383,7 @@ class Stack:
|
|||
try:
|
||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference recording cleanup: {e}")
|
||||
logger.error(f"Error during API recording cleanup: {e}")
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
if REGISTRY_REFRESH_TASK:
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ error_handler() {
|
|||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
|
||||
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
@ -43,7 +43,6 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
|||
|
||||
# Initialize variables
|
||||
yaml_config=""
|
||||
env_vars=""
|
||||
other_args=""
|
||||
|
||||
# Process remaining arguments
|
||||
|
|
@ -58,15 +57,6 @@ while [[ $# -gt 0 ]]; do
|
|||
exit 1
|
||||
fi
|
||||
;;
|
||||
--env)
|
||||
if [[ -n "$2" ]]; then
|
||||
env_vars="$env_vars --env $2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
other_args="$other_args $1"
|
||||
shift
|
||||
|
|
@ -116,10 +106,9 @@ if [[ "$env_type" == "venv" ]]; then
|
|||
yaml_config_arg=""
|
||||
fi
|
||||
|
||||
$PYTHON_BINARY -m llama_stack.core.server.server \
|
||||
llama stack run \
|
||||
$yaml_config_arg \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
|
|||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v9"
|
||||
KEY_VERSION = "v10"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
|
|
@ -96,10 +96,10 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||
existing_obj = await self.get(obj.type, obj.identifier)
|
||||
# warn if the object's providerid is different but proceed with registration
|
||||
if existing_obj and existing_obj.provider_id != obj.provider_id:
|
||||
logger.warning(
|
||||
f"Object {existing_obj.type}:{existing_obj.identifier}'s {existing_obj.provider_id} provider is being replaced with {obj.provider_id}"
|
||||
if existing_obj and existing_obj != obj:
|
||||
raise ValueError(
|
||||
f"Object of type '{obj.type}' and identifier '{obj.identifier}' already exists. "
|
||||
"Unregister it first if you want to replace it."
|
||||
)
|
||||
|
||||
await self.kvstore.set(
|
||||
|
|
|
|||
44
llama_stack/core/testing_context.py
Normal file
44
llama_stack/core/testing_context.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
|
||||
|
||||
TEST_CONTEXT: ContextVar[str | None] = ContextVar("llama_stack_test_context", default=None)
|
||||
|
||||
|
||||
def get_test_context() -> str | None:
|
||||
return TEST_CONTEXT.get()
|
||||
|
||||
|
||||
def set_test_context(value: str | None):
|
||||
return TEST_CONTEXT.set(value)
|
||||
|
||||
|
||||
def reset_test_context(token) -> None:
|
||||
TEST_CONTEXT.reset(token)
|
||||
|
||||
|
||||
def sync_test_context_from_provider_data():
|
||||
"""Sync test context from provider data when running in server test mode."""
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
|
||||
return None
|
||||
|
||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
if stack_config_type != "server":
|
||||
return None
|
||||
|
||||
try:
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
except LookupError:
|
||||
provider_data = None
|
||||
|
||||
if provider_data and "__test_id" in provider_data:
|
||||
return TEST_CONTEXT.set(provider_data["__test_id"])
|
||||
|
||||
return None
|
||||
|
|
@ -11,19 +11,17 @@ from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
|
|||
from llama_stack.core.ui.page.distribution.models import models
|
||||
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
|
||||
from llama_stack.core.ui.page.distribution.shields import shields
|
||||
from llama_stack.core.ui.page.distribution.vector_dbs import vector_dbs
|
||||
|
||||
|
||||
def resources_page():
|
||||
options = [
|
||||
"Models",
|
||||
"Vector Databases",
|
||||
"Shields",
|
||||
"Scoring Functions",
|
||||
"Datasets",
|
||||
"Benchmarks",
|
||||
]
|
||||
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
|
||||
icons = ["magic", "shield", "file-bar-graph", "database", "list-task"]
|
||||
selected_resource = option_menu(
|
||||
None,
|
||||
options,
|
||||
|
|
@ -37,8 +35,6 @@ def resources_page():
|
|||
)
|
||||
if selected_resource == "Benchmarks":
|
||||
benchmarks()
|
||||
elif selected_resource == "Vector Databases":
|
||||
vector_dbs()
|
||||
elif selected_resource == "Datasets":
|
||||
datasets()
|
||||
elif selected_resource == "Models":
|
||||
|
|
|
|||
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def vector_dbs():
|
||||
st.header("Vector Databases")
|
||||
vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
|
||||
|
||||
if len(vector_dbs_info) > 0:
|
||||
selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
|
||||
st.json(vector_dbs_info[selected_vector_db])
|
||||
else:
|
||||
st.info("No vector databases found")
|
||||
|
|
@ -1,301 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||
|
||||
from llama_stack.apis.common.content_types import ToolCallDelta
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
from llama_stack.core.ui.modules.utils import data_url_from_file
|
||||
|
||||
|
||||
def rag_chat_page():
|
||||
st.title("🦙 RAG")
|
||||
|
||||
def reset_agent_and_chat():
|
||||
st.session_state.clear()
|
||||
st.cache_resource.clear()
|
||||
|
||||
def should_disable_input():
|
||||
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
|
||||
|
||||
def log_message(message):
|
||||
with st.chat_message(message["role"]):
|
||||
if "tool_output" in message and message["tool_output"]:
|
||||
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
|
||||
st.write(message["tool_output"])
|
||||
st.markdown(message["content"])
|
||||
|
||||
with st.sidebar:
|
||||
# File/Directory Upload Section
|
||||
st.subheader("Upload Documents", divider=True)
|
||||
uploaded_files = st.file_uploader(
|
||||
"Upload file(s) or directory",
|
||||
accept_multiple_files=True,
|
||||
type=["txt", "pdf", "doc", "docx"], # Add more file types as needed
|
||||
)
|
||||
# Process uploaded files
|
||||
if uploaded_files:
|
||||
st.success(f"Successfully uploaded {len(uploaded_files)} files")
|
||||
# Add memory bank name input field
|
||||
vector_db_name = st.text_input(
|
||||
"Document Collection Name",
|
||||
value="rag_vector_db",
|
||||
help="Enter a unique identifier for this document collection",
|
||||
)
|
||||
if st.button("Create Document Collection"):
|
||||
documents = [
|
||||
RAGDocument(
|
||||
document_id=uploaded_file.name,
|
||||
content=data_url_from_file(uploaded_file),
|
||||
)
|
||||
for i, uploaded_file in enumerate(uploaded_files)
|
||||
]
|
||||
|
||||
providers = llama_stack_api.client.providers.list()
|
||||
vector_io_provider = None
|
||||
|
||||
for x in providers:
|
||||
if x.api == "vector_io":
|
||||
vector_io_provider = x.provider_id
|
||||
|
||||
llama_stack_api.client.vector_dbs.register(
|
||||
vector_db_id=vector_db_name, # Use the user-provided name
|
||||
embedding_dimension=384,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
provider_id=vector_io_provider,
|
||||
)
|
||||
|
||||
# insert documents using the custom vector db name
|
||||
llama_stack_api.client.tool_runtime.rag_tool.insert(
|
||||
vector_db_id=vector_db_name, # Use the user-provided name
|
||||
documents=documents,
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
st.success("Vector database created successfully!")
|
||||
|
||||
st.subheader("RAG Parameters", divider=True)
|
||||
|
||||
rag_mode = st.radio(
|
||||
"RAG mode",
|
||||
["Direct", "Agent-based"],
|
||||
captions=[
|
||||
"RAG is performed by directly retrieving the information and augmenting the user query",
|
||||
"RAG is performed by an agent activating a dedicated knowledge search tool.",
|
||||
],
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
|
||||
# select memory banks
|
||||
vector_dbs = llama_stack_api.client.vector_dbs.list()
|
||||
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
||||
selected_vector_dbs = st.multiselect(
|
||||
label="Select Document Collections to use in RAG queries",
|
||||
options=vector_dbs,
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
|
||||
st.subheader("Inference Parameters", divider=True)
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
||||
selected_model = st.selectbox(
|
||||
label="Choose a model",
|
||||
options=available_models,
|
||||
index=0,
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
system_prompt = st.text_area(
|
||||
"System Prompt",
|
||||
value="You are a helpful assistant. ",
|
||||
help="Initial instructions given to the AI to set its behavior and context",
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.0,
|
||||
step=0.1,
|
||||
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
|
||||
top_p = st.slider(
|
||||
"Top P",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.95,
|
||||
step=0.1,
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
|
||||
# Add clear chat button to sidebar
|
||||
if st.button("Clear Chat", use_container_width=True):
|
||||
reset_agent_and_chat()
|
||||
st.rerun()
|
||||
|
||||
# Chat Interface
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if "displayed_messages" not in st.session_state:
|
||||
st.session_state.displayed_messages = []
|
||||
|
||||
# Display chat history
|
||||
for message in st.session_state.displayed_messages:
|
||||
log_message(message)
|
||||
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
llama_stack_api.client,
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
tools=[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
if rag_mode == "Agent-based":
|
||||
agent = create_agent()
|
||||
if "agent_session_id" not in st.session_state:
|
||||
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
|
||||
|
||||
session_id = st.session_state["agent_session_id"]
|
||||
|
||||
def agent_process_prompt(prompt):
|
||||
# Add user message to chat history
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Send the prompt to the agent
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Display assistant response
|
||||
with st.chat_message("assistant"):
|
||||
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
if log.role == "tool_execution":
|
||||
retrieval_response += log.content.replace("====", "").strip()
|
||||
retrieval_message_placeholder.write(retrieval_response)
|
||||
else:
|
||||
full_response += log.content
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
st.session_state.displayed_messages.append(
|
||||
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
|
||||
)
|
||||
|
||||
def direct_process_prompt(prompt):
|
||||
# Add the system prompt in the beginning of the conversation
|
||||
if len(st.session_state.messages) == 0:
|
||||
st.session_state.messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# Query the vector DB
|
||||
rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
|
||||
content=prompt, vector_db_ids=list(selected_vector_dbs)
|
||||
)
|
||||
prompt_context = rag_response.content
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
with st.expander(label="Retrieval Output", expanded=False):
|
||||
st.write(prompt_context)
|
||||
|
||||
retrieval_message_placeholder = st.empty()
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
|
||||
# Construct the extended prompt
|
||||
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
|
||||
|
||||
# Run inference directly
|
||||
st.session_state.messages.append({"role": "user", "content": extended_prompt})
|
||||
response = llama_stack_api.client.inference.chat_completion(
|
||||
messages=st.session_state.messages,
|
||||
model_id=selected_model,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Display assistant response
|
||||
for chunk in response:
|
||||
response_delta = chunk.event.delta
|
||||
if isinstance(response_delta, ToolCallDelta):
|
||||
retrieval_response += response_delta.tool_call.replace("====", "").strip()
|
||||
retrieval_message_placeholder.info(retrieval_response)
|
||||
else:
|
||||
full_response += chunk.event.delta.text
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
|
||||
st.session_state.messages.append(response_dict)
|
||||
st.session_state.displayed_messages.append(response_dict)
|
||||
|
||||
# Chat input
|
||||
if prompt := st.chat_input("Ask a question about your documents"):
|
||||
# Add user message to chat history
|
||||
st.session_state.displayed_messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Display user message
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
# store the prompt to process it after page refresh
|
||||
st.session_state.prompt = prompt
|
||||
|
||||
# force page refresh to disable the settings widgets
|
||||
st.rerun()
|
||||
|
||||
if "prompt" in st.session_state and st.session_state.prompt is not None:
|
||||
if rag_mode == "Agent-based":
|
||||
agent_process_prompt(st.session_state.prompt)
|
||||
else: # rag_mode == "Direct"
|
||||
direct_process_prompt(st.session_state.prompt)
|
||||
st.session_state.prompt = None
|
||||
|
||||
|
||||
rag_chat_page()
|
||||
|
|
@ -81,7 +81,7 @@ def tool_chat_page():
|
|||
|
||||
for toolgroup_id in toolgroup_selection:
|
||||
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
||||
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
||||
grouped_tools[toolgroup_id] = [tool.name for tool in tools]
|
||||
total_tools += len(tools)
|
||||
|
||||
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
|
|
@ -224,6 +224,9 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/inference_store.db
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/conversations.db
|
||||
models: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -115,13 +115,13 @@ docker run -it \
|
|||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v $HOME/.llama:/root/.llama \
|
||||
# NOTE: mount the llama-stack directory if testing local changes else not needed
|
||||
-v /home/hjshah/git/llama-stack:/app/llama-stack-source \
|
||||
-v $HOME/git/llama-stack:/app/llama-stack-source \
|
||||
# localhost/distribution-dell:dev if building / testing locally
|
||||
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
-e DEH_URL=$DEH_URL \
|
||||
-e CHROMA_URL=$CHROMA_URL \
|
||||
llamastack/distribution-{{ name }}\
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
|
||||
```
|
||||
|
||||
|
|
@ -142,14 +142,14 @@ docker run \
|
|||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v $HOME/.llama:/root/.llama \
|
||||
-v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
-e DEH_URL=$DEH_URL \
|
||||
-e SAFETY_MODEL=$SAFETY_MODEL \
|
||||
-e DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
-e CHROMA_URL=$CHROMA_URL \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
|
|
@ -158,21 +158,21 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a
|
|||
|
||||
```bash
|
||||
llama stack build --distro {{ name }} --image-type conda
|
||||
llama stack run {{ name }}
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
DEH_URL=$DEH_URL \
|
||||
CHROMA_URL=$CHROMA_URL \
|
||||
llama stack run {{ name }} \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, use:
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
DEH_URL=$DEH_URL \
|
||||
SAFETY_MODEL=$SAFETY_MODEL \
|
||||
DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
CHROMA_URL=$CHROMA_URL \
|
||||
llama stack run ./run-with-safety.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
@ -101,6 +101,9 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/conversations.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
@ -97,6 +97,9 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/conversations.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
|||
|
|
@ -29,31 +29,7 @@ The following environment variables can be configured:
|
|||
|
||||
## Prerequisite: Downloading Models
|
||||
|
||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||
|
||||
```
|
||||
$ llama model list --downloaded
|
||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
||||
┃ Model ┃ Size ┃ Modified Time ┃
|
||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
||||
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
||||
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
||||
Please check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models using the Hugging Face CLI.
|
||||
```
|
||||
|
||||
## Running the Distribution
|
||||
|
|
@ -72,9 +48,9 @@ docker run \
|
|||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, use:
|
||||
|
|
@ -86,10 +62,10 @@ docker run \
|
|||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
-e SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
|
@ -98,16 +74,16 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
|
|||
|
||||
```bash
|
||||
llama stack build --distro {{ name }} --image-type venv
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
llama stack run distributions/{{ name }}/run.yaml \
|
||||
--port 8321 \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
--port 8321
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, use:
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
|
||||
llama stack run distributions/{{ name }}/run-with-safety.yaml \
|
||||
--port 8321 \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
--port 8321
|
||||
```
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
@ -114,6 +114,9 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/conversations.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
@ -104,6 +104,9 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/conversations.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
|||
|
|
@ -49,22 +49,22 @@ The deployed platform includes the NIM Proxy microservice, which is the service
|
|||
### Datasetio API: NeMo Data Store
|
||||
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Datasetio docs::llama_stack/providers/remote/datasetio/nvidia/README.md` for supported features and example usage.
|
||||
See the [NVIDIA Datasetio docs](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/datasetio/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
### Eval API: NeMo Evaluator
|
||||
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Eval docs::llama_stack/providers/remote/eval/nvidia/README.md` for supported features and example usage.
|
||||
See the [NVIDIA Eval docs](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/eval/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
### Post-Training API: NeMo Customizer
|
||||
The NeMo Customizer microservice supports fine-tuning models. You can reference {repopath}`this list of supported models::llama_stack/providers/remote/post_training/nvidia/models.py` that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
The NeMo Customizer microservice supports fine-tuning models. You can reference [this list of supported models](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/post_training/nvidia/models.py) that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Post-Training docs::llama_stack/providers/remote/post_training/nvidia/README.md` for supported features and example usage.
|
||||
See the [NVIDIA Post-Training docs](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/post_training/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
### Safety API: NeMo Guardrails
|
||||
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Safety docs::llama_stack/providers/remote/safety/nvidia/README.md` for supported features and example usage.
|
||||
See the [NVIDIA Safety docs](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/safety/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
## Deploying models
|
||||
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
|
||||
|
|
@ -118,10 +118,10 @@ docker run \
|
|||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
|
@ -131,11 +131,11 @@ If you've set up your local development environment, you can also build the imag
|
|||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
--port 8321
|
||||
```
|
||||
|
||||
## Example Notebooks
|
||||
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in {repopath}`docs/notebooks/nvidia`.
|
||||
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in [docs/notebooks/nvidia](https://github.com/meta-llama/llama-stack/tree/main/docs/notebooks/nvidia).
|
||||
|
|
|
|||
|
|
@ -7,12 +7,11 @@
|
|||
from pathlib import Path
|
||||
|
||||
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
|
||||
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||
|
||||
|
||||
|
|
@ -68,9 +67,6 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
|
|||
provider_id="nvidia",
|
||||
)
|
||||
|
||||
available_models = {
|
||||
"nvidia": MODEL_ENTRIES,
|
||||
}
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
|
|
@ -78,7 +74,6 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
|
|||
),
|
||||
]
|
||||
|
||||
default_models, _ = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
|
|
@ -86,7 +81,6 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
|
|||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
|
|
@ -95,7 +89,6 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
|
|||
"eval": [eval_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=default_models,
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
"run-with-safety.yaml": RunConfigSettings(
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
@ -103,6 +103,9 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/conversations.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
@ -92,90 +92,10 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta/llama3-8b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama3-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-8b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-405b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-1b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-3b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-11b-vision-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-90b-vision-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.3-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: nvidia/vila
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/vila
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 2048
|
||||
context_length: 8192
|
||||
model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 1024
|
||||
context_length: 512
|
||||
model_id: nvidia/nv-embedqa-e5-v5
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/nv-embedqa-e5-v5
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 4096
|
||||
context_length: 512
|
||||
model_id: nvidia/nv-embedqa-mistral-7b-v2
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 1024
|
||||
context_length: 512
|
||||
model_id: snowflake/arctic-embed-l
|
||||
provider_id: nvidia
|
||||
provider_model_id: snowflake/arctic-embed-l
|
||||
model_type: embedding
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/conversations.db
|
||||
models: []
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
@ -134,6 +134,9 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/conversations.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: gpt-4o
|
||||
|
|
|
|||
|
|
@ -86,6 +86,9 @@ inference_store:
|
|||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/postgres-demo}/conversations.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
|
|
@ -227,6 +227,9 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/conversations.db
|
||||
models: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
|
|
@ -224,6 +224,9 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/conversations.db
|
||||
models: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -181,6 +181,7 @@ class RunConfigSettings(BaseModel):
|
|||
default_benchmarks: list[BenchmarkInput] | None = None
|
||||
metadata_store: dict | None = None
|
||||
inference_store: dict | None = None
|
||||
conversations_store: dict | None = None
|
||||
|
||||
def run_config(
|
||||
self,
|
||||
|
|
@ -240,6 +241,11 @@ class RunConfigSettings(BaseModel):
|
|||
__distro_dir__=f"~/.llama/distributions/{name}",
|
||||
db_name="inference_store.db",
|
||||
),
|
||||
"conversations_store": self.conversations_store
|
||||
or SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=f"~/.llama/distributions/{name}",
|
||||
db_name="conversations.db",
|
||||
),
|
||||
"models": [m.model_dump(exclude_none=True) for m in (self.default_models or [])],
|
||||
"shields": [s.model_dump(exclude_none=True) for s in (self.default_shields or [])],
|
||||
"vector_dbs": [],
|
||||
|
|
|
|||
|
|
@ -3,3 +3,5 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .watsonx import get_distribution_template # noqa: F401
|
||||
|
|
|
|||
|
|
@ -3,44 +3,33 @@ distribution_spec:
|
|||
description: Use watsonx for running LLM inference
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: watsonx
|
||||
provider_type: remote::watsonx
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
- provider_type: remote::watsonx
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
- provider_type: inline::faiss
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- aiosqlite
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ apis:
|
|||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
- files
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: watsonx
|
||||
|
|
@ -19,8 +19,6 @@ providers:
|
|||
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||
api_key: ${env.WATSONX_API_KEY:=}
|
||||
project_id: ${env.WATSONX_PROJECT_ID:=}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
|
|
@ -48,7 +46,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
@ -109,102 +107,10 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-3-70b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-2-13b-chat
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-2-13b-chat
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-2-13b
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-2-13b-chat
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-1-70b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-1-8b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-1b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-3b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-guard-3-11b-vision
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
model_type: embedding
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/conversations.db
|
||||
models: []
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
|
|
|
|||
|
|
@ -4,17 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||
|
|
@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
|||
config=WatsonXConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
available_models = {
|
||||
"watsonx": MODEL_ENTRIES,
|
||||
}
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
|
|
@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
|||
),
|
||||
]
|
||||
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id="sentence-transformers",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
default_models, _ = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="remote_hosted",
|
||||
description="Use watsonx for running LLM inference",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider, embedding_provider],
|
||||
"inference": [inference_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=default_models + [embedding_model],
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -30,8 +30,21 @@ CATEGORIES = [
|
|||
"tools",
|
||||
"client",
|
||||
"telemetry",
|
||||
"openai",
|
||||
"openai_responses",
|
||||
"openai_conversations",
|
||||
"testing",
|
||||
"providers",
|
||||
"models",
|
||||
"files",
|
||||
"vector_io",
|
||||
"tool_runtime",
|
||||
"cli",
|
||||
"post_training",
|
||||
"scoring",
|
||||
"tests",
|
||||
]
|
||||
UNCATEGORIZED = "uncategorized"
|
||||
|
||||
# Initialize category levels with default level
|
||||
_category_levels: dict[str, int] = dict.fromkeys(CATEGORIES, DEFAULT_LOG_LEVEL)
|
||||
|
|
@ -121,7 +134,10 @@ def strip_rich_markup(text):
|
|||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console(width=150)
|
||||
# Set a reasonable default width for console output, especially when redirected to files
|
||||
console_width = int(os.environ.get("LLAMA_STACK_LOG_WIDTH", "120"))
|
||||
# Don't force terminal codes to avoid ANSI escape codes in log files
|
||||
kwargs["console"] = Console(width=console_width)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
|
|
@ -165,7 +181,7 @@ def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None
|
|||
|
||||
def filter(self, record):
|
||||
if not hasattr(record, "category"):
|
||||
record.category = "uncategorized" # Default to 'uncategorized' if no category found
|
||||
record.category = UNCATEGORIZED # Default to 'uncategorized' if no category found
|
||||
return True
|
||||
|
||||
# Determine the root logger's level (default to WARNING if not specified)
|
||||
|
|
@ -247,7 +263,20 @@ def get_logger(
|
|||
_category_levels.update(parse_yaml_config(config))
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
|
||||
if category in _category_levels:
|
||||
log_level = _category_levels[category]
|
||||
else:
|
||||
root_category = category.split("::")[0]
|
||||
if root_category in _category_levels:
|
||||
log_level = _category_levels[root_category]
|
||||
else:
|
||||
if category != UNCATEGORIZED:
|
||||
raise ValueError(
|
||||
f"Unknown logging category: {category}. To resolve, choose a valid category from the CATEGORIES list "
|
||||
f"or add it to the CATEGORIES list. Available categories: {CATEGORIES}"
|
||||
)
|
||||
log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
|
||||
logger.setLevel(log_level)
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -37,14 +37,7 @@ RecursiveType = Primitive | list[Primitive] | dict[str, Primitive]
|
|||
class ToolCall(BaseModel):
|
||||
call_id: str
|
||||
tool_name: BuiltinTool | str
|
||||
# Plan is to deprecate the Dict in favor of a JSON string
|
||||
# that is parsed on the client side instead of trying to manage
|
||||
# the recursive type here.
|
||||
# Making this a union so that client side can start prepping for this change.
|
||||
# Eventually, we will remove both the Dict and arguments_json field,
|
||||
# and arguments will just be a str
|
||||
arguments: str | dict[str, RecursiveType]
|
||||
arguments_json: str | None = None
|
||||
arguments: str
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -88,17 +81,11 @@ class StopReason(Enum):
|
|||
out_of_tokens = "out_of_tokens"
|
||||
|
||||
|
||||
class ToolParamDefinition(BaseModel):
|
||||
param_type: str
|
||||
description: str | None = None
|
||||
required: bool | None = True
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
tool_name: BuiltinTool | str
|
||||
description: str | None = None
|
||||
parameters: dict[str, ToolParamDefinition] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -232,8 +232,7 @@ class ChatFormat:
|
|||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
arguments_json=json.dumps(tool_arguments),
|
||||
arguments=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from typing import Any
|
|||
from llama_stack.apis.inference import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
|
||||
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
||||
|
|
@ -101,11 +100,8 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{%- set tprops = t.input_schema.get('properties', {}) -%}
|
||||
{%- set required_params = t.input_schema.get('required', []) -%}
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
|
@ -114,11 +110,11 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{%- for name, param in tparams.items() %}
|
||||
{%- for name, param in tprops.items() %}
|
||||
{
|
||||
"{{name}}": {
|
||||
"type": "object",
|
||||
"description": "{{param.description}}"
|
||||
"description": "{{param.get('description', '')}}"
|
||||
}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
|
|
@ -143,17 +139,19 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
ToolDefinition(
|
||||
tool_name="trending_songs",
|
||||
description="Returns the trending songs on a Music site",
|
||||
parameters={
|
||||
"n": ToolParamDefinition(
|
||||
param_type="int",
|
||||
description="The number of songs to return",
|
||||
required=True,
|
||||
),
|
||||
"genre": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The genre of the songs to return",
|
||||
required=False,
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"n": {
|
||||
"type": "int",
|
||||
"description": "The number of songs to return",
|
||||
},
|
||||
"genre": {
|
||||
"type": "str",
|
||||
"description": "The genre of the songs to return",
|
||||
},
|
||||
},
|
||||
"required": ["n"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
@ -170,11 +168,14 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
{#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set modified_params = t.parameters.copy() -%}
|
||||
{%- for key, value in modified_params.items() -%}
|
||||
{%- if 'default' in value -%}
|
||||
{%- set _ = value.pop('default', None) -%}
|
||||
{%- set tprops = t.input_schema.get('properties', {}) -%}
|
||||
{%- set modified_params = {} -%}
|
||||
{%- for key, value in tprops.items() -%}
|
||||
{%- set param_copy = value.copy() -%}
|
||||
{%- if 'default' in param_copy -%}
|
||||
{%- set _ = param_copy.pop('default', None) -%}
|
||||
{%- endif -%}
|
||||
{%- set _ = modified_params.update({key: param_copy}) -%}
|
||||
{%- endfor -%}
|
||||
{%- set tparams = modified_params | tojson -%}
|
||||
Use the function '{{ tname }}' to '{{ tdesc }}':
|
||||
|
|
@ -205,17 +206,19 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
ToolDefinition(
|
||||
tool_name="trending_songs",
|
||||
description="Returns the trending songs on a Music site",
|
||||
parameters={
|
||||
"n": ToolParamDefinition(
|
||||
param_type="int",
|
||||
description="The number of songs to return",
|
||||
required=True,
|
||||
),
|
||||
"genre": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The genre of the songs to return",
|
||||
required=False,
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"n": {
|
||||
"type": "int",
|
||||
"description": "The number of songs to return",
|
||||
},
|
||||
"genre": {
|
||||
"type": "str",
|
||||
"description": "The genre of the songs to return",
|
||||
},
|
||||
},
|
||||
"required": ["n"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
@ -255,11 +258,8 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{%- set tprops = (t.input_schema or {}).get('properties', {}) -%}
|
||||
{%- set required_params = (t.input_schema or {}).get('required', []) -%}
|
||||
{
|
||||
"name": "{{tname}}",
|
||||
"description": "{{tdesc}}",
|
||||
|
|
@ -267,11 +267,11 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
"type": "dict",
|
||||
"required": {{ required_params | tojson }},
|
||||
"properties": {
|
||||
{%- for name, param in tparams.items() %}
|
||||
{%- for name, param in tprops.items() %}
|
||||
"{{name}}": {
|
||||
"type": "{{param.param_type}}",
|
||||
"description": "{{param.description}}"{% if param.default %},
|
||||
"default": "{{param.default}}"{% endif %}
|
||||
"type": "{{param.get('type', 'string')}}",
|
||||
"description": "{{param.get('description', '')}}"{% if param.get('default') %},
|
||||
"default": "{{param.get('default')}}"{% endif %}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
}
|
||||
|
|
@ -299,18 +299,20 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get weather info for places",
|
||||
parameters={
|
||||
"city": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The name of the city to get the weather for",
|
||||
required=True,
|
||||
),
|
||||
"metric": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for",
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -220,17 +220,18 @@ class ToolUtils:
|
|||
|
||||
@staticmethod
|
||||
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
||||
args = json.loads(t.arguments)
|
||||
if t.tool_name == BuiltinTool.brave_search:
|
||||
q = t.arguments["query"]
|
||||
q = args["query"]
|
||||
return f'brave_search.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
||||
q = t.arguments["query"]
|
||||
q = args["query"]
|
||||
return f'wolfram_alpha.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.photogen:
|
||||
q = t.arguments["query"]
|
||||
q = args["query"]
|
||||
return f'photogen.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.code_interpreter:
|
||||
return t.arguments["code"]
|
||||
return args["code"]
|
||||
else:
|
||||
fname = t.tool_name
|
||||
|
||||
|
|
@ -239,12 +240,11 @@ class ToolUtils:
|
|||
{
|
||||
"type": "function",
|
||||
"name": fname,
|
||||
"parameters": t.arguments,
|
||||
"parameters": args,
|
||||
}
|
||||
)
|
||||
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
args = json.dumps(t.arguments)
|
||||
return f"<function={fname}>{args}</function>"
|
||||
return f"<function={fname}>{t.arguments}</function>"
|
||||
|
||||
elif tool_prompt_format == ToolPromptFormat.python_list:
|
||||
|
||||
|
|
@ -260,7 +260,7 @@ class ToolUtils:
|
|||
else:
|
||||
raise ValueError(f"Unsupported type: {type(value)}")
|
||||
|
||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in args.items())
|
||||
return f"[{fname}({args_str})]"
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
|
|
@ -184,7 +185,7 @@ def usecases() -> list[UseCase | str]:
|
|||
ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name=BuiltinTool.wolfram_alpha,
|
||||
arguments={"query": "100th decimal of pi"},
|
||||
arguments=json.dumps({"query": "100th decimal of pi"}),
|
||||
)
|
||||
],
|
||||
),
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
|
|
@ -185,7 +186,7 @@ def usecases() -> list[UseCase | str]:
|
|||
ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name=BuiltinTool.wolfram_alpha,
|
||||
arguments={"query": "100th decimal of pi"},
|
||||
arguments=json.dumps({"query": "100th decimal of pi"}),
|
||||
)
|
||||
],
|
||||
),
|
||||
|
|
|
|||
|
|
@ -298,8 +298,7 @@ class ChatFormat:
|
|||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
arguments_json=json.dumps(tool_arguments),
|
||||
arguments=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue