mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore(package): migrate to src/ layout (#3920)
Migrates package structure to src/ layout following Python packaging best practices. All code moved from `llama_stack/` to `src/llama_stack/`. Public API unchanged - imports remain `import llama_stack.*`. Updated build configs, pre-commit hooks, scripts, and GitHub workflows accordingly. All hooks pass, package builds cleanly. **Developer note**: Reinstall after pulling: `pip install -e .`
This commit is contained in:
parent
98a5047f9d
commit
471b1b248b
791 changed files with 2983 additions and 456 deletions
10
src/llama_stack/__init__.py
Normal file
10
src/llama_stack/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.core.library_client import ( # noqa: F401
|
||||
AsyncLlamaStackAsLibraryClient,
|
||||
LlamaStackAsLibraryClient,
|
||||
)
|
||||
5
src/llama_stack/apis/__init__.py
Normal file
5
src/llama_stack/apis/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
7
src/llama_stack/apis/agents/__init__.py
Normal file
7
src/llama_stack/apis/agents/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .agents import *
|
||||
894
src/llama_stack/apis/agents/agents.py
Normal file
894
src/llama_stack/apis/agents/agents.py
Normal file
|
|
@ -0,0 +1,894 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolPromptFormat,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseText,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ResponseGuardrailSpec(BaseModel):
|
||||
"""Specification for a guardrail to apply during response generation.
|
||||
|
||||
:param type: The type/identifier of the guardrail.
|
||||
"""
|
||||
|
||||
type: str
|
||||
# TODO: more fields to be added for guardrail configuration
|
||||
|
||||
|
||||
ResponseGuardrail = str | ResponseGuardrailSpec
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
||||
:param content: The content of the attachment.
|
||||
:param mime_type: The MIME type of the attachment.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""A document to be used by an agent.
|
||||
|
||||
:param content: The content of the document.
|
||||
:param mime_type: The MIME type of the document.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class StepCommon(BaseModel):
|
||||
"""A common step in an agent turn.
|
||||
|
||||
:param turn_id: The ID of the turn.
|
||||
:param step_id: The ID of the step.
|
||||
:param started_at: The time the step started.
|
||||
:param completed_at: The time the step completed.
|
||||
"""
|
||||
|
||||
turn_id: str
|
||||
step_id: str
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
class StepType(StrEnum):
|
||||
"""Type of the step in an agent turn.
|
||||
|
||||
:cvar inference: The step is an inference step that calls an LLM.
|
||||
:cvar tool_execution: The step is a tool execution step that executes a tool call.
|
||||
:cvar shield_call: The step is a shield call step that checks for safety violations.
|
||||
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
|
||||
"""
|
||||
|
||||
inference = "inference"
|
||||
tool_execution = "tool_execution"
|
||||
shield_call = "shield_call"
|
||||
memory_retrieval = "memory_retrieval"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceStep(StepCommon):
|
||||
"""An inference step in an agent turn.
|
||||
|
||||
:param model_response: The response from the LLM.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference] = StepType.inference
|
||||
model_response: CompletionMessage
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolExecutionStep(StepCommon):
|
||||
"""A tool execution step in an agent turn.
|
||||
|
||||
:param tool_calls: The tool calls to execute.
|
||||
:param tool_responses: The tool responses from the tool calls.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
|
||||
tool_calls: list[ToolCall]
|
||||
tool_responses: list[ToolResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldCallStep(StepCommon):
|
||||
"""A shield call step in an agent turn.
|
||||
|
||||
:param violation: The violation from the shield call.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.shield_call] = StepType.shield_call
|
||||
violation: SafetyViolation | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryRetrievalStep(StepCommon):
|
||||
"""A memory retrieval step in an agent turn.
|
||||
|
||||
:param vector_db_ids: The IDs of the vector databases to retrieve context from.
|
||||
:param inserted_context: The context retrieved from the vector databases.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
|
||||
# TODO: should this be List[str]?
|
||||
vector_db_ids: str
|
||||
inserted_context: InterleavedContent
|
||||
|
||||
|
||||
Step = Annotated[
|
||||
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
|
||||
Field(discriminator="step_type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Turn(BaseModel):
|
||||
"""A single turn in an interaction with an Agentic System.
|
||||
|
||||
:param turn_id: Unique identifier for the turn within a session
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param input_messages: List of messages that initiated this turn
|
||||
:param steps: Ordered list of processing steps executed during this turn
|
||||
:param output_message: The model's generated response containing content and metadata
|
||||
:param output_attachments: (Optional) Files or media attached to the agent's response
|
||||
:param started_at: Timestamp when the turn began
|
||||
:param completed_at: (Optional) Timestamp when the turn finished, if completed
|
||||
"""
|
||||
|
||||
turn_id: str
|
||||
session_id: str
|
||||
input_messages: list[UserMessage | ToolResponseMessage]
|
||||
steps: list[Step]
|
||||
output_message: CompletionMessage
|
||||
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
|
||||
|
||||
started_at: datetime
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Session(BaseModel):
|
||||
"""A single session of an interaction with an Agentic System.
|
||||
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param session_name: Human-readable name for the session
|
||||
:param turns: List of all turns that have occurred in this session
|
||||
:param started_at: Timestamp when the session was created
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
session_name: str
|
||||
turns: list[Turn]
|
||||
started_at: datetime
|
||||
|
||||
|
||||
class AgentToolGroupWithArgs(BaseModel):
|
||||
name: str
|
||||
args: dict[str, Any]
|
||||
|
||||
|
||||
AgentToolGroup = str | AgentToolGroupWithArgs
|
||||
register_schema(AgentToolGroup, name="AgentTool")
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_config: ToolConfig | None = Field(default=None)
|
||||
|
||||
max_infer_iters: int | None = 10
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if self.tool_config:
|
||||
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
|
||||
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
||||
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
||||
else:
|
||||
params = {}
|
||||
if self.tool_choice:
|
||||
params["tool_choice"] = self.tool_choice
|
||||
if self.tool_prompt_format:
|
||||
params["tool_prompt_format"] = self.tool_prompt_format
|
||||
self.tool_config = ToolConfig(**params)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentConfig(AgentConfigCommon):
|
||||
"""Configuration for an agent.
|
||||
|
||||
:param model: The model identifier to use for the agent
|
||||
:param instructions: The system instructions for the agent
|
||||
:param name: Optional name for the agent, used in telemetry and identification
|
||||
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
|
||||
:param response_format: Optional response format configuration
|
||||
"""
|
||||
|
||||
model: str
|
||||
instructions: str
|
||||
name: str | None = None
|
||||
enable_session_persistence: bool | None = False
|
||||
response_format: ResponseFormat | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Agent(BaseModel):
|
||||
"""An agent instance with configuration and metadata.
|
||||
|
||||
:param agent_id: Unique identifier for the agent
|
||||
:param agent_config: Configuration settings for the agent
|
||||
:param created_at: Timestamp when the agent was created
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
agent_config: AgentConfig
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: str | None = None
|
||||
|
||||
|
||||
class AgentTurnResponseEventType(StrEnum):
|
||||
step_start = "step_start"
|
||||
step_complete = "step_complete"
|
||||
step_progress = "step_progress"
|
||||
|
||||
turn_start = "turn_start"
|
||||
turn_complete = "turn_complete"
|
||||
turn_awaiting_input = "turn_awaiting_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||
"""Payload for step start events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param step_type: Type of step being executed
|
||||
:param step_id: Unique identifier for the step within a turn
|
||||
:param metadata: (Optional) Additional metadata for the step
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||
"""Payload for step completion events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param step_type: Type of step being executed
|
||||
:param step_id: Unique identifier for the step within a turn
|
||||
:param step_details: Complete details of the executed step
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
step_details: Step
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||
"""Payload for step progress events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param step_type: Type of step being executed
|
||||
:param step_id: Unique identifier for the step within a turn
|
||||
:param delta: Incremental content changes during step execution
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
delta: ContentDelta
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||
"""Payload for turn start events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param turn_id: Unique identifier for the turn within a session
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
|
||||
turn_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||
"""Payload for turn completion events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param turn: Complete turn data including all steps and results
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
|
||||
turn: Turn
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||
"""Payload for turn awaiting input events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param turn: Turn data when waiting for external tool responses
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
|
||||
turn: Turn
|
||||
|
||||
|
||||
AgentTurnResponseEventPayload = Annotated[
|
||||
AgentTurnResponseStepStartPayload
|
||||
| AgentTurnResponseStepProgressPayload
|
||||
| AgentTurnResponseStepCompletePayload
|
||||
| AgentTurnResponseTurnStartPayload
|
||||
| AgentTurnResponseTurnCompletePayload
|
||||
| AgentTurnResponseTurnAwaitingInputPayload,
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseEvent(BaseModel):
|
||||
"""An event in an agent turn response stream.
|
||||
|
||||
:param payload: Event-specific payload containing event data
|
||||
"""
|
||||
|
||||
payload: AgentTurnResponseEventPayload
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentCreateResponse(BaseModel):
|
||||
"""Response returned when creating a new agent.
|
||||
|
||||
:param agent_id: Unique identifier for the created agent
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentSessionCreateResponse(BaseModel):
|
||||
"""Response returned when creating a new agent session.
|
||||
|
||||
:param session_id: Unique identifier for the created session
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||
"""Request to create a new turn for an agent.
|
||||
|
||||
:param agent_id: Unique identifier for the agent
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param messages: List of messages to start the turn with
|
||||
:param documents: (Optional) List of documents to provide to the agent
|
||||
:param toolgroups: (Optional) List of tool groups to make available for this turn
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
:param tool_config: (Optional) Tool configuration to override agent defaults
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
session_id: str
|
||||
|
||||
# TODO: figure out how we can simplify this and make why
|
||||
# ToolResponseMessage needs to be here (it is function call
|
||||
# execution from outside the system)
|
||||
messages: list[UserMessage | ToolResponseMessage]
|
||||
|
||||
documents: list[Document] | None = None
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
|
||||
stream: bool | None = False
|
||||
tool_config: ToolConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResumeRequest(BaseModel):
|
||||
"""Request to resume an agent turn with tool responses.
|
||||
|
||||
:param agent_id: Unique identifier for the agent
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param turn_id: Unique identifier for the turn within a session
|
||||
:param tool_responses: List of tool responses to submit to continue the turn
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: list[ToolResponse]
|
||||
stream: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStreamChunk(BaseModel):
|
||||
"""Streamed agent turn completion response.
|
||||
|
||||
:param event: Individual event in the agent turn response stream
|
||||
"""
|
||||
|
||||
event: AgentTurnResponseEvent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentStepResponse(BaseModel):
|
||||
"""Response containing details of a specific agent step.
|
||||
|
||||
:param step: The complete step data and execution details
|
||||
"""
|
||||
|
||||
step: Step
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agents(Protocol):
|
||||
"""Agents
|
||||
|
||||
APIs for creating and interacting with agentic systems."""
|
||||
|
||||
@webmethod(
|
||||
route="/agents",
|
||||
method="POST",
|
||||
descriptive_name="create_agent",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents",
|
||||
method="POST",
|
||||
descriptive_name="create_agent",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse:
|
||||
"""Create an agent with the given configuration.
|
||||
|
||||
:param agent_config: The configuration for the agent.
|
||||
:returns: An AgentCreateResponse with the agent ID.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_turn",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_turn",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||
"""Create a new turn for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the turn for.
|
||||
:param session_id: The ID of the session to create the turn for.
|
||||
:param messages: List of messages to start the turn with.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param documents: (Optional) List of documents to create the turn with.
|
||||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
||||
:returns: If stream=False, returns a Turn object.
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
descriptive_name="resume_agent_turn",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
descriptive_name="resume_agent_turn",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: list[ToolResponse],
|
||||
stream: bool | None = False,
|
||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||
|
||||
:param agent_id: The ID of the agent to resume.
|
||||
:param session_id: The ID of the session to resume.
|
||||
:param turn_id: The ID of the turn to resume.
|
||||
:param tool_responses: The tool call responses to resume the turn with.
|
||||
:param stream: Whether to stream the response.
|
||||
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
) -> Turn:
|
||||
"""Retrieve an agent turn by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to get the turn for.
|
||||
:param session_id: The ID of the session to get the turn for.
|
||||
:param turn_id: The ID of the turn to get.
|
||||
:returns: A Turn.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_step(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
step_id: str,
|
||||
) -> AgentStepResponse:
|
||||
"""Retrieve an agent step by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to get the step for.
|
||||
:param session_id: The ID of the session to get the step for.
|
||||
:param turn_id: The ID of the turn to get the step for.
|
||||
:param step_id: The ID of the step to get.
|
||||
:returns: An AgentStepResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_session",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_session",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
"""Create a new session for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the session for.
|
||||
:param session_name: The name of the session to create.
|
||||
:returns: An AgentSessionCreateResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
"""Retrieve an agent session by its ID.
|
||||
|
||||
:param session_id: The ID of the session to get.
|
||||
:param agent_id: The ID of the agent to get the session for.
|
||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
||||
:returns: A Session.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="DELETE",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def delete_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Delete an agent session by its ID and its associated turns.
|
||||
|
||||
:param session_id: The ID of the session to delete.
|
||||
:param agent_id: The ID of the agent to delete the session for.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}",
|
||||
method="DELETE",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def delete_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Delete an agent by its ID and its associated sessions and turns.
|
||||
|
||||
:param agent_id: The ID of the agent to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||
"""List all agents.
|
||||
|
||||
:param start_index: The index to start the pagination from.
|
||||
:param limit: The number of agents to return.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
"""Describe an agent by its ID.
|
||||
|
||||
:param agent_id: ID of the agent.
|
||||
:returns: An Agent of the agent.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/sessions",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""List all session(s) of a given agent.
|
||||
|
||||
:param agent_id: The ID of the agent to list sessions for.
|
||||
:param start_index: The index to start the pagination from.
|
||||
:param limit: The number of sessions to return.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
# We situate the OpenAI Responses API in the Agents API just like we did things
|
||||
# for Inference. The Responses API, in its intent, serves the same purpose as
|
||||
# the Agents API above -- it is essentially a lightweight "agentic loop" with
|
||||
# integrated tool calling.
|
||||
#
|
||||
# Both of these APIs are inherently stateful.
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/responses/{response_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_openai_response(
|
||||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
"""Get a model response.
|
||||
|
||||
:param response_id: The ID of the OpenAI response to retrieve.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
guardrails: Annotated[
|
||||
list[ResponseGuardrail] | None,
|
||||
ExtraBodyField(
|
||||
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
|
||||
),
|
||||
] = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a model response.
|
||||
|
||||
:param input: Input message(s) to create the response.
|
||||
:param model: The underlying LLM used for completions.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
|
||||
:param include: (Optional) Additional fields to include in the response.
|
||||
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
"""List all responses.
|
||||
|
||||
:param after: The ID of the last response to return.
|
||||
:param limit: The number of responses to return.
|
||||
:param model: The model to filter responses by.
|
||||
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
|
||||
:returns: A ListOpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""List input items.
|
||||
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
:param before: An item ID to list items before, used for pagination.
|
||||
:param include: Additional fields to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
||||
:param order: The order to return the input items in. Default is desc.
|
||||
:returns: An ListOpenAIResponseInputItem.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
"""Delete a response.
|
||||
|
||||
:param response_id: The ID of the OpenAI response to delete.
|
||||
:returns: An OpenAIDeleteResponseObject
|
||||
"""
|
||||
...
|
||||
1306
src/llama_stack/apis/agents/openai_responses.py
Normal file
1306
src/llama_stack/apis/agents/openai_responses.py
Normal file
File diff suppressed because it is too large
Load diff
9
src/llama_stack/apis/batches/__init__.py
Normal file
9
src/llama_stack/apis/batches/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||
|
||||
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
|
||||
100
src/llama_stack/apis/batches/batches.py
Normal file
100
src/llama_stack/apis/batches/batches.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
try:
|
||||
from openai.types import Batch as BatchObject
|
||||
except ImportError as e:
|
||||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[BatchObject] = Field(..., description="List of batch objects")
|
||||
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
||||
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Batches(Protocol):
|
||||
"""
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
The API is designed to allow use of openai client libraries for seamless integration.
|
||||
|
||||
This API provides the following extensions:
|
||||
- idempotent batch creation
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> BatchObject:
|
||||
"""Create a new batch for processing multiple API requests.
|
||||
|
||||
:param input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||
:param endpoint: The endpoint to be used for all requests in the batch.
|
||||
:param completion_window: The time window within which the batch should be processed.
|
||||
:param metadata: Optional metadata for the batch.
|
||||
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
|
||||
:returns: The created batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch.
|
||||
|
||||
:param batch_id: The ID of the batch to retrieve.
|
||||
:returns: The batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress.
|
||||
|
||||
:param batch_id: The ID of the batch to cancel.
|
||||
:returns: The updated batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListBatchesResponse:
|
||||
"""List all batches for the current user.
|
||||
|
||||
:param after: A cursor for pagination; returns batches after this batch ID.
|
||||
:param limit: Number of batches to return (default 20, max 100).
|
||||
:returns: A list of batch objects.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/benchmarks/__init__.py
Normal file
7
src/llama_stack/apis/benchmarks/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .benchmarks import *
|
||||
108
src/llama_stack/apis/benchmarks/benchmarks.py
Normal file
108
src/llama_stack/apis/benchmarks/benchmarks.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class CommonBenchmarkFields(BaseModel):
|
||||
dataset_id: str
|
||||
scoring_functions: list[str]
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Metadata for this evaluation task",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Benchmark(CommonBenchmarkFields, Resource):
|
||||
"""A benchmark resource for evaluating model performance.
|
||||
|
||||
:param dataset_id: Identifier of the dataset to use for the benchmark evaluation
|
||||
:param scoring_functions: List of scoring function identifiers to apply during evaluation
|
||||
:param metadata: Metadata for this evaluation task
|
||||
:param type: The resource type, always benchmark
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.benchmark] = ResourceType.benchmark
|
||||
|
||||
@property
|
||||
def benchmark_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_benchmark_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
||||
benchmark_id: str
|
||||
provider_id: str | None = None
|
||||
provider_benchmark_id: str | None = None
|
||||
|
||||
|
||||
class ListBenchmarksResponse(BaseModel):
|
||||
data: list[Benchmark]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Benchmarks(Protocol):
|
||||
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||
"""List all benchmarks.
|
||||
|
||||
:returns: A ListBenchmarksResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
) -> Benchmark:
|
||||
"""Get a benchmark by its ID.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to get.
|
||||
:returns: A Benchmark.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def register_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: list[str],
|
||||
provider_benchmark_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Register a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to register.
|
||||
:param dataset_id: The ID of the dataset to use for the benchmark.
|
||||
:param scoring_functions: The scoring functions to use for the benchmark.
|
||||
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
|
||||
:param provider_id: The ID of the provider to use for the benchmark.
|
||||
:param metadata: The metadata to use for the benchmark.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
"""Unregister a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to unregister.
|
||||
"""
|
||||
...
|
||||
5
src/llama_stack/apis/common/__init__.py
Normal file
5
src/llama_stack/apis/common/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
143
src/llama_stack/apis/common/content_types.py
Normal file
143
src/llama_stack/apis/common/content_types.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolCall
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class URL(BaseModel):
|
||||
"""A URL reference to external content.
|
||||
|
||||
:param uri: The URL string pointing to the resource
|
||||
"""
|
||||
|
||||
uri: str
|
||||
|
||||
|
||||
class _URLOrData(BaseModel):
|
||||
"""
|
||||
A URL or a base64 encoded string
|
||||
|
||||
:param url: A URL of the image or data URL in the format of data:image/{type};base64,{data}. Note that URL could have length limits.
|
||||
:param data: base64 encoded image data as string
|
||||
"""
|
||||
|
||||
url: URL | None = None
|
||||
# data is a base64 encoded string, hint with contentEncoding=base64
|
||||
data: str | None = Field(default=None, json_schema_extra={"contentEncoding": "base64"})
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validator(cls, values):
|
||||
if isinstance(values, dict):
|
||||
return values
|
||||
return {"url": values}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ImageContentItem(BaseModel):
|
||||
"""A image content item
|
||||
|
||||
:param type: Discriminator type of the content item. Always "image"
|
||||
:param image: Image as a base64 encoded string or an URL
|
||||
"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
image: _URLOrData
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TextContentItem(BaseModel):
|
||||
"""A text content item
|
||||
|
||||
:param type: Discriminator type of the content item. Always "text"
|
||||
:param text: Text content
|
||||
"""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
# other modalities can be added here
|
||||
InterleavedContentItem = Annotated[
|
||||
ImageContentItem | TextContentItem,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||
|
||||
# accept a single "str" as a special case since it is common
|
||||
InterleavedContent = str | InterleavedContentItem | list[InterleavedContentItem]
|
||||
register_schema(InterleavedContent, name="InterleavedContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TextDelta(BaseModel):
|
||||
"""A text content delta for streaming responses.
|
||||
|
||||
:param type: Discriminator type of the delta. Always "text"
|
||||
:param text: The incremental text content
|
||||
"""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ImageDelta(BaseModel):
|
||||
"""An image content delta for streaming responses.
|
||||
|
||||
:param type: Discriminator type of the delta. Always "image"
|
||||
:param image: The incremental image data as bytes
|
||||
"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
image: bytes
|
||||
|
||||
|
||||
class ToolCallParseStatus(Enum):
|
||||
"""Status of tool call parsing during streaming.
|
||||
:cvar started: Tool call parsing has begun
|
||||
:cvar in_progress: Tool call parsing is ongoing
|
||||
:cvar failed: Tool call parsing failed
|
||||
:cvar succeeded: Tool call parsing completed successfully
|
||||
"""
|
||||
|
||||
started = "started"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
succeeded = "succeeded"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolCallDelta(BaseModel):
|
||||
"""A tool call content delta for streaming responses.
|
||||
|
||||
:param type: Discriminator type of the delta. Always "tool_call"
|
||||
:param tool_call: Either an in-progress tool call string or the final parsed tool call
|
||||
:param parse_status: Current parsing status of the tool call
|
||||
"""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
|
||||
# you either send an in-progress tool call so the client can stream a long
|
||||
# code generation or you send the final parsed tool call at the end of the
|
||||
# stream
|
||||
tool_call: str | ToolCall
|
||||
parse_status: ToolCallParseStatus
|
||||
|
||||
|
||||
# streaming completions send a stream of ContentDeltas
|
||||
ContentDelta = Annotated[
|
||||
TextDelta | ImageDelta | ToolCallDelta,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ContentDelta, name="ContentDelta")
|
||||
103
src/llama_stack/apis/common/errors.py
Normal file
103
src/llama_stack/apis/common/errors.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Custom Llama Stack Exception classes should follow the following schema
|
||||
# 1. All classes should inherit from an existing Built-In Exception class: https://docs.python.org/3/library/exceptions.html
|
||||
# 2. All classes should have a custom error message with the goal of informing the Llama Stack user specifically
|
||||
# 3. All classes should propogate the inherited __init__ function otherwise via 'super().__init__(message)'
|
||||
|
||||
|
||||
class ResourceNotFoundError(ValueError):
|
||||
"""generic exception for a missing Llama Stack resource"""
|
||||
|
||||
def __init__(self, resource_name: str, resource_type: str, client_list: str) -> None:
|
||||
message = (
|
||||
f"{resource_type} '{resource_name}' not found. Use '{client_list}' to list available {resource_type}s."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UnsupportedModelError(ValueError):
|
||||
"""raised when model is not present in the list of supported models"""
|
||||
|
||||
def __init__(self, model_name: str, supported_models_list: list[str]):
|
||||
message = f"'{model_name}' model is not supported. Supported models are: {', '.join(supported_models_list)}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced model"""
|
||||
|
||||
def __init__(self, model_name: str) -> None:
|
||||
super().__init__(model_name, "Model", "client.models.list()")
|
||||
|
||||
|
||||
class VectorStoreNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced vector store"""
|
||||
|
||||
def __init__(self, vector_store_name: str) -> None:
|
||||
super().__init__(vector_store_name, "Vector Store", "client.vector_dbs.list()")
|
||||
|
||||
|
||||
class DatasetNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced dataset"""
|
||||
|
||||
def __init__(self, dataset_name: str) -> None:
|
||||
super().__init__(dataset_name, "Dataset", "client.datasets.list()")
|
||||
|
||||
|
||||
class ToolGroupNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced tool group"""
|
||||
|
||||
def __init__(self, toolgroup_name: str) -> None:
|
||||
super().__init__(toolgroup_name, "Tool Group", "client.toolgroups.list()")
|
||||
|
||||
|
||||
class SessionNotFoundError(ValueError):
|
||||
"""raised when Llama Stack cannot find a referenced session or access is denied"""
|
||||
|
||||
def __init__(self, session_name: str) -> None:
|
||||
message = f"Session '{session_name}' not found or access denied."
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelTypeError(TypeError):
|
||||
"""raised when a model is present but not the correct type"""
|
||||
|
||||
def __init__(self, model_name: str, model_type: str, expected_model_type: str) -> None:
|
||||
message = (
|
||||
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ConflictError(ValueError):
|
||||
"""raised when an operation cannot be performed due to a conflict with the current state"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class TokenValidationError(ValueError):
|
||||
"""raised when token validation fails during authentication"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ConversationNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced conversation"""
|
||||
|
||||
def __init__(self, conversation_id: str) -> None:
|
||||
super().__init__(conversation_id, "Conversation", "client.conversations.list()")
|
||||
|
||||
|
||||
class InvalidConversationIdError(ValueError):
|
||||
"""raised when a conversation ID has an invalid format"""
|
||||
|
||||
def __init__(self, conversation_id: str) -> None:
|
||||
message = f"Invalid conversation ID '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||
super().__init__(message)
|
||||
38
src/llama_stack/apis/common/job_types.py
Normal file
38
src/llama_stack/apis/common/job_types.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
"""Status of a job execution.
|
||||
:cvar completed: Job has finished successfully
|
||||
:cvar in_progress: Job is currently running
|
||||
:cvar failed: Job has failed during execution
|
||||
:cvar scheduled: Job is scheduled but not yet started
|
||||
:cvar cancelled: Job was cancelled before completion
|
||||
"""
|
||||
|
||||
completed = "completed"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
scheduled = "scheduled"
|
||||
cancelled = "cancelled"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Job(BaseModel):
|
||||
"""A job execution instance with status tracking.
|
||||
|
||||
:param job_id: Unique identifier for the job
|
||||
:param status: Current execution status of the job
|
||||
"""
|
||||
|
||||
job_id: str
|
||||
status: JobStatus
|
||||
36
src/llama_stack/apis/common/responses.py
Normal file
36
src/llama_stack/apis/common/responses.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class Order(Enum):
|
||||
"""Sort order for paginated responses.
|
||||
:cvar asc: Ascending order
|
||||
:cvar desc: Descending order
|
||||
"""
|
||||
|
||||
asc = "asc"
|
||||
desc = "desc"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""A generic paginated response that follows a simple format.
|
||||
|
||||
:param data: The list of items for the current page
|
||||
:param has_more: Whether there are more items available after this set
|
||||
:param url: The URL for accessing this list
|
||||
"""
|
||||
|
||||
data: list[dict[str, Any]]
|
||||
has_more: bool
|
||||
url: str | None = None
|
||||
47
src/llama_stack/apis/common/training_types.py
Normal file
47
src/llama_stack/apis/common/training_types.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingMetric(BaseModel):
|
||||
"""Training metrics captured during post-training jobs.
|
||||
|
||||
:param epoch: Training epoch number
|
||||
:param train_loss: Loss value on the training dataset
|
||||
:param validation_loss: Loss value on the validation dataset
|
||||
:param perplexity: Perplexity metric indicating model confidence
|
||||
"""
|
||||
|
||||
epoch: int
|
||||
train_loss: float
|
||||
validation_loss: float
|
||||
perplexity: float
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Checkpoint(BaseModel):
|
||||
"""Checkpoint created during training runs.
|
||||
|
||||
:param identifier: Unique identifier for the checkpoint
|
||||
:param created_at: Timestamp when the checkpoint was created
|
||||
:param epoch: Training epoch when the checkpoint was saved
|
||||
:param post_training_job_id: Identifier of the training job that created this checkpoint
|
||||
:param path: File system path where the checkpoint is stored
|
||||
:param training_metrics: (Optional) Training metrics associated with this checkpoint
|
||||
"""
|
||||
|
||||
identifier: str
|
||||
created_at: datetime
|
||||
epoch: int
|
||||
post_training_job_id: str
|
||||
path: str
|
||||
training_metrics: PostTrainingMetric | None = None
|
||||
158
src/llama_stack/apis/common/type_system.py
Normal file
158
src/llama_stack/apis/common/type_system.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StringType(BaseModel):
|
||||
"""Parameter type for string values.
|
||||
|
||||
:param type: Discriminator type. Always "string"
|
||||
"""
|
||||
|
||||
type: Literal["string"] = "string"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NumberType(BaseModel):
|
||||
"""Parameter type for numeric values.
|
||||
|
||||
:param type: Discriminator type. Always "number"
|
||||
"""
|
||||
|
||||
type: Literal["number"] = "number"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BooleanType(BaseModel):
|
||||
"""Parameter type for boolean values.
|
||||
|
||||
:param type: Discriminator type. Always "boolean"
|
||||
"""
|
||||
|
||||
type: Literal["boolean"] = "boolean"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ArrayType(BaseModel):
|
||||
"""Parameter type for array values.
|
||||
|
||||
:param type: Discriminator type. Always "array"
|
||||
"""
|
||||
|
||||
type: Literal["array"] = "array"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ObjectType(BaseModel):
|
||||
"""Parameter type for object values.
|
||||
|
||||
:param type: Discriminator type. Always "object"
|
||||
"""
|
||||
|
||||
type: Literal["object"] = "object"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class JsonType(BaseModel):
|
||||
"""Parameter type for JSON values.
|
||||
|
||||
:param type: Discriminator type. Always "json"
|
||||
"""
|
||||
|
||||
type: Literal["json"] = "json"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UnionType(BaseModel):
|
||||
"""Parameter type for union values.
|
||||
|
||||
:param type: Discriminator type. Always "union"
|
||||
"""
|
||||
|
||||
type: Literal["union"] = "union"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionInputType(BaseModel):
|
||||
"""Parameter type for chat completion input.
|
||||
|
||||
:param type: Discriminator type. Always "chat_completion_input"
|
||||
"""
|
||||
|
||||
# expects List[Message] for messages
|
||||
type: Literal["chat_completion_input"] = "chat_completion_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionInputType(BaseModel):
|
||||
"""Parameter type for completion input.
|
||||
|
||||
:param type: Discriminator type. Always "completion_input"
|
||||
"""
|
||||
|
||||
# expects InterleavedTextMedia for content
|
||||
type: Literal["completion_input"] = "completion_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnInputType(BaseModel):
|
||||
"""Parameter type for agent turn input.
|
||||
|
||||
:param type: Discriminator type. Always "agent_turn_input"
|
||||
"""
|
||||
|
||||
# expects List[Message] for messages (may also include attachments?)
|
||||
type: Literal["agent_turn_input"] = "agent_turn_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DialogType(BaseModel):
|
||||
"""Parameter type for dialog data with semantic output labels.
|
||||
|
||||
:param type: Discriminator type. Always "dialog"
|
||||
"""
|
||||
|
||||
# expects List[Message] for messages
|
||||
# this type semantically contains the output label whereas ChatCompletionInputType does not
|
||||
type: Literal["dialog"] = "dialog"
|
||||
|
||||
|
||||
ParamType = Annotated[
|
||||
StringType
|
||||
| NumberType
|
||||
| BooleanType
|
||||
| ArrayType
|
||||
| ObjectType
|
||||
| JsonType
|
||||
| UnionType
|
||||
| ChatCompletionInputType
|
||||
| CompletionInputType
|
||||
| AgentTurnInputType,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ParamType, name="ParamType")
|
||||
|
||||
"""
|
||||
# TODO: recursive definition of ParamType in these containers
|
||||
# will cause infinite recursion in OpenAPI generation script
|
||||
# since we are going with ChatCompletionInputType and CompletionInputType
|
||||
# we don't need to worry about ArrayType/ObjectType/UnionType for now
|
||||
ArrayType.model_rebuild()
|
||||
ObjectType.model_rebuild()
|
||||
UnionType.model_rebuild()
|
||||
|
||||
|
||||
class CustomType(BaseModel):
|
||||
pylint: disable=syntax-error
|
||||
type: Literal["custom"] = "custom"
|
||||
validator_class: str
|
||||
"""
|
||||
31
src/llama_stack/apis/conversations/__init__.py
Normal file
31
src/llama_stack/apis/conversations/__init__.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .conversations import (
|
||||
Conversation,
|
||||
ConversationCreateRequest,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemCreateRequest,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
ConversationUpdateRequest,
|
||||
Metadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Conversation",
|
||||
"ConversationCreateRequest",
|
||||
"ConversationDeletedResource",
|
||||
"ConversationItem",
|
||||
"ConversationItemCreateRequest",
|
||||
"ConversationItemDeletedResource",
|
||||
"ConversationItemList",
|
||||
"Conversations",
|
||||
"ConversationUpdateRequest",
|
||||
"Metadata",
|
||||
]
|
||||
298
src/llama_stack/apis/conversations/conversations.py
Normal file
298
src/llama_stack/apis/conversations/conversations.py
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
Metadata = dict[str, str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Conversation(BaseModel):
|
||||
"""OpenAI-compatible conversation object."""
|
||||
|
||||
id: str = Field(..., description="The unique ID of the conversation.")
|
||||
object: Literal["conversation"] = Field(
|
||||
default="conversation", description="The object type, which is always conversation."
|
||||
)
|
||||
created_at: int = Field(
|
||||
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
|
||||
)
|
||||
metadata: Metadata | None = Field(
|
||||
default=None,
|
||||
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
|
||||
)
|
||||
items: list[dict] | None = Field(
|
||||
default=None,
|
||||
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationMessage(BaseModel):
|
||||
"""OpenAI-compatible message item for conversations."""
|
||||
|
||||
id: str = Field(..., description="unique identifier for this message")
|
||||
content: list[dict] = Field(..., description="message content")
|
||||
role: str = Field(..., description="message role")
|
||||
status: str = Field(..., description="message status")
|
||||
type: Literal["message"] = "message"
|
||||
object: Literal["message"] = "message"
|
||||
|
||||
|
||||
ConversationItem = Annotated[
|
||||
OpenAIResponseMessage
|
||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseInputFunctionToolCallOutput
|
||||
| OpenAIResponseMCPApprovalRequest
|
||||
| OpenAIResponseMCPApprovalResponse
|
||||
| OpenAIResponseOutputMessageMCPCall
|
||||
| OpenAIResponseOutputMessageMCPListTools
|
||||
| OpenAIResponseOutputMessageMCPCall
|
||||
| OpenAIResponseOutputMessageMCPListTools,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ConversationItem, name="ConversationItem")
|
||||
|
||||
# Using OpenAI types directly caused issues but some notes for reference:
|
||||
# Note that ConversationItem is a Annotated Union of the types below:
|
||||
# from openai.types.responses import *
|
||||
# from openai.types.responses.response_item import *
|
||||
# from openai.types.conversations import ConversationItem
|
||||
# f = [
|
||||
# ResponseFunctionToolCallItem,
|
||||
# ResponseFunctionToolCallOutputItem,
|
||||
# ResponseFileSearchToolCall,
|
||||
# ResponseFunctionWebSearch,
|
||||
# ImageGenerationCall,
|
||||
# ResponseComputerToolCall,
|
||||
# ResponseComputerToolCallOutputItem,
|
||||
# ResponseReasoningItem,
|
||||
# ResponseCodeInterpreterToolCall,
|
||||
# LocalShellCall,
|
||||
# LocalShellCallOutput,
|
||||
# McpListTools,
|
||||
# McpApprovalRequest,
|
||||
# McpApprovalResponse,
|
||||
# McpCall,
|
||||
# ResponseCustomToolCall,
|
||||
# ResponseCustomToolCallOutput
|
||||
# ]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationCreateRequest(BaseModel):
|
||||
"""Request body for creating a conversation."""
|
||||
|
||||
items: list[ConversationItem] | None = Field(
|
||||
default=[],
|
||||
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
max_length=20,
|
||||
)
|
||||
metadata: Metadata | None = Field(
|
||||
default={},
|
||||
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
|
||||
max_length=16,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationUpdateRequest(BaseModel):
|
||||
"""Request body for updating a conversation."""
|
||||
|
||||
metadata: Metadata = Field(
|
||||
...,
|
||||
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationDeletedResource(BaseModel):
|
||||
"""Response for deleted conversation."""
|
||||
|
||||
id: str = Field(..., description="The deleted conversation identifier")
|
||||
object: str = Field(default="conversation.deleted", description="Object type")
|
||||
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemCreateRequest(BaseModel):
|
||||
"""Request body for creating conversation items."""
|
||||
|
||||
items: list[ConversationItem] = Field(
|
||||
...,
|
||||
description="Items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
max_length=20,
|
||||
)
|
||||
|
||||
|
||||
class ConversationItemInclude(StrEnum):
|
||||
"""
|
||||
Specify additional output data to include in the model response.
|
||||
"""
|
||||
|
||||
web_search_call_action_sources = "web_search_call.action.sources"
|
||||
code_interpreter_call_outputs = "code_interpreter_call.outputs"
|
||||
computer_call_output_output_image_url = "computer_call_output.output.image_url"
|
||||
file_search_call_results = "file_search_call.results"
|
||||
message_input_image_image_url = "message.input_image.image_url"
|
||||
message_output_text_logprobs = "message.output_text.logprobs"
|
||||
reasoning_encrypted_content = "reasoning.encrypted_content"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemList(BaseModel):
|
||||
"""List of conversation items with pagination."""
|
||||
|
||||
object: str = Field(default="list", description="Object type")
|
||||
data: list[ConversationItem] = Field(..., description="List of conversation items")
|
||||
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
|
||||
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more items available")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemDeletedResource(BaseModel):
|
||||
"""Response for deleted conversation item."""
|
||||
|
||||
id: str = Field(..., description="The deleted item identifier")
|
||||
object: str = Field(default="conversation.item.deleted", description="Object type")
|
||||
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Conversations(Protocol):
|
||||
"""Conversations
|
||||
|
||||
Protocol for conversation management operations."""
|
||||
|
||||
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_conversation(
|
||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||
) -> Conversation:
|
||||
"""Create a conversation.
|
||||
|
||||
Create a conversation.
|
||||
|
||||
:param items: Initial items to include in the conversation context.
|
||||
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||
:returns: The created conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Retrieve a conversation.
|
||||
|
||||
Get a conversation with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:returns: The conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||
"""Update a conversation.
|
||||
|
||||
Update a conversation's metadata with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||
:returns: The updated conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||
"""Delete a conversation.
|
||||
|
||||
Delete a conversation with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:returns: The deleted conversation resource.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||
"""Create items.
|
||||
|
||||
Create items in the conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param items: Items to include in the conversation context.
|
||||
:returns: List of created items.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||
"""Retrieve an item.
|
||||
|
||||
Retrieve a conversation item.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param item_id: The item identifier.
|
||||
:returns: The conversation item.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_items(
|
||||
self,
|
||||
conversation_id: str,
|
||||
after: str | None = None,
|
||||
include: list[ConversationItemInclude] | None = None,
|
||||
limit: int | None = None,
|
||||
order: Literal["asc", "desc"] | None = None,
|
||||
) -> ConversationItemList:
|
||||
"""List items.
|
||||
|
||||
List items in the conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param after: An item ID to list items after, used in pagination.
|
||||
:param include: Specify additional output data to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned (1-100, default 20).
|
||||
:param order: The order to return items in (asc or desc, default desc).
|
||||
:returns: List of conversation items.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_conversation_item(
|
||||
self, conversation_id: str, item_id: str
|
||||
) -> ConversationItemDeletedResource:
|
||||
"""Delete an item.
|
||||
|
||||
Delete a conversation item.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param item_id: The item identifier.
|
||||
:returns: The deleted item resource.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/datasetio/__init__.py
Normal file
7
src/llama_stack/apis/datasetio/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .datasetio import *
|
||||
59
src/llama_stack/apis/datasetio/datasetio.py
Normal file
59
src/llama_stack/apis/datasetio/datasetio.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class DatasetStore(Protocol):
|
||||
def get_dataset(self, dataset_id: str) -> Dataset: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class DatasetIO(Protocol):
|
||||
# keeping for aligning with inference/safety, but this is not used
|
||||
dataset_store: DatasetStore
|
||||
|
||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""Get a paginated list of rows from a dataset.
|
||||
|
||||
Uses offset-based pagination where:
|
||||
- start_index: The starting index (0-based). If None, starts from beginning.
|
||||
- limit: Number of items to return. If None or -1, returns all items.
|
||||
|
||||
The response includes:
|
||||
- data: List of items for the current page.
|
||||
- has_more: Whether there are more items available after this set.
|
||||
|
||||
:param dataset_id: The ID of the dataset to get the rows from.
|
||||
:param start_index: Index into dataset for the first row to get. Get all rows if None.
|
||||
:param limit: The number of rows to get.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/datasetio/append-rows/{dataset_id:path}", method="POST", deprecated=True, level=LLAMA_STACK_API_V1
|
||||
)
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1BETA)
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
"""Append rows to a dataset.
|
||||
|
||||
:param dataset_id: The ID of the dataset to append the rows to.
|
||||
:param rows: The rows to append to the dataset.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/datasets/__init__.py
Normal file
7
src/llama_stack/apis/datasets/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .datasets import *
|
||||
251
src/llama_stack/apis/datasets/datasets.py
Normal file
251
src/llama_stack/apis/datasets/datasets.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
class DatasetPurpose(StrEnum):
|
||||
"""
|
||||
Purpose of the dataset. Each purpose has a required input data schema.
|
||||
|
||||
:cvar post-training/messages: The dataset contains messages used for post-training.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
{"role": "assistant", "content": "Hello, world!"},
|
||||
]
|
||||
}
|
||||
:cvar eval/question-answer: The dataset contains a question column and an answer column.
|
||||
{
|
||||
"question": "What is the capital of France?",
|
||||
"answer": "Paris"
|
||||
}
|
||||
:cvar eval/messages-answer: The dataset contains a messages column with list of messages and an answer column.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, my name is John Doe."},
|
||||
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
|
||||
{"role": "user", "content": "What's my name?"},
|
||||
],
|
||||
"answer": "John Doe"
|
||||
}
|
||||
"""
|
||||
|
||||
post_training_messages = "post-training/messages"
|
||||
eval_question_answer = "eval/question-answer"
|
||||
eval_messages_answer = "eval/messages-answer"
|
||||
|
||||
# TODO: add more schemas here
|
||||
|
||||
|
||||
class DatasetType(Enum):
|
||||
"""
|
||||
Type of the dataset source.
|
||||
:cvar uri: The dataset can be obtained from a URI.
|
||||
:cvar rows: The dataset is stored in rows.
|
||||
"""
|
||||
|
||||
uri = "uri"
|
||||
rows = "rows"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class URIDataSource(BaseModel):
|
||||
"""A dataset that can be obtained from a URI.
|
||||
:param uri: The dataset can be obtained from a URI. E.g.
|
||||
- "https://mywebsite.com/mydata.jsonl"
|
||||
- "lsfs://mydata.jsonl"
|
||||
- "data:csv;base64,{base64_content}"
|
||||
"""
|
||||
|
||||
type: Literal["uri"] = "uri"
|
||||
uri: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RowsDataSource(BaseModel):
|
||||
"""A dataset stored in rows.
|
||||
:param rows: The dataset is stored in rows. E.g.
|
||||
- [
|
||||
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
|
||||
]
|
||||
"""
|
||||
|
||||
type: Literal["rows"] = "rows"
|
||||
rows: list[dict[str, Any]]
|
||||
|
||||
|
||||
DataSource = Annotated[
|
||||
URIDataSource | RowsDataSource,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(DataSource, name="DataSource")
|
||||
|
||||
|
||||
class CommonDatasetFields(BaseModel):
|
||||
"""
|
||||
Common fields for a dataset.
|
||||
|
||||
:param purpose: Purpose of the dataset indicating its intended use
|
||||
:param source: Data source configuration for the dataset
|
||||
:param metadata: Additional metadata for the dataset
|
||||
"""
|
||||
|
||||
purpose: DatasetPurpose
|
||||
source: DataSource
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this dataset",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Dataset(CommonDatasetFields, Resource):
|
||||
"""Dataset resource for storing and accessing training or evaluation data.
|
||||
|
||||
:param type: Type of resource, always 'dataset' for datasets
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.dataset] = ResourceType.dataset
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_dataset_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||
"""Input parameters for dataset operations.
|
||||
|
||||
:param dataset_id: Unique identifier for the dataset
|
||||
"""
|
||||
|
||||
dataset_id: str
|
||||
|
||||
|
||||
class ListDatasetsResponse(BaseModel):
|
||||
"""Response from listing datasets.
|
||||
|
||||
:param data: List of datasets
|
||||
"""
|
||||
|
||||
data: list[Dataset]
|
||||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets", method="POST", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
dataset_id: str | None = None,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Register a new dataset.
|
||||
|
||||
:param purpose: The purpose of the dataset.
|
||||
One of:
|
||||
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
{"role": "assistant", "content": "Hello, world!"},
|
||||
]
|
||||
}
|
||||
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
|
||||
{
|
||||
"question": "What is the capital of France?",
|
||||
"answer": "Paris"
|
||||
}
|
||||
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, my name is John Doe."},
|
||||
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
|
||||
{"role": "user", "content": "What's my name?"},
|
||||
],
|
||||
"answer": "John Doe"
|
||||
}
|
||||
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples:
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "https://mywebsite.com/mydata.jsonl"
|
||||
}
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "lsfs://mydata.jsonl"
|
||||
}
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "data:csv;base64,{base64_content}"
|
||||
}
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "huggingface://llamastack/simpleqa?split=train"
|
||||
}
|
||||
- {
|
||||
"type": "rows",
|
||||
"rows": [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
{"role": "assistant", "content": "Hello, world!"},
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
:param metadata: The metadata for the dataset.
|
||||
- E.g. {"description": "My dataset"}.
|
||||
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
||||
:returns: A Dataset.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> Dataset:
|
||||
"""Get a dataset by its ID.
|
||||
|
||||
:param dataset_id: The ID of the dataset to get.
|
||||
:returns: A Dataset.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
"""List all datasets.
|
||||
|
||||
:returns: A ListDatasetsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> None:
|
||||
"""Unregister a dataset by its ID.
|
||||
|
||||
:param dataset_id: The ID of the dataset to unregister.
|
||||
"""
|
||||
...
|
||||
158
src/llama_stack/apis/datatypes.py
Normal file
158
src/llama_stack/apis/datatypes.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum, EnumMeta
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class DynamicApiMeta(EnumMeta):
|
||||
def __new__(cls, name, bases, namespace):
|
||||
# Store the original enum values
|
||||
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
|
||||
|
||||
# Create the enum class
|
||||
cls = super().__new__(cls, name, bases, namespace)
|
||||
|
||||
# Store the original values for reference
|
||||
cls._original_values = original_values
|
||||
# Initialize _dynamic_values
|
||||
cls._dynamic_values = {}
|
||||
|
||||
return cls
|
||||
|
||||
def __call__(cls, value):
|
||||
try:
|
||||
return super().__call__(value)
|
||||
except ValueError as e:
|
||||
# If this value was already dynamically added, return it
|
||||
if value in cls._dynamic_values:
|
||||
return cls._dynamic_values[value]
|
||||
|
||||
# If the value doesn't exist, create a new enum member
|
||||
# Create a new member name from the value
|
||||
member_name = value.lower().replace("-", "_")
|
||||
|
||||
# If this member name already exists in the enum, return the existing member
|
||||
if member_name in cls._member_map_:
|
||||
return cls._member_map_[member_name]
|
||||
|
||||
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
|
||||
# register new APIs explicitly
|
||||
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
|
||||
|
||||
def __iter__(cls):
|
||||
# Allow iteration over both static and dynamic members
|
||||
yield from super().__iter__()
|
||||
if hasattr(cls, "_dynamic_values"):
|
||||
yield from cls._dynamic_values.values()
|
||||
|
||||
def add(cls, value):
|
||||
"""
|
||||
Add a new API to the enum.
|
||||
Used to register external APIs.
|
||||
"""
|
||||
member_name = value.lower().replace("-", "_")
|
||||
|
||||
# If this member name already exists in the enum, return it
|
||||
if member_name in cls._member_map_:
|
||||
return cls._member_map_[member_name]
|
||||
|
||||
# Create a new enum member
|
||||
member = object.__new__(cls)
|
||||
member._name_ = member_name
|
||||
member._value_ = value
|
||||
|
||||
# Add it to the enum class
|
||||
cls._member_map_[member_name] = member
|
||||
cls._member_names_.append(member_name)
|
||||
cls._member_type_ = str
|
||||
|
||||
# Store it in our dynamic values
|
||||
cls._dynamic_values[value] = member
|
||||
|
||||
return member
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Api(Enum, metaclass=DynamicApiMeta):
|
||||
"""Enumeration of all available APIs in the Llama Stack system.
|
||||
:cvar providers: Provider management and configuration
|
||||
:cvar inference: Text generation, chat completions, and embeddings
|
||||
:cvar safety: Content moderation and safety shields
|
||||
:cvar agents: Agent orchestration and execution
|
||||
:cvar batches: Batch processing for asynchronous API requests
|
||||
:cvar vector_io: Vector database operations and queries
|
||||
:cvar datasetio: Dataset input/output operations
|
||||
:cvar scoring: Model output evaluation and scoring
|
||||
:cvar eval: Model evaluation and benchmarking framework
|
||||
:cvar post_training: Fine-tuning and model training
|
||||
:cvar tool_runtime: Tool execution and management
|
||||
:cvar telemetry: Observability and system monitoring
|
||||
:cvar models: Model metadata and management
|
||||
:cvar shields: Safety shield implementations
|
||||
:cvar datasets: Dataset creation and management
|
||||
:cvar scoring_functions: Scoring function definitions
|
||||
:cvar benchmarks: Benchmark suite management
|
||||
:cvar tool_groups: Tool group organization
|
||||
:cvar files: File storage and management
|
||||
:cvar prompts: Prompt versions and management
|
||||
:cvar inspect: Built-in system inspection and introspection
|
||||
"""
|
||||
|
||||
providers = "providers"
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
batches = "batches"
|
||||
vector_io = "vector_io"
|
||||
datasetio = "datasetio"
|
||||
scoring = "scoring"
|
||||
eval = "eval"
|
||||
post_training = "post_training"
|
||||
tool_runtime = "tool_runtime"
|
||||
|
||||
models = "models"
|
||||
shields = "shields"
|
||||
vector_stores = "vector_stores" # only used for routing table
|
||||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
benchmarks = "benchmarks"
|
||||
tool_groups = "tool_groups"
|
||||
files = "files"
|
||||
prompts = "prompts"
|
||||
conversations = "conversations"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Error(BaseModel):
|
||||
"""
|
||||
Error response from the API. Roughly follows RFC 7807.
|
||||
|
||||
:param status: HTTP status code
|
||||
:param title: Error title, a short summary of the error which is invariant for an error type
|
||||
:param detail: Error detail, a longer human-readable description of the error
|
||||
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
|
||||
"""
|
||||
|
||||
status: int
|
||||
title: str
|
||||
detail: str
|
||||
instance: str | None = None
|
||||
|
||||
|
||||
class ExternalApiSpec(BaseModel):
|
||||
"""Specification for an external API implementation."""
|
||||
|
||||
module: str = Field(..., description="Python module containing the API implementation")
|
||||
name: str = Field(..., description="Name of the API")
|
||||
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
|
||||
protocol: str = Field(..., description="Name of the protocol class for the API")
|
||||
7
src/llama_stack/apis/eval/__init__.py
Normal file
7
src/llama_stack/apis/eval/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .eval import *
|
||||
169
src/llama_stack/apis/eval/eval.py
Normal file
169
src/llama_stack/apis/eval/eval.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelCandidate(BaseModel):
|
||||
"""A model candidate for evaluation.
|
||||
|
||||
:param model: The model ID to evaluate.
|
||||
:param sampling_params: The sampling parameters for the model.
|
||||
:param system_message: (Optional) The system message providing instructions or context to the model.
|
||||
"""
|
||||
|
||||
type: Literal["model"] = "model"
|
||||
model: str
|
||||
sampling_params: SamplingParams
|
||||
system_message: SystemMessage | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentCandidate(BaseModel):
|
||||
"""An agent candidate for evaluation.
|
||||
|
||||
:param config: The configuration for the agent candidate.
|
||||
"""
|
||||
|
||||
type: Literal["agent"] = "agent"
|
||||
config: AgentConfig
|
||||
|
||||
|
||||
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
|
||||
register_schema(EvalCandidate, name="EvalCandidate")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""A benchmark configuration for evaluation.
|
||||
|
||||
:param eval_candidate: The candidate to evaluate.
|
||||
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
|
||||
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
|
||||
"""
|
||||
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
default_factory=dict,
|
||||
)
|
||||
num_examples: int | None = Field(
|
||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||
default=None,
|
||||
)
|
||||
# we could optinally add any specific dataset config here
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateResponse(BaseModel):
|
||||
"""The response from an evaluation.
|
||||
|
||||
:param generations: The generations from the evaluation.
|
||||
:param scores: The scores from the evaluation.
|
||||
"""
|
||||
|
||||
generations: list[dict[str, Any]]
|
||||
# each key in the dict is a scoring function name
|
||||
scores: dict[str, ScoringResult]
|
||||
|
||||
|
||||
class Eval(Protocol):
|
||||
"""Evaluations
|
||||
|
||||
Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
"""Run an evaluation on a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:returns: The job that was created to run the evaluation.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: list[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
"""Evaluate a list of rows on a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param input_rows: The rows to evaluate.
|
||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:returns: EvaluateResponse object containing generations and scores.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the status of.
|
||||
:returns: The status of the evaluation job.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to cancel.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET", level=LLAMA_STACK_API_V1ALPHA
|
||||
)
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Get the result of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the result of.
|
||||
:returns: The result of the job.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/files/__init__.py
Normal file
7
src/llama_stack/apis/files/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .files import *
|
||||
199
src/llama_stack/apis/files/files.py
Normal file
199
src/llama_stack/apis/files/files.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, ClassVar, Literal, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
# OpenAI Files API Models
|
||||
class OpenAIFilePurpose(StrEnum):
|
||||
"""
|
||||
Valid purpose values for OpenAI Files API.
|
||||
"""
|
||||
|
||||
ASSISTANTS = "assistants"
|
||||
BATCH = "batch"
|
||||
# TODO: Add other purposes as needed
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIFileObject(BaseModel):
|
||||
"""
|
||||
OpenAI File object as defined in the OpenAI Files API.
|
||||
|
||||
:param object: The object type, which is always "file"
|
||||
:param id: The file identifier, which can be referenced in the API endpoints
|
||||
:param bytes: The size of the file, in bytes
|
||||
:param created_at: The Unix timestamp (in seconds) for when the file was created
|
||||
:param expires_at: The Unix timestamp (in seconds) for when the file expires
|
||||
:param filename: The name of the file
|
||||
:param purpose: The intended purpose of the file
|
||||
"""
|
||||
|
||||
object: Literal["file"] = "file"
|
||||
id: str
|
||||
bytes: int
|
||||
created_at: int
|
||||
expires_at: int
|
||||
filename: str
|
||||
purpose: OpenAIFilePurpose
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ExpiresAfter(BaseModel):
|
||||
"""
|
||||
Control expiration of uploaded files.
|
||||
|
||||
Params:
|
||||
- anchor, must be "created_at"
|
||||
- seconds, must be int between 3600 and 2592000 (1 hour to 30 days)
|
||||
"""
|
||||
|
||||
MIN: ClassVar[int] = 3600 # 1 hour
|
||||
MAX: ClassVar[int] = 2592000 # 30 days
|
||||
|
||||
anchor: Literal["created_at"]
|
||||
seconds: int = Field(..., ge=3600, le=2592000)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIFileResponse(BaseModel):
|
||||
"""
|
||||
Response for listing files in OpenAI Files API.
|
||||
|
||||
:param data: List of file objects
|
||||
:param has_more: Whether there are more files available beyond this page
|
||||
:param first_id: ID of the first file in the list for pagination
|
||||
:param last_id: ID of the last file in the list for pagination
|
||||
:param object: The object type, which is always "list"
|
||||
"""
|
||||
|
||||
data: list[OpenAIFileObject]
|
||||
has_more: bool
|
||||
first_id: str
|
||||
last_id: str
|
||||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIFileDeleteResponse(BaseModel):
|
||||
"""
|
||||
Response for deleting a file in OpenAI Files API.
|
||||
|
||||
:param id: The file identifier that was deleted
|
||||
:param object: The object type, which is always "file"
|
||||
:param deleted: Whether the file was successfully deleted
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: Literal["file"] = "file"
|
||||
deleted: bool
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Files(Protocol):
|
||||
"""Files
|
||||
|
||||
This API is used to upload documents that can be used with other Llama Stack APIs.
|
||||
"""
|
||||
|
||||
# OpenAI Files API Endpoints
|
||||
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
|
||||
) -> OpenAIFileObject:
|
||||
"""Upload file.
|
||||
|
||||
Upload a file that can be used across various endpoints.
|
||||
|
||||
The file upload should be a multipart form request with:
|
||||
- file: The File object (not file name) to be uploaded.
|
||||
- purpose: The intended purpose of the uploaded file.
|
||||
- expires_after: Optional form values describing expiration for the file.
|
||||
|
||||
:param file: The uploaded file object containing content and metadata (filename, content_type, etc.).
|
||||
:param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune").
|
||||
:param expires_after: Optional form values describing expiration for the file.
|
||||
:returns: An OpenAIFileObject representing the uploaded file.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_files(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 10000,
|
||||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
"""List files.
|
||||
|
||||
Returns a list of files that belong to the user's organization.
|
||||
|
||||
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
|
||||
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 10,000, and the default is 10,000.
|
||||
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
|
||||
:param purpose: Only return files with the given purpose.
|
||||
:returns: An ListOpenAIFileResponse containing the list of files.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_file(
|
||||
self,
|
||||
file_id: str,
|
||||
) -> OpenAIFileObject:
|
||||
"""Retrieve file.
|
||||
|
||||
Returns information about a specific file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
:returns: An OpenAIFileObject containing file information.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_file(
|
||||
self,
|
||||
file_id: str,
|
||||
) -> OpenAIFileDeleteResponse:
|
||||
"""Delete file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
:returns: An OpenAIFileDeleteResponse indicating successful deletion.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_file_content(
|
||||
self,
|
||||
file_id: str,
|
||||
) -> Response:
|
||||
"""Retrieve file content.
|
||||
|
||||
Returns the contents of the specified file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
:returns: The raw file content as a binary response.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/inference/__init__.py
Normal file
7
src/llama_stack/apis/inference/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .inference import *
|
||||
43
src/llama_stack/apis/inference/event_logger.py
Normal file
43
src/llama_stack/apis/inference/event_logger.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
)
|
||||
|
||||
|
||||
class LogEvent:
|
||||
def __init__(
|
||||
self,
|
||||
content: str = "",
|
||||
end: str = "\n",
|
||||
color="white",
|
||||
):
|
||||
self.content = content
|
||||
self.color = color
|
||||
self.end = "\n" if end is None else end
|
||||
|
||||
def print(self, flush=True):
|
||||
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
|
||||
|
||||
|
||||
class EventLogger:
|
||||
async def log(self, event_generator):
|
||||
async for chunk in event_generator:
|
||||
if isinstance(chunk, ChatCompletionResponseStreamChunk):
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.start:
|
||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
||||
elif event.event_type == ChatCompletionResponseEventType.progress:
|
||||
yield LogEvent(event.delta, color="yellow", end="")
|
||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||
yield LogEvent("")
|
||||
else:
|
||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
||||
yield LogEvent(chunk.completion_message.content, color="yellow")
|
||||
1274
src/llama_stack/apis/inference/inference.py
Normal file
1274
src/llama_stack/apis/inference/inference.py
Normal file
File diff suppressed because it is too large
Load diff
7
src/llama_stack/apis/inspect/__init__.py
Normal file
7
src/llama_stack/apis/inspect/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .inspect import *
|
||||
94
src/llama_stack/apis/inspect/inspect.py
Normal file
94
src/llama_stack/apis/inspect/inspect.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
"""Information about an API route including its path, method, and implementing providers.
|
||||
|
||||
:param route: The API endpoint path
|
||||
:param method: HTTP method for the route
|
||||
:param provider_types: List of provider types that implement this route
|
||||
"""
|
||||
|
||||
route: str
|
||||
method: str
|
||||
provider_types: list[str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class HealthInfo(BaseModel):
|
||||
"""Health status information for the service.
|
||||
|
||||
:param status: Current health status of the service
|
||||
"""
|
||||
|
||||
status: HealthStatus
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VersionInfo(BaseModel):
|
||||
"""Version information for the service.
|
||||
|
||||
:param version: Version number of the service
|
||||
"""
|
||||
|
||||
version: str
|
||||
|
||||
|
||||
class ListRoutesResponse(BaseModel):
|
||||
"""Response containing a list of all available API routes.
|
||||
|
||||
:param data: List of available route information objects
|
||||
"""
|
||||
|
||||
data: list[RouteInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
"""Inspect
|
||||
|
||||
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
|
||||
"""
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
"""List routes.
|
||||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
|
||||
:returns: Response containing information about all available routes.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
|
||||
async def health(self) -> HealthInfo:
|
||||
"""Get health status.
|
||||
|
||||
Get the current health status of the service.
|
||||
|
||||
:returns: Health information indicating if the service is operational.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
|
||||
async def version(self) -> VersionInfo:
|
||||
"""Get version.
|
||||
|
||||
Get the version of the service.
|
||||
|
||||
:returns: Version information containing the service version number.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/models/__init__.py
Normal file
7
src/llama_stack/apis/models/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .models import *
|
||||
171
src/llama_stack/apis/models/models.py
Normal file
171
src/llama_stack/apis/models/models.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class CommonModelFields(BaseModel):
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this model",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelType(StrEnum):
|
||||
"""Enumeration of supported model types in Llama Stack.
|
||||
:cvar llm: Large language model for text generation and completion
|
||||
:cvar embedding: Embedding model for converting text to vector representations
|
||||
:cvar rerank: Reranking model for reordering documents based on their relevance to a query
|
||||
"""
|
||||
|
||||
llm = "llm"
|
||||
embedding = "embedding"
|
||||
rerank = "rerank"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Model(CommonModelFields, Resource):
|
||||
"""A model resource representing an AI model registered in Llama Stack.
|
||||
|
||||
:param type: The resource type, always 'model' for model resources
|
||||
:param model_type: The type of model (LLM or embedding model)
|
||||
:param metadata: Any additional metadata for this model
|
||||
:param identifier: Unique identifier for this resource in llama stack
|
||||
:param provider_resource_id: Unique identifier for this resource in the provider
|
||||
:param provider_id: ID of the provider that owns this resource
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.model] = ResourceType.model
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_model_id(self) -> str:
|
||||
assert self.provider_resource_id is not None, "Provider resource ID must be set"
|
||||
return self.provider_resource_id
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
model_type: ModelType = Field(default=ModelType.llm)
|
||||
|
||||
@field_validator("provider_resource_id")
|
||||
@classmethod
|
||||
def validate_provider_resource_id(cls, v):
|
||||
if v is None:
|
||||
raise ValueError("provider_resource_id cannot be None")
|
||||
return v
|
||||
|
||||
|
||||
class ModelInput(CommonModelFields):
|
||||
model_id: str
|
||||
provider_id: str | None = None
|
||||
provider_model_id: str | None = None
|
||||
model_type: ModelType | None = ModelType.llm
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ListModelsResponse(BaseModel):
|
||||
data: list[Model]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIModel(BaseModel):
|
||||
"""A model from OpenAI.
|
||||
|
||||
:id: The ID of the model
|
||||
:object: The object type, which will be "model"
|
||||
:created: The Unix timestamp in seconds when the model was created
|
||||
:owned_by: The owner of the model
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
created: int
|
||||
owned_by: str
|
||||
|
||||
|
||||
class OpenAIListModelsResponse(BaseModel):
|
||||
data: list[OpenAIModel]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
"""List all models.
|
||||
|
||||
:returns: A ListModelsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List models using the OpenAI API.
|
||||
|
||||
:returns: A OpenAIListModelsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> Model:
|
||||
"""Get model.
|
||||
|
||||
Get a model by its identifier.
|
||||
|
||||
:param model_id: The identifier of the model to get.
|
||||
:returns: A Model.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
"""Register model.
|
||||
|
||||
Register a model.
|
||||
|
||||
:param model_id: The identifier of the model to register.
|
||||
:param provider_model_id: The identifier of the model in the provider.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param metadata: Any additional metadata for this model.
|
||||
:param model_type: The type of model to register.
|
||||
:returns: A Model.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> None:
|
||||
"""Unregister model.
|
||||
|
||||
Unregister a model.
|
||||
|
||||
:param model_id: The identifier of the model to unregister.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/post_training/__init__.py
Normal file
7
src/llama_stack/apis/post_training/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .post_training import *
|
||||
374
src/llama_stack/apis/post_training/post_training.py
Normal file
374
src/llama_stack/apis/post_training/post_training.py
Normal file
|
|
@ -0,0 +1,374 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.common.training_types import Checkpoint
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OptimizerType(Enum):
|
||||
"""Available optimizer algorithms for training.
|
||||
:cvar adam: Adaptive Moment Estimation optimizer
|
||||
:cvar adamw: AdamW optimizer with weight decay
|
||||
:cvar sgd: Stochastic Gradient Descent optimizer
|
||||
"""
|
||||
|
||||
adam = "adam"
|
||||
adamw = "adamw"
|
||||
sgd = "sgd"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatasetFormat(Enum):
|
||||
"""Format of the training dataset.
|
||||
:cvar instruct: Instruction-following format with prompt and completion
|
||||
:cvar dialog: Multi-turn conversation format with messages
|
||||
"""
|
||||
|
||||
instruct = "instruct"
|
||||
dialog = "dialog"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DataConfig(BaseModel):
|
||||
"""Configuration for training data and data loading.
|
||||
|
||||
:param dataset_id: Unique identifier for the training dataset
|
||||
:param batch_size: Number of samples per training batch
|
||||
:param shuffle: Whether to shuffle the dataset during training
|
||||
:param data_format: Format of the dataset (instruct or dialog)
|
||||
:param validation_dataset_id: (Optional) Unique identifier for the validation dataset
|
||||
:param packed: (Optional) Whether to pack multiple samples into a single sequence for efficiency
|
||||
:param train_on_input: (Optional) Whether to compute loss on input tokens as well as output tokens
|
||||
"""
|
||||
|
||||
dataset_id: str
|
||||
batch_size: int
|
||||
shuffle: bool
|
||||
data_format: DatasetFormat
|
||||
validation_dataset_id: str | None = None
|
||||
packed: bool | None = False
|
||||
train_on_input: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OptimizerConfig(BaseModel):
|
||||
"""Configuration parameters for the optimization algorithm.
|
||||
|
||||
:param optimizer_type: Type of optimizer to use (adam, adamw, or sgd)
|
||||
:param lr: Learning rate for the optimizer
|
||||
:param weight_decay: Weight decay coefficient for regularization
|
||||
:param num_warmup_steps: Number of steps for learning rate warmup
|
||||
"""
|
||||
|
||||
optimizer_type: OptimizerType
|
||||
lr: float
|
||||
weight_decay: float
|
||||
num_warmup_steps: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EfficiencyConfig(BaseModel):
|
||||
"""Configuration for memory and compute efficiency optimizations.
|
||||
|
||||
:param enable_activation_checkpointing: (Optional) Whether to use activation checkpointing to reduce memory usage
|
||||
:param enable_activation_offloading: (Optional) Whether to offload activations to CPU to save GPU memory
|
||||
:param memory_efficient_fsdp_wrap: (Optional) Whether to use memory-efficient FSDP wrapping
|
||||
:param fsdp_cpu_offload: (Optional) Whether to offload FSDP parameters to CPU
|
||||
"""
|
||||
|
||||
enable_activation_checkpointing: bool | None = False
|
||||
enable_activation_offloading: bool | None = False
|
||||
memory_efficient_fsdp_wrap: bool | None = False
|
||||
fsdp_cpu_offload: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Comprehensive configuration for the training process.
|
||||
|
||||
:param n_epochs: Number of training epochs to run
|
||||
:param max_steps_per_epoch: Maximum number of steps to run per epoch
|
||||
:param gradient_accumulation_steps: Number of steps to accumulate gradients before updating
|
||||
:param max_validation_steps: (Optional) Maximum number of validation steps per epoch
|
||||
:param data_config: (Optional) Configuration for data loading and formatting
|
||||
:param optimizer_config: (Optional) Configuration for the optimization algorithm
|
||||
:param efficiency_config: (Optional) Configuration for memory and compute optimizations
|
||||
:param dtype: (Optional) Data type for model parameters (bf16, fp16, fp32)
|
||||
"""
|
||||
|
||||
n_epochs: int
|
||||
max_steps_per_epoch: int = 1
|
||||
gradient_accumulation_steps: int = 1
|
||||
max_validation_steps: int | None = 1
|
||||
data_config: DataConfig | None = None
|
||||
optimizer_config: OptimizerConfig | None = None
|
||||
efficiency_config: EfficiencyConfig | None = None
|
||||
dtype: str | None = "bf16"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LoraFinetuningConfig(BaseModel):
|
||||
"""Configuration for Low-Rank Adaptation (LoRA) fine-tuning.
|
||||
|
||||
:param type: Algorithm type identifier, always "LoRA"
|
||||
:param lora_attn_modules: List of attention module names to apply LoRA to
|
||||
:param apply_lora_to_mlp: Whether to apply LoRA to MLP layers
|
||||
:param apply_lora_to_output: Whether to apply LoRA to output projection layers
|
||||
:param rank: Rank of the LoRA adaptation (lower rank = fewer parameters)
|
||||
:param alpha: LoRA scaling parameter that controls adaptation strength
|
||||
:param use_dora: (Optional) Whether to use DoRA (Weight-Decomposed Low-Rank Adaptation)
|
||||
:param quantize_base: (Optional) Whether to quantize the base model weights
|
||||
"""
|
||||
|
||||
type: Literal["LoRA"] = "LoRA"
|
||||
lora_attn_modules: list[str]
|
||||
apply_lora_to_mlp: bool
|
||||
apply_lora_to_output: bool
|
||||
rank: int
|
||||
alpha: int
|
||||
use_dora: bool | None = False
|
||||
quantize_base: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QATFinetuningConfig(BaseModel):
|
||||
"""Configuration for Quantization-Aware Training (QAT) fine-tuning.
|
||||
|
||||
:param type: Algorithm type identifier, always "QAT"
|
||||
:param quantizer_name: Name of the quantization algorithm to use
|
||||
:param group_size: Size of groups for grouped quantization
|
||||
"""
|
||||
|
||||
type: Literal["QAT"] = "QAT"
|
||||
quantizer_name: str
|
||||
group_size: int
|
||||
|
||||
|
||||
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingJobLogStream(BaseModel):
|
||||
"""Stream of logs from a finetuning job.
|
||||
|
||||
:param job_uuid: Unique identifier for the training job
|
||||
:param log_lines: List of log message strings from the training process
|
||||
"""
|
||||
|
||||
job_uuid: str
|
||||
log_lines: list[str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RLHFAlgorithm(Enum):
|
||||
"""Available reinforcement learning from human feedback algorithms.
|
||||
:cvar dpo: Direct Preference Optimization algorithm
|
||||
"""
|
||||
|
||||
dpo = "dpo"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DPOLossType(Enum):
|
||||
sigmoid = "sigmoid"
|
||||
hinge = "hinge"
|
||||
ipo = "ipo"
|
||||
kto_pair = "kto_pair"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DPOAlignmentConfig(BaseModel):
|
||||
"""Configuration for Direct Preference Optimization (DPO) alignment.
|
||||
|
||||
:param beta: Temperature parameter for the DPO loss
|
||||
:param loss_type: The type of loss function to use for DPO
|
||||
"""
|
||||
|
||||
beta: float
|
||||
loss_type: DPOLossType = DPOLossType.sigmoid
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingRLHFRequest(BaseModel):
|
||||
"""Request to finetune a model using reinforcement learning from human feedback.
|
||||
|
||||
:param job_uuid: Unique identifier for the training job
|
||||
:param finetuned_model: URL or path to the base model to fine-tune
|
||||
:param dataset_id: Unique identifier for the training dataset
|
||||
:param validation_dataset_id: Unique identifier for the validation dataset
|
||||
:param algorithm: RLHF algorithm to use for training
|
||||
:param algorithm_config: Configuration parameters for the RLHF algorithm
|
||||
:param optimizer_config: Configuration parameters for the optimization algorithm
|
||||
:param training_config: Configuration parameters for the training process
|
||||
:param hyperparam_search_config: Configuration for hyperparameter search
|
||||
:param logger_config: Configuration for training logging
|
||||
"""
|
||||
|
||||
job_uuid: str
|
||||
|
||||
finetuned_model: URL
|
||||
|
||||
dataset_id: str
|
||||
validation_dataset_id: str
|
||||
|
||||
algorithm: RLHFAlgorithm
|
||||
algorithm_config: DPOAlignmentConfig
|
||||
|
||||
optimizer_config: OptimizerConfig
|
||||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
hyperparam_search_config: dict[str, Any]
|
||||
logger_config: dict[str, Any]
|
||||
|
||||
|
||||
class PostTrainingJob(BaseModel):
|
||||
job_uuid: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingJobStatusResponse(BaseModel):
|
||||
"""Status of a finetuning job.
|
||||
|
||||
:param job_uuid: Unique identifier for the training job
|
||||
:param status: Current status of the training job
|
||||
:param scheduled_at: (Optional) Timestamp when the job was scheduled
|
||||
:param started_at: (Optional) Timestamp when the job execution began
|
||||
:param completed_at: (Optional) Timestamp when the job finished, if completed
|
||||
:param resources_allocated: (Optional) Information about computational resources allocated to the job
|
||||
:param checkpoints: List of model checkpoints created during training
|
||||
"""
|
||||
|
||||
job_uuid: str
|
||||
status: JobStatus
|
||||
|
||||
scheduled_at: datetime | None = None
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
resources_allocated: dict[str, Any] | None = None
|
||||
|
||||
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ListPostTrainingJobsResponse(BaseModel):
|
||||
data: list[PostTrainingJob]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingJobArtifactsResponse(BaseModel):
|
||||
"""Artifacts of a finetuning job.
|
||||
|
||||
:param job_uuid: Unique identifier for the training job
|
||||
:param checkpoints: List of model checkpoints created during training
|
||||
"""
|
||||
|
||||
job_uuid: str
|
||||
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||
|
||||
# TODO(ashwin): metrics, evals
|
||||
|
||||
|
||||
class PostTraining(Protocol):
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str | None = Field(
|
||||
default=None,
|
||||
description="Model descriptor for training if not in provider config`",
|
||||
),
|
||||
checkpoint_dir: str | None = None,
|
||||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob:
|
||||
"""Run supervised fine-tuning of a model.
|
||||
|
||||
:param job_uuid: The UUID of the job to create.
|
||||
:param training_config: The training configuration.
|
||||
:param hyperparam_search_config: The hyperparam search configuration.
|
||||
:param logger_config: The logger configuration.
|
||||
:param model: The model to fine-tune.
|
||||
:param checkpoint_dir: The directory to save checkpoint(s) to.
|
||||
:param algorithm_config: The algorithm configuration.
|
||||
:returns: A PostTrainingJob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
"""Run preference optimization of a model.
|
||||
|
||||
:param job_uuid: The UUID of the job to create.
|
||||
:param finetuned_model: The model to fine-tune.
|
||||
:param algorithm_config: The algorithm configuration.
|
||||
:param training_config: The training configuration.
|
||||
:param hyperparam_search_config: The hyperparam search configuration.
|
||||
:param logger_config: The logger configuration.
|
||||
:returns: A PostTrainingJob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
"""Get all training jobs.
|
||||
|
||||
:returns: A ListPostTrainingJobsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||
"""Get the status of a training job.
|
||||
|
||||
:param job_uuid: The UUID of the job to get the status of.
|
||||
:returns: A PostTrainingJobStatusResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
"""Cancel a training job.
|
||||
|
||||
:param job_uuid: The UUID of the job to cancel.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||
"""Get the artifacts of a training job.
|
||||
|
||||
:param job_uuid: The UUID of the job to get the artifacts of.
|
||||
:returns: A PostTrainingJobArtifactsResponse.
|
||||
"""
|
||||
...
|
||||
9
src/llama_stack/apis/prompts/__init__.py
Normal file
9
src/llama_stack/apis/prompts/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .prompts import ListPromptsResponse, Prompt, Prompts
|
||||
|
||||
__all__ = ["Prompt", "Prompts", "ListPromptsResponse"]
|
||||
204
src/llama_stack/apis/prompts/prompts.py
Normal file
204
src/llama_stack/apis/prompts/prompts.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import secrets
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Prompt(BaseModel):
|
||||
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack.
|
||||
|
||||
:param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API.
|
||||
:param version: Version (integer starting at 1, incremented on save)
|
||||
:param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>'
|
||||
:param variables: List of prompt variable names that can be used in the prompt template
|
||||
:param is_default: Boolean indicating whether this version is the default version for this prompt
|
||||
"""
|
||||
|
||||
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
|
||||
version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1)
|
||||
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
|
||||
variables: list[str] = Field(
|
||||
default_factory=list, description="List of variable names that can be used in the prompt template"
|
||||
)
|
||||
is_default: bool = Field(
|
||||
default=False, description="Boolean indicating whether this version is the default version"
|
||||
)
|
||||
|
||||
@field_validator("prompt_id")
|
||||
@classmethod
|
||||
def validate_prompt_id(cls, prompt_id: str) -> str:
|
||||
if not isinstance(prompt_id, str):
|
||||
raise TypeError("prompt_id must be a string in format 'pmpt_<48-digit-hash>'")
|
||||
|
||||
if not prompt_id.startswith("pmpt_"):
|
||||
raise ValueError("prompt_id must start with 'pmpt_' prefix")
|
||||
|
||||
hex_part = prompt_id[5:]
|
||||
if len(hex_part) != 48:
|
||||
raise ValueError("prompt_id must be in format 'pmpt_<48-digit-hash>' (48 lowercase hex chars)")
|
||||
|
||||
for char in hex_part:
|
||||
if char not in "0123456789abcdef":
|
||||
raise ValueError("prompt_id hex part must contain only lowercase hex characters [0-9a-f]")
|
||||
|
||||
return prompt_id
|
||||
|
||||
@field_validator("version")
|
||||
@classmethod
|
||||
def validate_version(cls, prompt_version: int) -> int:
|
||||
if prompt_version < 1:
|
||||
raise ValueError("version must be >= 1")
|
||||
return prompt_version
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_prompt_variables(self):
|
||||
"""Validate that all variables used in the prompt are declared in the variables list."""
|
||||
if not self.prompt:
|
||||
return self
|
||||
|
||||
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
|
||||
declared_variables = set(self.variables)
|
||||
|
||||
undeclared = prompt_variables - declared_variables
|
||||
if undeclared:
|
||||
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def generate_prompt_id(cls) -> str:
|
||||
# Generate 48 hex characters (24 bytes)
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
hex_string = random_bytes.hex()
|
||||
return f"pmpt_{hex_string}"
|
||||
|
||||
|
||||
class ListPromptsResponse(BaseModel):
|
||||
"""Response model to list prompts."""
|
||||
|
||||
data: list[Prompt]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Prompts(Protocol):
|
||||
"""Prompts
|
||||
|
||||
Protocol for prompt management operations."""
|
||||
|
||||
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
"""List all prompts.
|
||||
|
||||
:returns: A ListPromptsResponse containing all prompts.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}/versions", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_prompt_versions(
|
||||
self,
|
||||
prompt_id: str,
|
||||
) -> ListPromptsResponse:
|
||||
"""List prompt versions.
|
||||
|
||||
List all versions of a specific prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to list versions for.
|
||||
:returns: A ListPromptsResponse containing all versions of the prompt.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
version: int | None = None,
|
||||
) -> Prompt:
|
||||
"""Get prompt.
|
||||
|
||||
Get a prompt by its identifier and optional version.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to get.
|
||||
:param version: The version of the prompt to get (defaults to latest).
|
||||
:returns: A Prompt resource.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Create prompt.
|
||||
|
||||
Create a new prompt.
|
||||
|
||||
:param prompt: The prompt text content with variable placeholders.
|
||||
:param variables: List of variable names that can be used in the prompt template.
|
||||
:returns: The created Prompt resource.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}", method="PUT", level=LLAMA_STACK_API_V1)
|
||||
async def update_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt: str,
|
||||
version: int,
|
||||
variables: list[str] | None = None,
|
||||
set_as_default: bool = True,
|
||||
) -> Prompt:
|
||||
"""Update prompt.
|
||||
|
||||
Update an existing prompt (increments version).
|
||||
|
||||
:param prompt_id: The identifier of the prompt to update.
|
||||
:param prompt: The updated prompt text content.
|
||||
:param version: The current version of the prompt being updated.
|
||||
:param variables: Updated list of variable names that can be used in the prompt template.
|
||||
:param set_as_default: Set the new version as the default (default=True).
|
||||
:returns: The updated Prompt resource with incremented version.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def delete_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
) -> None:
|
||||
"""Delete prompt.
|
||||
|
||||
Delete a prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT", level=LLAMA_STACK_API_V1)
|
||||
async def set_default_version(
|
||||
self,
|
||||
prompt_id: str,
|
||||
version: int,
|
||||
) -> Prompt:
|
||||
"""Set prompt version.
|
||||
|
||||
Set which version of a prompt should be the default in get_prompt (latest).
|
||||
|
||||
:param prompt_id: The identifier of the prompt.
|
||||
:param version: The version to set as default.
|
||||
:returns: The prompt with the specified version now set as default.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/providers/__init__.py
Normal file
7
src/llama_stack/apis/providers/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .providers import *
|
||||
69
src/llama_stack/apis/providers/providers.py
Normal file
69
src/llama_stack/apis/providers/providers.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.datatypes import HealthResponse
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about a registered provider including its configuration and health status.
|
||||
|
||||
:param api: The API name this provider implements
|
||||
:param provider_id: Unique identifier for the provider
|
||||
:param provider_type: The type of provider implementation
|
||||
:param config: Configuration parameters for the provider
|
||||
:param health: Current health status of the provider
|
||||
"""
|
||||
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: dict[str, Any]
|
||||
health: HealthResponse
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
"""Response containing a list of all available providers.
|
||||
|
||||
:param data: List of provider information objects
|
||||
"""
|
||||
|
||||
data: list[ProviderInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""Providers
|
||||
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
"""List providers.
|
||||
|
||||
List all available providers.
|
||||
|
||||
:returns: A ListProvidersResponse containing information about all providers.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
"""Get provider.
|
||||
|
||||
Get detailed information about a specific provider.
|
||||
|
||||
:param provider_id: The ID of the provider to inspect.
|
||||
:returns: A ProviderInfo object containing the provider's details.
|
||||
"""
|
||||
...
|
||||
37
src/llama_stack/apis/resource.py
Normal file
37
src/llama_stack/apis/resource.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ResourceType(StrEnum):
|
||||
model = "model"
|
||||
shield = "shield"
|
||||
vector_store = "vector_store"
|
||||
dataset = "dataset"
|
||||
scoring_function = "scoring_function"
|
||||
benchmark = "benchmark"
|
||||
tool = "tool"
|
||||
tool_group = "tool_group"
|
||||
prompt = "prompt"
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
"""Base class for all Llama Stack resources"""
|
||||
|
||||
identifier: str = Field(description="Unique identifier for this resource in llama stack")
|
||||
|
||||
provider_resource_id: str | None = Field(
|
||||
default=None,
|
||||
description="Unique identifier for this resource in the provider",
|
||||
)
|
||||
|
||||
provider_id: str = Field(description="ID of the provider that owns this resource")
|
||||
|
||||
type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_store', etc.)")
|
||||
7
src/llama_stack/apis/safety/__init__.py
Normal file
7
src/llama_stack/apis/safety/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .safety import *
|
||||
135
src/llama_stack/apis/safety/safety.py
Normal file
135
src/llama_stack/apis/safety/safety.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModerationObjectResults(BaseModel):
|
||||
"""A moderation object.
|
||||
:param flagged: Whether any of the below categories are flagged.
|
||||
:param categories: A list of the categories, and whether they are flagged or not.
|
||||
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
|
||||
:param category_scores: A list of the categories along with their scores as predicted by model.
|
||||
"""
|
||||
|
||||
flagged: bool
|
||||
categories: dict[str, bool] | None = None
|
||||
category_applied_input_types: dict[str, list[str]] | None = None
|
||||
category_scores: dict[str, float] | None = None
|
||||
user_message: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModerationObject(BaseModel):
|
||||
"""A moderation object.
|
||||
:param id: The unique identifier for the moderation request.
|
||||
:param model: The model used to generate the moderation results.
|
||||
:param results: A list of moderation objects
|
||||
"""
|
||||
|
||||
id: str
|
||||
model: str
|
||||
results: list[ModerationObjectResults]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ViolationLevel(Enum):
|
||||
"""Severity level of a safety violation.
|
||||
|
||||
:cvar INFO: Informational level violation that does not require action
|
||||
:cvar WARN: Warning level violation that suggests caution but allows continuation
|
||||
:cvar ERROR: Error level violation that requires blocking or intervention
|
||||
"""
|
||||
|
||||
INFO = "info"
|
||||
WARN = "warn"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SafetyViolation(BaseModel):
|
||||
"""Details of a safety violation detected by content moderation.
|
||||
|
||||
:param violation_level: Severity level of the violation
|
||||
:param user_message: (Optional) Message to convey to the user about the violation
|
||||
:param metadata: Additional metadata including specific violation codes for debugging and telemetry
|
||||
"""
|
||||
|
||||
violation_level: ViolationLevel
|
||||
|
||||
# what message should you convey to the user
|
||||
user_message: str | None = None
|
||||
|
||||
# additional metadata (including specific violation codes) more for
|
||||
# debugging, telemetry
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RunShieldResponse(BaseModel):
|
||||
"""Response from running a safety shield.
|
||||
|
||||
:param violation: (Optional) Safety violation detected by the shield, if any
|
||||
"""
|
||||
|
||||
violation: SafetyViolation | None = None
|
||||
|
||||
|
||||
class ShieldStore(Protocol):
|
||||
async def get_shield(self, identifier: str) -> Shield: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Safety(Protocol):
|
||||
"""Safety
|
||||
|
||||
OpenAI-compatible Moderations API.
|
||||
"""
|
||||
|
||||
shield_store: ShieldStore
|
||||
|
||||
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
"""Run shield.
|
||||
|
||||
Run a shield.
|
||||
|
||||
:param shield_id: The identifier of the shield to run.
|
||||
:param messages: The messages to run the shield on.
|
||||
:param params: The parameters of the shield.
|
||||
:returns: A RunShieldResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
"""Create moderation.
|
||||
|
||||
Classifies if text and/or image inputs are potentially harmful.
|
||||
:param input: Input (or inputs) to classify.
|
||||
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
||||
:param model: (Optional) The content moderation model you would like to use.
|
||||
:returns: A moderation object.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/scoring/__init__.py
Normal file
7
src/llama_stack/apis/scoring/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .scoring import *
|
||||
93
src/llama_stack/apis/scoring/scoring.py
Normal file
93
src/llama_stack/apis/scoring/scoring.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
# mapping of metric to value
|
||||
ScoringResultRow = dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringResult(BaseModel):
|
||||
"""
|
||||
A scoring result for a single row.
|
||||
|
||||
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
|
||||
:param aggregated_results: Map of metric name to aggregated value
|
||||
"""
|
||||
|
||||
score_rows: list[ScoringResultRow]
|
||||
# aggregated metrics to value
|
||||
aggregated_results: dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoreBatchResponse(BaseModel):
|
||||
"""Response from batch scoring operations on datasets.
|
||||
|
||||
:param dataset_id: (Optional) The identifier of the dataset that was scored
|
||||
:param results: A map of scoring function name to ScoringResult
|
||||
"""
|
||||
|
||||
dataset_id: str | None = None
|
||||
results: dict[str, ScoringResult]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoreResponse(BaseModel):
|
||||
"""
|
||||
The response from scoring.
|
||||
|
||||
:param results: A map of scoring function name to ScoringResult.
|
||||
"""
|
||||
|
||||
# each key in the dict is a scoring function name
|
||||
results: dict[str, ScoringResult]
|
||||
|
||||
|
||||
class ScoringFunctionStore(Protocol):
|
||||
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Scoring(Protocol):
|
||||
scoring_function_store: ScoringFunctionStore
|
||||
|
||||
@webmethod(route="/scoring/score-batch", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: dict[str, ScoringFnParams | None],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
"""Score a batch of rows.
|
||||
|
||||
:param dataset_id: The ID of the dataset to score.
|
||||
:param scoring_functions: The scoring functions to use for the scoring.
|
||||
:param save_results_dataset: Whether to save the results to a dataset.
|
||||
:returns: A ScoreBatchResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring/score", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def score(
|
||||
self,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None],
|
||||
) -> ScoreResponse:
|
||||
"""Score a list of rows.
|
||||
|
||||
:param input_rows: The rows to score.
|
||||
:param scoring_functions: The scoring functions to use for the scoring.
|
||||
:returns: A ScoreResponse object containing rows and aggregated results.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/scoring_functions/__init__.py
Normal file
7
src/llama_stack/apis/scoring_functions/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .scoring_functions import *
|
||||
208
src/llama_stack/apis/scoring_functions/scoring_functions.py
Normal file
208
src/llama_stack/apis/scoring_functions/scoring_functions.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||
from enum import StrEnum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||
# with standard metrics so they can be rolled up?
|
||||
@json_schema_type
|
||||
class ScoringFnParamsType(StrEnum):
|
||||
"""Types of scoring function parameter configurations.
|
||||
:cvar llm_as_judge: Use an LLM model to evaluate and score responses
|
||||
:cvar regex_parser: Use regex patterns to extract and score specific parts of responses
|
||||
:cvar basic: Basic scoring with simple aggregation functions
|
||||
"""
|
||||
|
||||
llm_as_judge = "llm_as_judge"
|
||||
regex_parser = "regex_parser"
|
||||
basic = "basic"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AggregationFunctionType(StrEnum):
|
||||
"""Types of aggregation functions for scoring results.
|
||||
:cvar average: Calculate the arithmetic mean of scores
|
||||
:cvar weighted_average: Calculate a weighted average of scores
|
||||
:cvar median: Calculate the median value of scores
|
||||
:cvar categorical_count: Count occurrences of categorical values
|
||||
:cvar accuracy: Calculate accuracy as the proportion of correct answers
|
||||
"""
|
||||
|
||||
average = "average"
|
||||
weighted_average = "weighted_average"
|
||||
median = "median"
|
||||
categorical_count = "categorical_count"
|
||||
accuracy = "accuracy"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||
"""Parameters for LLM-as-judge scoring function configuration.
|
||||
:param type: The type of scoring function parameters, always llm_as_judge
|
||||
:param judge_model: Identifier of the LLM model to use as a judge for scoring
|
||||
:param prompt_template: (Optional) Custom prompt template for the judge model
|
||||
:param judge_score_regexes: Regexes to extract the answer from generated response
|
||||
:param aggregation_functions: Aggregation functions to apply to the scores of each row
|
||||
"""
|
||||
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
|
||||
judge_model: str
|
||||
prompt_template: str | None = None
|
||||
judge_score_regexes: list[str] = Field(
|
||||
description="Regexes to extract the answer from generated response",
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RegexParserScoringFnParams(BaseModel):
|
||||
"""Parameters for regex parser scoring function configuration.
|
||||
:param type: The type of scoring function parameters, always regex_parser
|
||||
:param parsing_regexes: Regex to extract the answer from generated response
|
||||
:param aggregation_functions: Aggregation functions to apply to the scores of each row
|
||||
"""
|
||||
|
||||
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
|
||||
parsing_regexes: list[str] = Field(
|
||||
description="Regex to extract the answer from generated response",
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BasicScoringFnParams(BaseModel):
|
||||
"""Parameters for basic scoring function configuration.
|
||||
:param type: The type of scoring function parameters, always basic
|
||||
:param aggregation_functions: Aggregation functions to apply to the scores of each row
|
||||
"""
|
||||
|
||||
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
ScoringFnParams = Annotated[
|
||||
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||
|
||||
|
||||
class CommonScoringFnFields(BaseModel):
|
||||
description: str | None = None
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this definition",
|
||||
)
|
||||
return_type: ParamType = Field(
|
||||
description="The return type of the deterministic function",
|
||||
)
|
||||
params: ScoringFnParams | None = Field(
|
||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringFn(CommonScoringFnFields, Resource):
|
||||
"""A scoring function resource for evaluating model outputs.
|
||||
:param type: The resource type, always scoring_function
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
|
||||
|
||||
@property
|
||||
def scoring_fn_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_scoring_fn_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||
scoring_fn_id: str
|
||||
provider_id: str | None = None
|
||||
provider_scoring_fn_id: str | None = None
|
||||
|
||||
|
||||
class ListScoringFunctionsResponse(BaseModel):
|
||||
data: list[ScoringFn]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring-functions", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
"""List all scoring functions.
|
||||
|
||||
:returns: A ListScoringFunctionsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
|
||||
"""Get a scoring function by its ID.
|
||||
|
||||
:param scoring_fn_id: The ID of the scoring function to get.
|
||||
:returns: A ScoringFn.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: ScoringFnParams | None = None,
|
||||
) -> None:
|
||||
"""Register a scoring function.
|
||||
|
||||
:param scoring_fn_id: The ID of the scoring function to register.
|
||||
:param description: The description of the scoring function.
|
||||
:param return_type: The return type of the scoring function.
|
||||
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
|
||||
:param provider_id: The ID of the provider to use for the scoring function.
|
||||
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
||||
"""Unregister a scoring function.
|
||||
|
||||
:param scoring_fn_id: The ID of the scoring function to unregister.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/shields/__init__.py
Normal file
7
src/llama_stack/apis/shields/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .shields import *
|
||||
94
src/llama_stack/apis/shields/shields.py
Normal file
94
src/llama_stack/apis/shields/shields.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class CommonShieldFields(BaseModel):
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Shield(CommonShieldFields, Resource):
|
||||
"""A safety shield resource that can be used to check content.
|
||||
|
||||
:param params: (Optional) Configuration parameters for the shield
|
||||
:param type: The resource type, always shield
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.shield] = ResourceType.shield
|
||||
|
||||
@property
|
||||
def shield_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_shield_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class ShieldInput(CommonShieldFields):
|
||||
shield_id: str
|
||||
provider_id: str | None = None
|
||||
provider_shield_id: str | None = None
|
||||
|
||||
|
||||
class ListShieldsResponse(BaseModel):
|
||||
data: list[Shield]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
"""List all shields.
|
||||
|
||||
:returns: A ListShieldsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_shield(self, identifier: str) -> Shield:
|
||||
"""Get a shield by its identifier.
|
||||
|
||||
:param identifier: The identifier of the shield to get.
|
||||
:returns: A Shield.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield:
|
||||
"""Register a shield.
|
||||
|
||||
:param shield_id: The identifier of the shield to register.
|
||||
:param provider_shield_id: The identifier of the shield in the provider.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param params: The parameters of the shield.
|
||||
:returns: A Shield.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
"""Unregister a shield.
|
||||
|
||||
:param identifier: The identifier of the shield to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .synthetic_data_generation import *
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class FilteringFunction(Enum):
|
||||
"""The type of filtering function.
|
||||
|
||||
:cvar none: No filtering applied, accept all generated synthetic data
|
||||
:cvar random: Random sampling of generated data points
|
||||
:cvar top_k: Keep only the top-k highest scoring synthetic data samples
|
||||
:cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold
|
||||
:cvar top_k_top_p: Combined top-k and top-p filtering strategy
|
||||
:cvar sigmoid: Apply sigmoid function for probability-based filtering
|
||||
"""
|
||||
|
||||
none = "none"
|
||||
random = "random"
|
||||
top_k = "top_k"
|
||||
top_p = "top_p"
|
||||
top_k_top_p = "top_k_top_p"
|
||||
sigmoid = "sigmoid"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SyntheticDataGenerationRequest(BaseModel):
|
||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function
|
||||
|
||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
||||
"""
|
||||
|
||||
dialogs: list[Message]
|
||||
filtering_function: FilteringFunction = FilteringFunction.none
|
||||
model: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SyntheticDataGenerationResponse(BaseModel):
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.
|
||||
|
||||
:param synthetic_data: List of generated synthetic data samples that passed the filtering criteria
|
||||
:param statistics: (Optional) Statistical information about the generation process and filtering results
|
||||
"""
|
||||
|
||||
synthetic_data: list[dict[str, Any]]
|
||||
statistics: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SyntheticDataGeneration(Protocol):
|
||||
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
|
||||
def synthetic_data_generate(
|
||||
self,
|
||||
dialogs: list[Message],
|
||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||
model: str | None = None,
|
||||
) -> SyntheticDataGenerationResponse:
|
||||
"""Generate synthetic data based on input dialogs and apply filtering.
|
||||
|
||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
||||
:returns: Response containing filtered synthetic data samples and optional statistics
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/telemetry/__init__.py
Normal file
7
src/llama_stack/apis/telemetry/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .telemetry import *
|
||||
423
src/llama_stack/apis/telemetry/telemetry.py
Normal file
423
src/llama_stack/apis/telemetry/telemetry.py
Normal file
|
|
@ -0,0 +1,423 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
# Add this constant near the top of the file, after the imports
|
||||
DEFAULT_TTL_DAYS = 7
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStatus(Enum):
|
||||
"""The status of a span indicating whether it completed successfully or with an error.
|
||||
:cvar OK: Span completed successfully without errors
|
||||
:cvar ERROR: Span completed with an error or failure
|
||||
"""
|
||||
|
||||
OK = "ok"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Span(BaseModel):
|
||||
"""A span representing a single operation within a trace.
|
||||
:param span_id: Unique identifier for the span
|
||||
:param trace_id: Unique identifier for the trace this span belongs to
|
||||
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
|
||||
:param name: Human-readable name describing the operation this span represents
|
||||
:param start_time: Timestamp when the operation began
|
||||
:param end_time: (Optional) Timestamp when the operation finished, if completed
|
||||
:param attributes: (Optional) Key-value pairs containing additional metadata about the span
|
||||
"""
|
||||
|
||||
span_id: str
|
||||
trace_id: str
|
||||
parent_span_id: str | None = None
|
||||
name: str
|
||||
start_time: datetime
|
||||
end_time: datetime | None = None
|
||||
attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||
|
||||
def set_attribute(self, key: str, value: Any):
|
||||
if self.attributes is None:
|
||||
self.attributes = {}
|
||||
self.attributes[key] = value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Trace(BaseModel):
|
||||
"""A trace representing the complete execution path of a request across multiple operations.
|
||||
:param trace_id: Unique identifier for the trace
|
||||
:param root_span_id: Unique identifier for the root span that started this trace
|
||||
:param start_time: Timestamp when the trace began
|
||||
:param end_time: (Optional) Timestamp when the trace finished, if completed
|
||||
"""
|
||||
|
||||
trace_id: str
|
||||
root_span_id: str
|
||||
start_time: datetime
|
||||
end_time: datetime | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EventType(Enum):
|
||||
"""The type of telemetry event being logged.
|
||||
:cvar UNSTRUCTURED_LOG: A simple log message with severity level
|
||||
:cvar STRUCTURED_LOG: A structured log event with typed payload data
|
||||
:cvar METRIC: A metric measurement with value and unit
|
||||
"""
|
||||
|
||||
UNSTRUCTURED_LOG = "unstructured_log"
|
||||
STRUCTURED_LOG = "structured_log"
|
||||
METRIC = "metric"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LogSeverity(Enum):
|
||||
"""The severity level of a log message.
|
||||
:cvar VERBOSE: Detailed diagnostic information for troubleshooting
|
||||
:cvar DEBUG: Debug information useful during development
|
||||
:cvar INFO: General informational messages about normal operation
|
||||
:cvar WARN: Warning messages about potentially problematic situations
|
||||
:cvar ERROR: Error messages indicating failures that don't stop execution
|
||||
:cvar CRITICAL: Critical error messages indicating severe failures
|
||||
"""
|
||||
|
||||
VERBOSE = "verbose"
|
||||
DEBUG = "debug"
|
||||
INFO = "info"
|
||||
WARN = "warn"
|
||||
ERROR = "error"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class EventCommon(BaseModel):
|
||||
"""Common fields shared by all telemetry events.
|
||||
:param trace_id: Unique identifier for the trace this event belongs to
|
||||
:param span_id: Unique identifier for the span this event belongs to
|
||||
:param timestamp: Timestamp when the event occurred
|
||||
:param attributes: (Optional) Key-value pairs containing additional metadata about the event
|
||||
"""
|
||||
|
||||
trace_id: str
|
||||
span_id: str
|
||||
timestamp: datetime
|
||||
attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UnstructuredLogEvent(EventCommon):
|
||||
"""An unstructured log event containing a simple text message.
|
||||
:param type: Event type identifier set to UNSTRUCTURED_LOG
|
||||
:param message: The log message text
|
||||
:param severity: The severity level of the log message
|
||||
"""
|
||||
|
||||
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
|
||||
message: str
|
||||
severity: LogSeverity
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricEvent(EventCommon):
|
||||
"""A metric event containing a measured value.
|
||||
:param type: Event type identifier set to METRIC
|
||||
:param metric: The name of the metric being measured
|
||||
:param value: The numeric value of the metric measurement
|
||||
:param unit: The unit of measurement for the metric value
|
||||
"""
|
||||
|
||||
type: Literal[EventType.METRIC] = EventType.METRIC
|
||||
metric: str # this would be an enum
|
||||
value: int | float
|
||||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricInResponse(BaseModel):
|
||||
"""A metric value included in API responses.
|
||||
:param metric: The name of the metric
|
||||
:param value: The numeric value of the metric
|
||||
:param unit: (Optional) The unit of measurement for the metric value
|
||||
"""
|
||||
|
||||
metric: str
|
||||
value: int | float
|
||||
unit: str | None = None
|
||||
|
||||
|
||||
# This is a short term solution to allow inference API to return metrics
|
||||
# The ideal way to do this is to have a way for all response types to include metrics
|
||||
# and all metric events logged to the telemetry API to be included with the response
|
||||
# To do this, we will need to augment all response types with a metrics field.
|
||||
# We have hit a blocker from stainless SDK that prevents us from doing this.
|
||||
# The blocker is that if we were to augment the response types that have a data field
|
||||
# in them like so
|
||||
# class ListModelsResponse(BaseModel):
|
||||
# metrics: Optional[List[MetricEvent]] = None
|
||||
# data: List[Models]
|
||||
# ...
|
||||
# The client SDK will need to access the data by using a .data field, which is not
|
||||
# ergonomic. Stainless SDK does support unwrapping the response type, but it
|
||||
# requires that the response type to only have a single field.
|
||||
|
||||
# We will need a way in the client SDK to signal that the metrics are needed
|
||||
# and if they are needed, the client SDK has to return the full response type
|
||||
# without unwrapping it.
|
||||
|
||||
|
||||
class MetricResponseMixin(BaseModel):
|
||||
"""Mixin class for API responses that can include metrics.
|
||||
:param metrics: (Optional) List of metrics associated with the API response
|
||||
"""
|
||||
|
||||
metrics: list[MetricInResponse] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StructuredLogType(Enum):
|
||||
"""The type of structured log event payload.
|
||||
:cvar SPAN_START: Event indicating the start of a new span
|
||||
:cvar SPAN_END: Event indicating the completion of a span
|
||||
"""
|
||||
|
||||
SPAN_START = "span_start"
|
||||
SPAN_END = "span_end"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStartPayload(BaseModel):
|
||||
"""Payload for a span start event.
|
||||
:param type: Payload type identifier set to SPAN_START
|
||||
:param name: Human-readable name describing the operation this span represents
|
||||
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
|
||||
"""
|
||||
|
||||
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
|
||||
name: str
|
||||
parent_span_id: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanEndPayload(BaseModel):
|
||||
"""Payload for a span end event.
|
||||
:param type: Payload type identifier set to SPAN_END
|
||||
:param status: The final status of the span indicating success or failure
|
||||
"""
|
||||
|
||||
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
|
||||
status: SpanStatus
|
||||
|
||||
|
||||
StructuredLogPayload = Annotated[
|
||||
SpanStartPayload | SpanEndPayload,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StructuredLogEvent(EventCommon):
|
||||
"""A structured log event containing typed payload data.
|
||||
:param type: Event type identifier set to STRUCTURED_LOG
|
||||
:param payload: The structured payload data for the log event
|
||||
"""
|
||||
|
||||
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
|
||||
payload: StructuredLogPayload
|
||||
|
||||
|
||||
Event = Annotated[
|
||||
UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Event, name="Event")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvalTrace(BaseModel):
|
||||
"""A trace record for evaluation purposes.
|
||||
:param session_id: Unique identifier for the evaluation session
|
||||
:param step: The evaluation step or phase identifier
|
||||
:param input: The input data for the evaluation
|
||||
:param output: The actual output produced during evaluation
|
||||
:param expected_output: The expected output for comparison during evaluation
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
step: str
|
||||
input: str
|
||||
output: str
|
||||
expected_output: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanWithStatus(Span):
|
||||
"""A span that includes status information.
|
||||
:param status: (Optional) The current status of the span
|
||||
"""
|
||||
|
||||
status: SpanStatus | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryConditionOp(Enum):
|
||||
"""Comparison operators for query conditions.
|
||||
:cvar EQ: Equal to comparison
|
||||
:cvar NE: Not equal to comparison
|
||||
:cvar GT: Greater than comparison
|
||||
:cvar LT: Less than comparison
|
||||
"""
|
||||
|
||||
EQ = "eq"
|
||||
NE = "ne"
|
||||
GT = "gt"
|
||||
LT = "lt"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryCondition(BaseModel):
|
||||
"""A condition for filtering query results.
|
||||
:param key: The attribute key to filter on
|
||||
:param op: The comparison operator to apply
|
||||
:param value: The value to compare against
|
||||
"""
|
||||
|
||||
key: str
|
||||
op: QueryConditionOp
|
||||
value: Any
|
||||
|
||||
|
||||
class QueryTracesResponse(BaseModel):
|
||||
"""Response containing a list of traces.
|
||||
:param data: List of traces matching the query criteria
|
||||
"""
|
||||
|
||||
data: list[Trace]
|
||||
|
||||
|
||||
class QuerySpansResponse(BaseModel):
|
||||
"""Response containing a list of spans.
|
||||
:param data: List of spans matching the query criteria
|
||||
"""
|
||||
|
||||
data: list[Span]
|
||||
|
||||
|
||||
class QuerySpanTreeResponse(BaseModel):
|
||||
"""Response containing a tree structure of spans.
|
||||
:param data: Dictionary mapping span IDs to spans with status information
|
||||
"""
|
||||
|
||||
data: dict[str, SpanWithStatus]
|
||||
|
||||
|
||||
class MetricQueryType(Enum):
|
||||
"""The type of metric query to perform.
|
||||
:cvar RANGE: Query metrics over a time range
|
||||
:cvar INSTANT: Query metrics at a specific point in time
|
||||
"""
|
||||
|
||||
RANGE = "range"
|
||||
INSTANT = "instant"
|
||||
|
||||
|
||||
class MetricLabelOperator(Enum):
|
||||
"""Operators for matching metric labels.
|
||||
:cvar EQUALS: Label value must equal the specified value
|
||||
:cvar NOT_EQUALS: Label value must not equal the specified value
|
||||
:cvar REGEX_MATCH: Label value must match the specified regular expression
|
||||
:cvar REGEX_NOT_MATCH: Label value must not match the specified regular expression
|
||||
"""
|
||||
|
||||
EQUALS = "="
|
||||
NOT_EQUALS = "!="
|
||||
REGEX_MATCH = "=~"
|
||||
REGEX_NOT_MATCH = "!~"
|
||||
|
||||
|
||||
class MetricLabelMatcher(BaseModel):
|
||||
"""A matcher for filtering metrics by label values.
|
||||
:param name: The name of the label to match
|
||||
:param value: The value to match against
|
||||
:param operator: The comparison operator to use for matching
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: str
|
||||
operator: MetricLabelOperator = MetricLabelOperator.EQUALS
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricLabel(BaseModel):
|
||||
"""A label associated with a metric.
|
||||
:param name: The name of the label
|
||||
:param value: The value of the label
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricDataPoint(BaseModel):
|
||||
"""A single data point in a metric time series.
|
||||
:param timestamp: Unix timestamp when the metric value was recorded
|
||||
:param value: The numeric value of the metric at this timestamp
|
||||
"""
|
||||
|
||||
timestamp: int
|
||||
value: float
|
||||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricSeries(BaseModel):
|
||||
"""A time series of metric data points.
|
||||
:param metric: The name of the metric
|
||||
:param labels: List of labels associated with this metric series
|
||||
:param values: List of data points in chronological order
|
||||
"""
|
||||
|
||||
metric: str
|
||||
labels: list[MetricLabel]
|
||||
values: list[MetricDataPoint]
|
||||
|
||||
|
||||
class QueryMetricsResponse(BaseModel):
|
||||
"""Response containing metric time series data.
|
||||
:param data: List of metric series matching the query criteria
|
||||
"""
|
||||
|
||||
data: list[MetricSeries]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
async def log_event(
|
||||
self,
|
||||
event: Event,
|
||||
ttl_seconds: int = DEFAULT_TTL_DAYS * 86400,
|
||||
) -> None:
|
||||
"""Log an event.
|
||||
|
||||
:param event: The event to log.
|
||||
:param ttl_seconds: The time to live of the event.
|
||||
"""
|
||||
...
|
||||
8
src/llama_stack/apis/tools/__init__.py
Normal file
8
src/llama_stack/apis/tools/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .rag_tool import *
|
||||
from .tools import *
|
||||
218
src/llama_stack/apis/tools/rag_tool.py
Normal file
218
src/llama_stack/apis/tools/rag_tool.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RRFRanker(BaseModel):
|
||||
"""
|
||||
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||
|
||||
:param type: The type of ranker, always "rrf"
|
||||
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
|
||||
Must be greater than 0
|
||||
"""
|
||||
|
||||
type: Literal["rrf"] = "rrf"
|
||||
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WeightedRanker(BaseModel):
|
||||
"""
|
||||
Weighted ranker configuration that combines vector and keyword scores.
|
||||
|
||||
:param type: The type of ranker, always "weighted"
|
||||
:param alpha: Weight factor between 0 and 1.
|
||||
0 means only use keyword scores,
|
||||
1 means only use vector scores,
|
||||
values in between blend both scores.
|
||||
"""
|
||||
|
||||
type: Literal["weighted"] = "weighted"
|
||||
alpha: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
|
||||
)
|
||||
|
||||
|
||||
Ranker = Annotated[
|
||||
RRFRanker | WeightedRanker,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Ranker, name="Ranker")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
"""
|
||||
A document to be used for document ingestion in the RAG Tool.
|
||||
|
||||
:param document_id: The unique identifier for the document.
|
||||
:param content: The content of the document.
|
||||
:param mime_type: The MIME type of the document.
|
||||
:param metadata: Additional metadata for the document.
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryResult(BaseModel):
|
||||
"""Result of a RAG query containing retrieved content and metadata.
|
||||
|
||||
:param content: (Optional) The retrieved content from the query
|
||||
:param metadata: Additional metadata about the query result
|
||||
"""
|
||||
|
||||
content: InterleavedContent | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryGenerator(Enum):
|
||||
"""Types of query generators for RAG systems.
|
||||
|
||||
:cvar default: Default query generator using simple text processing
|
||||
:cvar llm: LLM-based query generator for enhanced query understanding
|
||||
:cvar custom: Custom query generator implementation
|
||||
"""
|
||||
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGSearchMode(StrEnum):
|
||||
"""
|
||||
Search modes for RAG query retrieval:
|
||||
- VECTOR: Uses vector similarity search for semantic matching
|
||||
- KEYWORD: Uses keyword-based search for exact matching
|
||||
- HYBRID: Combines both vector and keyword search for better results
|
||||
"""
|
||||
|
||||
VECTOR = "vector"
|
||||
KEYWORD = "keyword"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||
"""Configuration for the default RAG query generator.
|
||||
|
||||
:param type: Type of query generator, always 'default'
|
||||
:param separator: String separator used to join query terms
|
||||
"""
|
||||
|
||||
type: Literal["default"] = "default"
|
||||
separator: str = " "
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||
"""Configuration for the LLM-based RAG query generator.
|
||||
|
||||
:param type: Type of query generator, always 'llm'
|
||||
:param model: Name of the language model to use for query generation
|
||||
:param template: Template string for formatting the query generation prompt
|
||||
"""
|
||||
|
||||
type: Literal["llm"] = "llm"
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
RAGQueryGeneratorConfig = Annotated[
|
||||
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the RAG query generation.
|
||||
|
||||
:param query_generator_config: Configuration for the query generator.
|
||||
:param max_tokens_in_context: Maximum number of tokens in the context.
|
||||
:param max_chunks: Maximum number of chunks to retrieve.
|
||||
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
||||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||
:param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector".
|
||||
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
|
||||
"""
|
||||
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||
mode: RAGSearchMode | None = RAGSearchMode.VECTOR
|
||||
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
|
||||
|
||||
@field_validator("chunk_template")
|
||||
def validate_chunk_template(cls, v: str) -> str:
|
||||
if "{chunk.content}" not in v:
|
||||
raise ValueError("chunk_template must contain {chunk.content}")
|
||||
if "{index}" not in v:
|
||||
raise ValueError("chunk_template must contain {index}")
|
||||
if len(v) == 0:
|
||||
raise ValueError("chunk_template must not be empty")
|
||||
return v
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class RAGToolRuntime(Protocol):
|
||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def insert(
|
||||
self,
|
||||
documents: list[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
"""Index documents so they can be used by the RAG system.
|
||||
|
||||
:param documents: List of documents to index in the RAG system
|
||||
:param vector_db_id: ID of the vector database to store the document embeddings
|
||||
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
"""Query the RAG system for context; typically invoked by the agent.
|
||||
|
||||
:param content: The query content to search for in the indexed documents
|
||||
:param vector_db_ids: List of vector database IDs to search within
|
||||
:param query_config: (Optional) Configuration parameters for the query operation
|
||||
:returns: RAGQueryResult containing the retrieved content and metadata
|
||||
"""
|
||||
...
|
||||
221
src/llama_stack/apis/tools/tools.py
Normal file
221
src/llama_stack/apis/tools/tools.py
Normal file
|
|
@ -0,0 +1,221 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from .rag_tool import RAGToolRuntime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
"""Tool definition used in runtime contexts.
|
||||
|
||||
:param name: Name of the tool
|
||||
:param description: (Optional) Human-readable description of what the tool does
|
||||
:param input_schema: (Optional) JSON Schema for tool inputs (MCP inputSchema)
|
||||
:param output_schema: (Optional) JSON Schema for tool outputs (MCP outputSchema)
|
||||
:param metadata: (Optional) Additional metadata about the tool
|
||||
:param toolgroup_id: (Optional) ID of the tool group this tool belongs to
|
||||
"""
|
||||
|
||||
toolgroup_id: str | None = None
|
||||
name: str
|
||||
description: str | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolGroupInput(BaseModel):
|
||||
"""Input data for registering a tool group.
|
||||
|
||||
:param toolgroup_id: Unique identifier for the tool group
|
||||
:param provider_id: ID of the provider that will handle this tool group
|
||||
:param args: (Optional) Additional arguments to pass to the provider
|
||||
:param mcp_endpoint: (Optional) Model Context Protocol endpoint for remote tools
|
||||
"""
|
||||
|
||||
toolgroup_id: str
|
||||
provider_id: str
|
||||
args: dict[str, Any] | None = None
|
||||
mcp_endpoint: URL | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolGroup(Resource):
|
||||
"""A group of related tools managed together.
|
||||
|
||||
:param type: Type of resource, always 'tool_group'
|
||||
:param mcp_endpoint: (Optional) Model Context Protocol endpoint for remote tools
|
||||
:param args: (Optional) Additional arguments for the tool group
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.tool_group] = ResourceType.tool_group
|
||||
mcp_endpoint: URL | None = None
|
||||
args: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolInvocationResult(BaseModel):
|
||||
"""Result of a tool invocation.
|
||||
|
||||
:param content: (Optional) The output content from the tool execution
|
||||
:param error_message: (Optional) Error message if the tool execution failed
|
||||
:param error_code: (Optional) Numeric error code if the tool execution failed
|
||||
:param metadata: (Optional) Additional metadata about the tool execution
|
||||
"""
|
||||
|
||||
content: InterleavedContent | None = None
|
||||
error_message: str | None = None
|
||||
error_code: int | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ToolStore(Protocol):
|
||||
async def get_tool(self, tool_name: str) -> ToolDef: ...
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||
|
||||
|
||||
class ListToolGroupsResponse(BaseModel):
|
||||
"""Response containing a list of tool groups.
|
||||
|
||||
:param data: List of tool groups
|
||||
"""
|
||||
|
||||
data: list[ToolGroup]
|
||||
|
||||
|
||||
class ListToolDefsResponse(BaseModel):
|
||||
"""Response containing a list of tool definitions.
|
||||
|
||||
:param data: List of tool definitions
|
||||
"""
|
||||
|
||||
data: list[ToolDef]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolGroups(Protocol):
|
||||
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
provider_id: str,
|
||||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Register a tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to register.
|
||||
:param provider_id: The ID of the provider to use for the tool group.
|
||||
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
||||
:param args: A dictionary of arguments to pass to the tool group.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
) -> ToolGroup:
|
||||
"""Get a tool group by its ID.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to get.
|
||||
:returns: A ToolGroup.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
"""List tool groups with optional provider.
|
||||
|
||||
:returns: A ListToolGroupsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
"""List tools with optional tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to list tools for.
|
||||
:returns: A ListToolDefsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools/{tool_name:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
) -> ToolDef:
|
||||
"""Get a tool by its name.
|
||||
|
||||
:param tool_name: The name of the tool to get.
|
||||
:returns: A ToolDef.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_toolgroup(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
) -> None:
|
||||
"""Unregister a tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to unregister.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SpecialToolGroup(Enum):
|
||||
"""Special tool groups with predefined functionality.
|
||||
|
||||
:cvar rag_tool: Retrieval-Augmented Generation tool group for document search and retrieval
|
||||
"""
|
||||
|
||||
rag_tool = "rag_tool"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore | None = None
|
||||
|
||||
rag_tool: RAGToolRuntime | None = None
|
||||
|
||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
"""List all tools in the runtime.
|
||||
|
||||
:param tool_group_id: The ID of the tool group to list tools for.
|
||||
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
||||
:returns: A ListToolDefsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments.
|
||||
|
||||
:param tool_name: The name of the tool to invoke.
|
||||
:param kwargs: A dictionary of arguments to pass to the tool.
|
||||
:returns: A ToolInvocationResult.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/vector_io/__init__.py
Normal file
7
src/llama_stack/apis/vector_io/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .vector_io import *
|
||||
960
src/llama_stack/apis/vector_io/vector_io.py
Normal file
960
src/llama_stack/apis/vector_io/vector_io.py
Normal file
|
|
@ -0,0 +1,960 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import uuid
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import Body
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.strong_typing.schema import register_schema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChunkMetadata(BaseModel):
|
||||
"""
|
||||
`ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that
|
||||
will not be used in the context during inference, but is required for backend functionality. The `ChunkMetadata`
|
||||
is set during chunk creation in `MemoryToolRuntimeImpl().insert()`and is not expected to change after.
|
||||
Use `Chunk.metadata` for metadata that will be used in the context during inference.
|
||||
:param chunk_id: The ID of the chunk. If not set, it will be generated based on the document ID and content.
|
||||
:param document_id: The ID of the document this chunk belongs to.
|
||||
:param source: The source of the content, such as a URL, file path, or other identifier.
|
||||
:param created_timestamp: An optional timestamp indicating when the chunk was created.
|
||||
:param updated_timestamp: An optional timestamp indicating when the chunk was last updated.
|
||||
:param chunk_window: The window of the chunk, which can be used to group related chunks together.
|
||||
:param chunk_tokenizer: The tokenizer used to create the chunk. Default is Tiktoken.
|
||||
:param chunk_embedding_model: The embedding model used to create the chunk's embedding.
|
||||
:param chunk_embedding_dimension: The dimension of the embedding vector for the chunk.
|
||||
:param content_token_count: The number of tokens in the content of the chunk.
|
||||
:param metadata_token_count: The number of tokens in the metadata of the chunk.
|
||||
"""
|
||||
|
||||
chunk_id: str | None = None
|
||||
document_id: str | None = None
|
||||
source: str | None = None
|
||||
created_timestamp: int | None = None
|
||||
updated_timestamp: int | None = None
|
||||
chunk_window: str | None = None
|
||||
chunk_tokenizer: str | None = None
|
||||
chunk_embedding_model: str | None = None
|
||||
chunk_embedding_dimension: int | None = None
|
||||
content_token_count: int | None = None
|
||||
metadata_token_count: int | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Chunk(BaseModel):
|
||||
"""
|
||||
A chunk of content that can be inserted into a vector database.
|
||||
:param content: The content of the chunk, which can be interleaved text, images, or other types.
|
||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
||||
:param metadata: Metadata associated with the chunk that will be used in the model context during inference.
|
||||
:param stored_chunk_id: The chunk ID that is stored in the vector database. Used for backend functionality.
|
||||
:param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference.
|
||||
The `chunk_metadata` is required backend functionality.
|
||||
"""
|
||||
|
||||
content: InterleavedContent
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
embedding: list[float] | None = None
|
||||
# The alias parameter serializes the field as "chunk_id" in JSON but keeps the internal name as "stored_chunk_id"
|
||||
stored_chunk_id: str | None = Field(default=None, alias="chunk_id")
|
||||
chunk_metadata: ChunkMetadata | None = None
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
def model_post_init(self, __context):
|
||||
# Extract chunk_id from metadata if present
|
||||
if self.metadata and "chunk_id" in self.metadata:
|
||||
self.stored_chunk_id = self.metadata.pop("chunk_id")
|
||||
|
||||
@property
|
||||
def chunk_id(self) -> str:
|
||||
"""Returns the chunk ID, which is either an input `chunk_id` or a generated one if not set."""
|
||||
if self.stored_chunk_id:
|
||||
return self.stored_chunk_id
|
||||
|
||||
if "document_id" in self.metadata:
|
||||
return generate_chunk_id(self.metadata["document_id"], str(self.content))
|
||||
|
||||
return generate_chunk_id(str(uuid.uuid4()), str(self.content))
|
||||
|
||||
@property
|
||||
def document_id(self) -> str | None:
|
||||
"""Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence."""
|
||||
# Check metadata first (takes precedence)
|
||||
doc_id = self.metadata.get("document_id")
|
||||
if doc_id is not None:
|
||||
if not isinstance(doc_id, str):
|
||||
raise TypeError(f"metadata['document_id'] must be a string, got {type(doc_id).__name__}: {doc_id!r}")
|
||||
return doc_id
|
||||
|
||||
# Fall back to chunk_metadata if available (Pydantic ensures type safety)
|
||||
if self.chunk_metadata is not None:
|
||||
return self.chunk_metadata.document_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryChunksResponse(BaseModel):
|
||||
"""Response from querying chunks in a vector database.
|
||||
|
||||
:param chunks: List of content chunks returned from the query
|
||||
:param scores: Relevance scores corresponding to each returned chunk
|
||||
"""
|
||||
|
||||
chunks: list[Chunk]
|
||||
scores: list[float]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileCounts(BaseModel):
|
||||
"""File processing status counts for a vector store.
|
||||
|
||||
:param completed: Number of files that have been successfully processed
|
||||
:param cancelled: Number of files that had their processing cancelled
|
||||
:param failed: Number of files that failed to process
|
||||
:param in_progress: Number of files currently being processed
|
||||
:param total: Total number of files in the vector store
|
||||
"""
|
||||
|
||||
completed: int
|
||||
cancelled: int
|
||||
failed: int
|
||||
in_progress: int
|
||||
total: int
|
||||
|
||||
|
||||
# TODO: rename this as OpenAIVectorStore
|
||||
@json_schema_type
|
||||
class VectorStoreObject(BaseModel):
|
||||
"""OpenAI Vector Store object.
|
||||
|
||||
:param id: Unique identifier for the vector store
|
||||
:param object: Object type identifier, always "vector_store"
|
||||
:param created_at: Timestamp when the vector store was created
|
||||
:param name: (Optional) Name of the vector store
|
||||
:param usage_bytes: Storage space used by the vector store in bytes
|
||||
:param file_counts: File processing status counts for the vector store
|
||||
:param status: Current status of the vector store
|
||||
:param expires_after: (Optional) Expiration policy for the vector store
|
||||
:param expires_at: (Optional) Timestamp when the vector store will expire
|
||||
:param last_active_at: (Optional) Timestamp of last activity on the vector store
|
||||
:param metadata: Set of key-value pairs that can be attached to the vector store
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store"
|
||||
created_at: int
|
||||
name: str | None = None
|
||||
usage_bytes: int = 0
|
||||
file_counts: VectorStoreFileCounts
|
||||
status: str = "completed"
|
||||
expires_after: dict[str, Any] | None = None
|
||||
expires_at: int | None = None
|
||||
last_active_at: int | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreCreateRequest(BaseModel):
|
||||
"""Request to create a vector store.
|
||||
|
||||
:param name: (Optional) Name for the vector store
|
||||
:param file_ids: List of file IDs to include in the vector store
|
||||
:param expires_after: (Optional) Expiration policy for the vector store
|
||||
:param chunking_strategy: (Optional) Strategy for splitting files into chunks
|
||||
:param metadata: Set of key-value pairs that can be attached to the vector store
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
file_ids: list[str] = Field(default_factory=list)
|
||||
expires_after: dict[str, Any] | None = None
|
||||
chunking_strategy: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreModifyRequest(BaseModel):
|
||||
"""Request to modify a vector store.
|
||||
|
||||
:param name: (Optional) Updated name for the vector store
|
||||
:param expires_after: (Optional) Updated expiration policy for the vector store
|
||||
:param metadata: (Optional) Updated set of key-value pairs for the vector store
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
expires_after: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreListResponse(BaseModel):
|
||||
"""Response from listing vector stores.
|
||||
|
||||
:param object: Object type identifier, always "list"
|
||||
:param data: List of vector store objects
|
||||
:param first_id: (Optional) ID of the first vector store in the list for pagination
|
||||
:param last_id: (Optional) ID of the last vector store in the list for pagination
|
||||
:param has_more: Whether there are more vector stores available beyond this page
|
||||
"""
|
||||
|
||||
object: str = "list"
|
||||
data: list[VectorStoreObject]
|
||||
first_id: str | None = None
|
||||
last_id: str | None = None
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreSearchRequest(BaseModel):
|
||||
"""Request to search a vector store.
|
||||
|
||||
:param query: Search query as a string or list of strings
|
||||
:param filters: (Optional) Filters based on file attributes to narrow search results
|
||||
:param max_num_results: Maximum number of results to return, defaults to 10
|
||||
:param ranking_options: (Optional) Options for ranking and filtering search results
|
||||
:param rewrite_query: Whether to rewrite the query for better vector search performance
|
||||
"""
|
||||
|
||||
query: str | list[str]
|
||||
filters: dict[str, Any] | None = None
|
||||
max_num_results: int = 10
|
||||
ranking_options: dict[str, Any] | None = None
|
||||
rewrite_query: bool = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreContent(BaseModel):
|
||||
"""Content item from a vector store file or search result.
|
||||
|
||||
:param type: Content type, currently only "text" is supported
|
||||
:param text: The actual text content
|
||||
"""
|
||||
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreSearchResponse(BaseModel):
|
||||
"""Response from searching a vector store.
|
||||
|
||||
:param file_id: Unique identifier of the file containing the result
|
||||
:param filename: Name of the file containing the result
|
||||
:param score: Relevance score for this search result
|
||||
:param attributes: (Optional) Key-value attributes associated with the file
|
||||
:param content: List of content items matching the search query
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
score: float
|
||||
attributes: dict[str, str | float | bool] | None = None
|
||||
content: list[VectorStoreContent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreSearchResponsePage(BaseModel):
|
||||
"""Paginated response from searching a vector store.
|
||||
|
||||
:param object: Object type identifier for the search results page
|
||||
:param search_query: The original search query that was executed
|
||||
:param data: List of search result objects
|
||||
:param has_more: Whether there are more results available beyond this page
|
||||
:param next_page: (Optional) Token for retrieving the next page of results
|
||||
"""
|
||||
|
||||
object: str = "vector_store.search_results.page"
|
||||
search_query: str
|
||||
data: list[VectorStoreSearchResponse]
|
||||
has_more: bool = False
|
||||
next_page: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreDeleteResponse(BaseModel):
|
||||
"""Response from deleting a vector store.
|
||||
|
||||
:param id: Unique identifier of the deleted vector store
|
||||
:param object: Object type identifier for the deletion response
|
||||
:param deleted: Whether the deletion operation was successful
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store.deleted"
|
||||
deleted: bool = True
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreChunkingStrategyAuto(BaseModel):
|
||||
"""Automatic chunking strategy for vector store files.
|
||||
|
||||
:param type: Strategy type, always "auto" for automatic chunking
|
||||
"""
|
||||
|
||||
type: Literal["auto"] = "auto"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreChunkingStrategyStaticConfig(BaseModel):
|
||||
"""Configuration for static chunking strategy.
|
||||
|
||||
:param chunk_overlap_tokens: Number of tokens to overlap between adjacent chunks
|
||||
:param max_chunk_size_tokens: Maximum number of tokens per chunk, must be between 100 and 4096
|
||||
"""
|
||||
|
||||
chunk_overlap_tokens: int = 400
|
||||
max_chunk_size_tokens: int = Field(800, ge=100, le=4096)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreChunkingStrategyStatic(BaseModel):
|
||||
"""Static chunking strategy with configurable parameters.
|
||||
|
||||
:param type: Strategy type, always "static" for static chunking
|
||||
:param static: Configuration parameters for the static chunking strategy
|
||||
"""
|
||||
|
||||
type: Literal["static"] = "static"
|
||||
static: VectorStoreChunkingStrategyStaticConfig
|
||||
|
||||
|
||||
VectorStoreChunkingStrategy = Annotated[
|
||||
VectorStoreChunkingStrategyAuto | VectorStoreChunkingStrategyStatic,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
|
||||
|
||||
|
||||
class SearchRankingOptions(BaseModel):
|
||||
"""Options for ranking and filtering search results.
|
||||
|
||||
:param ranker: (Optional) Name of the ranking algorithm to use
|
||||
:param score_threshold: (Optional) Minimum relevance score threshold for results
|
||||
"""
|
||||
|
||||
ranker: str | None = None
|
||||
# NOTE: OpenAI File Search Tool requires threshold to be between 0 and 1, however
|
||||
# we don't guarantee that the score is between 0 and 1, so will leave this unconstrained
|
||||
# and let the provider handle it
|
||||
score_threshold: float | None = Field(default=0.0)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileLastError(BaseModel):
|
||||
"""Error information for failed vector store file processing.
|
||||
|
||||
:param code: Error code indicating the type of failure
|
||||
:param message: Human-readable error message describing the failure
|
||||
"""
|
||||
|
||||
code: Literal["server_error"] | Literal["rate_limit_exceeded"]
|
||||
message: str
|
||||
|
||||
|
||||
VectorStoreFileStatus = Literal["completed"] | Literal["in_progress"] | Literal["cancelled"] | Literal["failed"]
|
||||
register_schema(VectorStoreFileStatus, name="VectorStoreFileStatus")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileObject(BaseModel):
|
||||
"""OpenAI Vector Store File object.
|
||||
|
||||
:param id: Unique identifier for the file
|
||||
:param object: Object type identifier, always "vector_store.file"
|
||||
:param attributes: Key-value attributes associated with the file
|
||||
:param chunking_strategy: Strategy used for splitting the file into chunks
|
||||
:param created_at: Timestamp when the file was added to the vector store
|
||||
:param last_error: (Optional) Error information if file processing failed
|
||||
:param status: Current processing status of the file
|
||||
:param usage_bytes: Storage space used by this file in bytes
|
||||
:param vector_store_id: ID of the vector store containing this file
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store.file"
|
||||
attributes: dict[str, Any] = Field(default_factory=dict)
|
||||
chunking_strategy: VectorStoreChunkingStrategy
|
||||
created_at: int
|
||||
last_error: VectorStoreFileLastError | None = None
|
||||
status: VectorStoreFileStatus
|
||||
usage_bytes: int = 0
|
||||
vector_store_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreListFilesResponse(BaseModel):
|
||||
"""Response from listing files in a vector store.
|
||||
|
||||
:param object: Object type identifier, always "list"
|
||||
:param data: List of vector store file objects
|
||||
:param first_id: (Optional) ID of the first file in the list for pagination
|
||||
:param last_id: (Optional) ID of the last file in the list for pagination
|
||||
:param has_more: Whether there are more files available beyond this page
|
||||
"""
|
||||
|
||||
object: str = "list"
|
||||
data: list[VectorStoreFileObject]
|
||||
first_id: str | None = None
|
||||
last_id: str | None = None
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileContentsResponse(BaseModel):
|
||||
"""Response from retrieving the contents of a vector store file.
|
||||
|
||||
:param file_id: Unique identifier for the file
|
||||
:param filename: Name of the file
|
||||
:param attributes: Key-value attributes associated with the file
|
||||
:param content: List of content items from the file
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
attributes: dict[str, Any]
|
||||
content: list[VectorStoreContent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileDeleteResponse(BaseModel):
|
||||
"""Response from deleting a vector store file.
|
||||
|
||||
:param id: Unique identifier of the deleted file
|
||||
:param object: Object type identifier for the deletion response
|
||||
:param deleted: Whether the deletion operation was successful
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store.file.deleted"
|
||||
deleted: bool = True
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileBatchObject(BaseModel):
|
||||
"""OpenAI Vector Store File Batch object.
|
||||
|
||||
:param id: Unique identifier for the file batch
|
||||
:param object: Object type identifier, always "vector_store.file_batch"
|
||||
:param created_at: Timestamp when the file batch was created
|
||||
:param vector_store_id: ID of the vector store containing the file batch
|
||||
:param status: Current processing status of the file batch
|
||||
:param file_counts: File processing status counts for the batch
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store.file_batch"
|
||||
created_at: int
|
||||
vector_store_id: str
|
||||
status: VectorStoreFileStatus
|
||||
file_counts: VectorStoreFileCounts
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFilesListInBatchResponse(BaseModel):
|
||||
"""Response from listing files in a vector store file batch.
|
||||
|
||||
:param object: Object type identifier, always "list"
|
||||
:param data: List of vector store file objects in the batch
|
||||
:param first_id: (Optional) ID of the first file in the list for pagination
|
||||
:param last_id: (Optional) ID of the last file in the list for pagination
|
||||
:param has_more: Whether there are more files available beyond this page
|
||||
"""
|
||||
|
||||
object: str = "list"
|
||||
data: list[VectorStoreFileObject]
|
||||
first_id: str | None = None
|
||||
last_id: str | None = None
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
# extra_body can be accessed via .model_extra
|
||||
@json_schema_type
|
||||
class OpenAICreateVectorStoreRequestWithExtraBody(BaseModel, extra="allow"):
|
||||
"""Request to create a vector store with extra_body support.
|
||||
|
||||
:param name: (Optional) A name for the vector store
|
||||
:param file_ids: List of file IDs to include in the vector store
|
||||
:param expires_after: (Optional) Expiration policy for the vector store
|
||||
:param chunking_strategy: (Optional) Strategy for splitting files into chunks
|
||||
:param metadata: Set of key-value pairs that can be attached to the vector store
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
file_ids: list[str] | None = None
|
||||
expires_after: dict[str, Any] | None = None
|
||||
chunking_strategy: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# extra_body can be accessed via .model_extra
|
||||
@json_schema_type
|
||||
class OpenAICreateVectorStoreFileBatchRequestWithExtraBody(BaseModel, extra="allow"):
|
||||
"""Request to create a vector store file batch with extra_body support.
|
||||
|
||||
:param file_ids: A list of File IDs that the vector store should use
|
||||
:param attributes: (Optional) Key-value attributes to store with the files
|
||||
:param chunking_strategy: (Optional) The chunking strategy used to chunk the file(s). Defaults to auto
|
||||
"""
|
||||
|
||||
file_ids: list[str]
|
||||
attributes: dict[str, Any] | None = None
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None
|
||||
|
||||
|
||||
class VectorStoreTable(Protocol):
|
||||
def get_vector_store(self, vector_store_id: str) -> VectorStore | None: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class VectorIO(Protocol):
|
||||
vector_store_table: VectorStoreTable | None = None
|
||||
|
||||
# this will just block now until chunks are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
# TODO: rename vector_db_id to vector_store_id once Stainless is working
|
||||
@webmethod(route="/vector-io/insert", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Insert chunks into a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
||||
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
|
||||
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
|
||||
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
|
||||
If `embedding` is not provided, it will be computed later.
|
||||
:param ttl_seconds: The time to live of the chunks.
|
||||
"""
|
||||
...
|
||||
|
||||
# TODO: rename vector_db_id to vector_store_id once Stainless is working
|
||||
@webmethod(route="/vector-io/query", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
"""Query chunks from a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to query.
|
||||
:param query: The query to search for.
|
||||
:param params: The parameters of the query.
|
||||
:returns: A QueryChunksResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
@webmethod(route="/openai/v1/vector_stores", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
|
||||
) -> VectorStoreObject:
|
||||
"""Creates a vector store.
|
||||
|
||||
Generate an OpenAI-compatible vector store with the given parameters.
|
||||
:returns: A VectorStoreObject representing the created vector store.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/vector_stores", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
"""Returns a list of vector stores.
|
||||
|
||||
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
||||
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
|
||||
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list.
|
||||
:param before: A cursor for use in pagination. `before` is an object ID that defines your place in the list.
|
||||
:returns: A VectorStoreListResponse containing the list of vector stores.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
"""Retrieves a vector store.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to retrieve.
|
||||
:returns: A VectorStoreObject representing the vector store.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
"""Updates a vector store.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to update.
|
||||
:param name: The name of the vector store.
|
||||
:param expires_after: The expiration policy for a vector store.
|
||||
:param metadata: Set of 16 key-value pairs that can be attached to an object.
|
||||
:returns: A VectorStoreObject representing the updated vector store.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
"""Delete a vector store.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to delete.
|
||||
:returns: A VectorStoreDeleteResponse indicating the deletion status.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/search",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/search",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: (
|
||||
str | None
|
||||
) = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
"""Search for chunks in a vector store.
|
||||
|
||||
Searches a vector store for relevant chunks based on a query and optional file attribute filters.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to search.
|
||||
:param query: The query string or array for performing the search.
|
||||
:param filters: Filters based on file attributes to narrow the search results.
|
||||
:param max_num_results: Maximum number of results to return (1 to 50 inclusive, default 10).
|
||||
:param ranking_options: Ranking options for fine-tuning the search results.
|
||||
:param rewrite_query: Whether to rewrite the natural language query for vector search (default false)
|
||||
:param search_mode: The search mode to use - "keyword", "vector", or "hybrid" (default "vector")
|
||||
:returns: A VectorStoreSearchResponse containing the search results.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
"""Attach a file to a vector store.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to attach the file to.
|
||||
:param file_id: The ID of the file to attach to the vector store.
|
||||
:param attributes: The key-value attributes stored with the file, which can be used for filtering.
|
||||
:param chunking_strategy: The chunking strategy to use for the file.
|
||||
:returns: A VectorStoreFileObject representing the attached file.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> VectorStoreListFilesResponse:
|
||||
"""List files in a vector store.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to list files from.
|
||||
:param limit: (Optional) A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
||||
:param order: (Optional) Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
|
||||
:param after: (Optional) A cursor for use in pagination. `after` is an object ID that defines your place in the list.
|
||||
:param before: (Optional) A cursor for use in pagination. `before` is an object ID that defines your place in the list.
|
||||
:param filter: (Optional) Filter by file status to only return files with the specified status.
|
||||
:returns: A VectorStoreListFilesResponse containing the list of files.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
"""Retrieves a vector store file.
|
||||
|
||||
:param vector_store_id: The ID of the vector store containing the file to retrieve.
|
||||
:param file_id: The ID of the file to retrieve.
|
||||
:returns: A VectorStoreFileObject representing the file.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}/content",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
"""Retrieves the contents of a vector store file.
|
||||
|
||||
:param vector_store_id: The ID of the vector store containing the file to retrieve.
|
||||
:param file_id: The ID of the file to retrieve.
|
||||
:returns: A list of InterleavedContent representing the file contents.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
"""Updates a vector store file.
|
||||
|
||||
:param vector_store_id: The ID of the vector store containing the file to update.
|
||||
:param file_id: The ID of the file to update.
|
||||
:param attributes: The updated key-value attributes to store with the file.
|
||||
:returns: A VectorStoreFileObject representing the updated file.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
"""Delete a vector store file.
|
||||
|
||||
:param vector_store_id: The ID of the vector store containing the file to delete.
|
||||
:param file_id: The ID of the file to delete.
|
||||
:returns: A VectorStoreFileDeleteResponse indicating the deletion status.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)],
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Create a vector store file batch.
|
||||
|
||||
Generate an OpenAI-compatible vector store file batch for the given vector store.
|
||||
:param vector_store_id: The ID of the vector store to create the file batch for.
|
||||
:returns: A VectorStoreFileBatchObject representing the created file batch.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Retrieve a vector store file batch.
|
||||
|
||||
:param batch_id: The ID of the file batch to retrieve.
|
||||
:param vector_store_id: The ID of the vector store containing the file batch.
|
||||
:returns: A VectorStoreFileBatchObject representing the file batch.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: str | None = None,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
) -> VectorStoreFilesListInBatchResponse:
|
||||
"""Returns a list of vector store files in a batch.
|
||||
|
||||
:param batch_id: The ID of the file batch to list files from.
|
||||
:param vector_store_id: The ID of the vector store containing the file batch.
|
||||
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list.
|
||||
:param before: A cursor for use in pagination. `before` is an object ID that defines your place in the list.
|
||||
:param filter: Filter by file status. One of in_progress, completed, failed, cancelled.
|
||||
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
||||
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
|
||||
:returns: A VectorStoreFilesListInBatchResponse containing the list of files in the batch.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Cancels a vector store file batch.
|
||||
|
||||
:param batch_id: The ID of the file batch to cancel.
|
||||
:param vector_store_id: The ID of the vector store containing the file batch.
|
||||
:returns: A VectorStoreFileBatchObject representing the cancelled file batch.
|
||||
"""
|
||||
...
|
||||
7
src/llama_stack/apis/vector_stores/__init__.py
Normal file
7
src/llama_stack/apis/vector_stores/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .vector_stores import *
|
||||
51
src/llama_stack/apis/vector_stores/vector_stores.py
Normal file
51
src/llama_stack/apis/vector_stores/vector_stores.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
# Internal resource type for storing the vector store routing and other information
|
||||
class VectorStore(Resource):
|
||||
"""Vector database resource for storing and querying vector embeddings.
|
||||
|
||||
:param type: Type of resource, always 'vector_store' for vector stores
|
||||
:param embedding_model: Name of the embedding model to use for vector generation
|
||||
:param embedding_dimension: Dimension of the embedding vectors
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.vector_store] = ResourceType.vector_store
|
||||
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
vector_store_name: str | None = None
|
||||
|
||||
@property
|
||||
def vector_store_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_vector_store_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class VectorStoreInput(BaseModel):
|
||||
"""Input parameters for creating or configuring a vector database.
|
||||
|
||||
:param vector_store_id: Unique identifier for the vector store
|
||||
:param embedding_model: Name of the embedding model to use for vector generation
|
||||
:param embedding_dimension: Dimension of the embedding vectors
|
||||
:param provider_vector_store_id: (Optional) Provider-specific identifier for the vector store
|
||||
"""
|
||||
|
||||
vector_store_id: str
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
provider_id: str | None = None
|
||||
provider_vector_store_id: str | None = None
|
||||
9
src/llama_stack/apis/version.py
Normal file
9
src/llama_stack/apis/version.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_STACK_API_V1 = "v1"
|
||||
LLAMA_STACK_API_V1BETA = "v1beta"
|
||||
LLAMA_STACK_API_V1ALPHA = "v1alpha"
|
||||
5
src/llama_stack/cli/__init__.py
Normal file
5
src/llama_stack/cli/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
56
src/llama_stack/cli/llama.py
Normal file
56
src/llama_stack/cli/llama.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.log import setup_logging
|
||||
|
||||
from .stack import StackParser
|
||||
from .stack.utils import print_subcommand_description
|
||||
|
||||
|
||||
class LlamaCLIParser:
|
||||
"""Defines CLI parser for Llama CLI"""
|
||||
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser(
|
||||
prog="llama",
|
||||
description="Welcome to the Llama CLI",
|
||||
add_help=True,
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
# Default command is to print help
|
||||
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||
|
||||
subparsers = self.parser.add_subparsers(title="subcommands")
|
||||
|
||||
# Add sub-commands
|
||||
StackParser.create(subparsers)
|
||||
|
||||
print_subcommand_description(self.parser, subparsers)
|
||||
|
||||
def parse_args(self) -> argparse.Namespace:
|
||||
args = self.parser.parse_args()
|
||||
if not isinstance(args, argparse.Namespace):
|
||||
raise TypeError(f"Expected argparse.Namespace, got {type(args)}")
|
||||
return args
|
||||
|
||||
def run(self, args: argparse.Namespace) -> None:
|
||||
args.func(args)
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize logging from environment variables before any other operations
|
||||
setup_logging()
|
||||
|
||||
parser = LlamaCLIParser()
|
||||
args = parser.parse_args()
|
||||
parser.run(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
src/llama_stack/cli/scripts/__init__.py
Normal file
5
src/llama_stack/cli/scripts/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
38
src/llama_stack/cli/scripts/install-wheel-from-presigned.sh
Executable file
38
src/llama_stack/cli/scripts/install-wheel-from-presigned.sh
Executable file
|
|
@ -0,0 +1,38 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ $# -eq 0 ]; then
|
||||
echo "Please provide a URL as an argument."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
URL=$1
|
||||
|
||||
HEADERS_FILE=$(mktemp)
|
||||
curl -s -I "$URL" >"$HEADERS_FILE"
|
||||
FILENAME=$(grep -i "x-manifold-obj-canonicalpath:" "$HEADERS_FILE" | sed -E 's/.*nodes\/[^\/]+\/(.+)/\1/' | tr -d "\r\n")
|
||||
|
||||
if [ -z "$FILENAME" ]; then
|
||||
echo "Could not find the x-manifold-obj-canonicalpath header."
|
||||
echo "HEADERS_FILE contents: "
|
||||
cat "$HEADERS_FILE"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Downloading $FILENAME..."
|
||||
|
||||
curl -s -L -o "$FILENAME" "$URL"
|
||||
|
||||
echo "Installing $FILENAME..."
|
||||
pip install "$FILENAME"
|
||||
echo "Successfully installed $FILENAME"
|
||||
|
||||
rm -f "$FILENAME"
|
||||
18
src/llama_stack/cli/scripts/run.py
Normal file
18
src/llama_stack/cli/scripts/run.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def install_wheel_from_presigned():
|
||||
file = "install-wheel-from-presigned.sh"
|
||||
script_path = os.path.join(os.path.dirname(__file__), file)
|
||||
try:
|
||||
subprocess.run(["sh", script_path] + sys.argv[1:], check=True)
|
||||
except Exception:
|
||||
sys.exit(1)
|
||||
7
src/llama_stack/cli/stack/__init__.py
Normal file
7
src/llama_stack/cli/stack/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .stack import StackParser # noqa
|
||||
182
src/llama_stack/cli/stack/_list_deps.py
Normal file
182
src/llama_stack/cli/stack/_list_deps.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.stack.utils import ImageType
|
||||
from llama_stack.core.build import get_provider_dependencies
|
||||
from llama_stack.core.datatypes import (
|
||||
BuildConfig,
|
||||
BuildProvider,
|
||||
DistributionSpec,
|
||||
)
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
||||
logger = get_logger(name=__name__, category="cli")
|
||||
|
||||
|
||||
# These are the dependencies needed by the distribution server.
|
||||
# `llama-stack` is automatically installed by the installation script.
|
||||
SERVER_DEPENDENCIES = [
|
||||
"aiosqlite",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"uvicorn",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
]
|
||||
|
||||
|
||||
def format_output_deps_only(
|
||||
normal_deps: list[str],
|
||||
special_deps: list[str],
|
||||
external_deps: list[str],
|
||||
uv: bool = False,
|
||||
) -> str:
|
||||
"""Format dependencies as a list."""
|
||||
lines = []
|
||||
|
||||
uv_str = ""
|
||||
if uv:
|
||||
uv_str = "uv pip install "
|
||||
|
||||
# Quote deps with commas
|
||||
quoted_normal_deps = [quote_if_needed(dep) for dep in normal_deps]
|
||||
lines.append(f"{uv_str}{' '.join(quoted_normal_deps)}")
|
||||
|
||||
for special_dep in special_deps:
|
||||
lines.append(f"{uv_str}{quote_special_dep(special_dep)}")
|
||||
|
||||
for external_dep in external_deps:
|
||||
lines.append(f"{uv_str}{quote_special_dep(external_dep)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def run_stack_list_deps_command(args: argparse.Namespace) -> None:
|
||||
if args.config:
|
||||
try:
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
|
||||
config_file = resolve_config_or_distro(args.config, Mode.BUILD)
|
||||
except ValueError as e:
|
||||
cprint(
|
||||
f"Could not parse config file {args.config}: {e}",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
if config_file:
|
||||
with open(config_file) as f:
|
||||
try:
|
||||
contents = yaml.safe_load(f)
|
||||
contents = replace_env_vars(contents)
|
||||
build_config = BuildConfig(**contents)
|
||||
build_config.image_type = "venv"
|
||||
except Exception as e:
|
||||
cprint(
|
||||
f"Could not parse config file {config_file}: {e}",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
elif args.providers:
|
||||
provider_list: dict[str, list[BuildProvider]] = dict()
|
||||
for api_provider in args.providers.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider_type = api_provider.split("=")
|
||||
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||
if providers_for_api is None:
|
||||
cprint(
|
||||
f"{api} is not a valid API.",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
if provider_type in providers_for_api:
|
||||
provider = BuildProvider(
|
||||
provider_type=provider_type,
|
||||
module=None,
|
||||
)
|
||||
provider_list.setdefault(api, []).append(provider)
|
||||
else:
|
||||
cprint(
|
||||
f"{provider_type} is not a valid provider for the {api} API.",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
distribution_spec = DistributionSpec(
|
||||
providers=provider_list,
|
||||
description=",".join(args.providers),
|
||||
)
|
||||
build_config = BuildConfig(image_type=ImageType.VENV.value, distribution_spec=distribution_spec)
|
||||
|
||||
normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
||||
# Add external API dependencies
|
||||
if build_config.external_apis_dir:
|
||||
from llama_stack.core.external import load_external_apis
|
||||
|
||||
external_apis = load_external_apis(build_config)
|
||||
if external_apis:
|
||||
for _, api_spec in external_apis.items():
|
||||
normal_deps.extend(api_spec.pip_packages)
|
||||
|
||||
# Format and output based on requested format
|
||||
output = format_output_deps_only(
|
||||
normal_deps=normal_deps,
|
||||
special_deps=special_deps,
|
||||
external_deps=external_provider_dependencies,
|
||||
uv=args.format == "uv",
|
||||
)
|
||||
|
||||
print(output)
|
||||
|
||||
|
||||
def quote_if_needed(dep):
|
||||
# Add quotes if the dependency contains special characters that need escaping in shell
|
||||
# This includes: commas, comparison operators (<, >, <=, >=, ==, !=)
|
||||
needs_quoting = any(char in dep for char in [",", "<", ">", "="])
|
||||
return f"'{dep}'" if needs_quoting else dep
|
||||
|
||||
|
||||
def quote_special_dep(dep_string):
|
||||
"""
|
||||
Quote individual packages in a special dependency string.
|
||||
Special deps may contain multiple packages and flags like --extra-index-url.
|
||||
We need to quote only the package specs that contain special characters.
|
||||
"""
|
||||
parts = dep_string.split()
|
||||
quoted_parts = []
|
||||
|
||||
for part in parts:
|
||||
# Don't quote flags (they start with -)
|
||||
if part.startswith("-"):
|
||||
quoted_parts.append(part)
|
||||
else:
|
||||
# Quote package specs that need it
|
||||
quoted_parts.append(quote_if_needed(part))
|
||||
|
||||
return " ".join(quoted_parts)
|
||||
47
src/llama_stack/cli/stack/list_apis.py
Normal file
47
src/llama_stack/cli/stack/list_apis.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class StackListApis(Subcommand):
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"list-apis",
|
||||
prog="llama stack list-apis",
|
||||
description="List APIs part of the Llama Stack implementation",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_apis_list_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
pass
|
||||
|
||||
def _run_apis_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.core.distribution import stack_apis
|
||||
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
headers = [
|
||||
"API",
|
||||
]
|
||||
|
||||
rows = []
|
||||
for api in stack_apis():
|
||||
rows.append(
|
||||
[
|
||||
api.value,
|
||||
]
|
||||
)
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
51
src/llama_stack/cli/stack/list_deps.py
Normal file
51
src/llama_stack/cli/stack/list_deps.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class StackListDeps(Subcommand):
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"list-deps",
|
||||
prog="llama stack list-deps",
|
||||
description="list the dependencies for a llama stack distribution",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_stack_list_deps_command)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"config",
|
||||
type=str,
|
||||
nargs="?", # Make it optional
|
||||
metavar="config | distro",
|
||||
help="Path to config file to use or name of known distro (llama stack list for a list).",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--providers",
|
||||
type=str,
|
||||
default=None,
|
||||
help="sync dependencies for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
choices=["uv", "deps-only"],
|
||||
default="deps-only",
|
||||
help="Output format: 'uv' shows shell commands, 'deps-only' shows just the list of dependencies without `uv` (default)",
|
||||
)
|
||||
|
||||
def _run_stack_list_deps_command(self, args: argparse.Namespace) -> None:
|
||||
# always keep implementation completely silo-ed away from CLI so CLI
|
||||
# can be fast to load and reduces dependencies
|
||||
from ._list_deps import run_stack_list_deps_command
|
||||
|
||||
return run_stack_list_deps_command(args)
|
||||
76
src/llama_stack/cli/stack/list_providers.py
Normal file
76
src/llama_stack/cli/stack/list_providers.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class StackListProviders(Subcommand):
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"list-providers",
|
||||
prog="llama stack list-providers",
|
||||
description="Show available Llama Stack Providers for an API",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_providers_list_cmd)
|
||||
|
||||
@property
|
||||
def providable_apis(self):
|
||||
from llama_stack.core.distribution import providable_apis
|
||||
|
||||
return [api.value for api in providable_apis()]
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"api",
|
||||
type=str,
|
||||
choices=self.providable_apis,
|
||||
nargs="?",
|
||||
help="API to list providers for. List all if not specified.",
|
||||
)
|
||||
|
||||
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.core.distribution import Api, get_provider_registry
|
||||
|
||||
all_providers = get_provider_registry()
|
||||
if args.api:
|
||||
providers = [(args.api, all_providers[Api(args.api)])]
|
||||
else:
|
||||
providers = [(k.value, prov) for k, prov in all_providers.items()]
|
||||
|
||||
providers = [(api, p) for api, p in providers if api in self.providable_apis]
|
||||
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
headers = [
|
||||
"API Type",
|
||||
"Provider Type",
|
||||
"PIP Package Dependencies",
|
||||
]
|
||||
|
||||
rows = []
|
||||
|
||||
specs = [spec for api, p in providers for spec in p.values()]
|
||||
for spec in specs:
|
||||
if spec.is_sample:
|
||||
continue
|
||||
rows.append(
|
||||
[
|
||||
spec.api.value,
|
||||
spec.provider_type,
|
||||
",".join(spec.pip_packages) if hasattr(spec, "pip_packages") else "",
|
||||
]
|
||||
)
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
sort_by=(0, 1),
|
||||
)
|
||||
56
src/llama_stack/cli/stack/list_stacks.py
Normal file
56
src/llama_stack/cli/stack/list_stacks.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
|
||||
|
||||
class StackListBuilds(Subcommand):
|
||||
"""List built stacks in .llama/distributions directory"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"list",
|
||||
prog="llama stack list",
|
||||
description="list the build stacks",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._list_stack_command)
|
||||
|
||||
def _get_distribution_dirs(self) -> dict[str, Path]:
|
||||
"""Return a dictionary of distribution names and their paths"""
|
||||
distributions = {}
|
||||
dist_dir = Path.home() / ".llama" / "distributions"
|
||||
|
||||
if dist_dir.exists():
|
||||
for stack_dir in dist_dir.iterdir():
|
||||
if stack_dir.is_dir():
|
||||
distributions[stack_dir.name] = stack_dir
|
||||
return distributions
|
||||
|
||||
def _list_stack_command(self, args: argparse.Namespace) -> None:
|
||||
distributions = self._get_distribution_dirs()
|
||||
|
||||
if not distributions:
|
||||
print("No stacks found in ~/.llama/distributions")
|
||||
return
|
||||
|
||||
headers = ["Stack Name", "Path"]
|
||||
headers.extend(["Build Config", "Run Config"])
|
||||
rows = []
|
||||
for name, path in distributions.items():
|
||||
row = [name, str(path)]
|
||||
# Check for build and run config files
|
||||
build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No"
|
||||
run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No"
|
||||
row.extend([build_config, run_config])
|
||||
rows.append(row)
|
||||
print_table(rows, headers, separate_rows=True)
|
||||
115
src/llama_stack/cli/stack/remove.py
Normal file
115
src/llama_stack/cli/stack/remove.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
|
||||
|
||||
class StackRemove(Subcommand):
|
||||
"""Remove the build stack"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"rm",
|
||||
prog="llama stack rm",
|
||||
description="Remove the build stack",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._remove_stack_build_command)
|
||||
|
||||
def _add_arguments(self) -> None:
|
||||
self.parser.add_argument(
|
||||
"name",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="Name of the stack to delete",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--all",
|
||||
"-a",
|
||||
action="store_true",
|
||||
help="Delete all stacks (use with caution)",
|
||||
)
|
||||
|
||||
def _get_distribution_dirs(self) -> dict[str, Path]:
|
||||
"""Return a dictionary of distribution names and their paths"""
|
||||
distributions = {}
|
||||
dist_dir = Path.home() / ".llama" / "distributions"
|
||||
|
||||
if dist_dir.exists():
|
||||
for stack_dir in dist_dir.iterdir():
|
||||
if stack_dir.is_dir():
|
||||
distributions[stack_dir.name] = stack_dir
|
||||
return distributions
|
||||
|
||||
def _list_stacks(self) -> None:
|
||||
"""Display available stacks in a table"""
|
||||
distributions = self._get_distribution_dirs()
|
||||
if not distributions:
|
||||
cprint("No stacks found in ~/.llama/distributions", color="red", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
headers = ["Stack Name", "Path"]
|
||||
rows = [[name, str(path)] for name, path in distributions.items()]
|
||||
print_table(rows, headers, separate_rows=True)
|
||||
|
||||
def _remove_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||
distributions = self._get_distribution_dirs()
|
||||
|
||||
if args.all:
|
||||
confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower()
|
||||
if confirm != "yes-i-really-want":
|
||||
cprint("Deletion cancelled.", color="green", file=sys.stderr)
|
||||
return
|
||||
|
||||
for name, path in distributions.items():
|
||||
try:
|
||||
shutil.rmtree(path)
|
||||
cprint(f"Deleted stack: {name}", color="green", file=sys.stderr)
|
||||
except Exception as e:
|
||||
cprint(
|
||||
f"Failed to delete stack {name}: {e}",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if not args.name:
|
||||
self._list_stacks()
|
||||
if not args.name:
|
||||
return
|
||||
|
||||
if args.name not in distributions:
|
||||
self._list_stacks()
|
||||
cprint(
|
||||
f"Stack not found: {args.name}",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
stack_path = distributions[args.name]
|
||||
|
||||
confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower()
|
||||
if confirm != "y":
|
||||
cprint("Deletion cancelled.", color="green", file=sys.stderr)
|
||||
return
|
||||
|
||||
try:
|
||||
shutil.rmtree(stack_path)
|
||||
cprint(f"Successfully deleted stack: {args.name}", color="green", file=sys.stderr)
|
||||
except Exception as e:
|
||||
cprint(f"Failed to delete stack {args.name}: {e}", color="red", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
214
src/llama_stack/cli/stack/run.py
Normal file
214
src/llama_stack/cli/stack/run.py
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import ssl
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
import yaml
|
||||
|
||||
from llama_stack.cli.stack.utils import ImageType
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.log import LoggingConfig, get_logger
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = get_logger(name=__name__, category="cli")
|
||||
|
||||
|
||||
class StackRun(Subcommand):
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"run",
|
||||
prog="llama stack run",
|
||||
description="""Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_stack_run_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"config",
|
||||
type=str,
|
||||
nargs="?", # Make it optional
|
||||
metavar="config | distro",
|
||||
help="Path to config file to use for the run or name of known distro (`llama stack list` for a list).",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT.",
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--image-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--image-type",
|
||||
type=str,
|
||||
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
|
||||
choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value],
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--enable-ui",
|
||||
action="store_true",
|
||||
help="Start the UI server",
|
||||
)
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||
|
||||
if args.image_type or args.image_name:
|
||||
self.parser.error(
|
||||
"The --image-type and --image-name flags are no longer supported.\n\n"
|
||||
"Please activate your virtual environment manually before running `llama stack run`.\n\n"
|
||||
"For example:\n"
|
||||
" source /path/to/venv/bin/activate\n"
|
||||
" llama stack run <config>\n"
|
||||
)
|
||||
|
||||
if args.enable_ui:
|
||||
self._start_ui_development_server(args.port)
|
||||
|
||||
if args.config:
|
||||
try:
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
|
||||
config_file = resolve_config_or_distro(args.config, Mode.RUN)
|
||||
except ValueError as e:
|
||||
self.parser.error(str(e))
|
||||
else:
|
||||
config_file = None
|
||||
|
||||
if config_file:
|
||||
logger.info(f"Using run configuration: {config_file}")
|
||||
|
||||
try:
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
except yaml.parser.ParserError as e:
|
||||
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
|
||||
|
||||
try:
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if not os.path.exists(str(config.external_providers_dir)):
|
||||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||
except AttributeError as e:
|
||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||
|
||||
self._uvicorn_run(config_file, args)
|
||||
|
||||
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
|
||||
if not config_file:
|
||||
self.parser.error("Config file is required")
|
||||
|
||||
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
port = args.port or config.server.port
|
||||
host = config.server.host or ["::", "0.0.0.0"]
|
||||
|
||||
# Set the config file in environment so create_app can find it
|
||||
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
||||
|
||||
uvicorn_config = {
|
||||
"factory": True,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
|
||||
keyfile = config.server.tls_keyfile
|
||||
certfile = config.server.tls_certfile
|
||||
if keyfile and certfile:
|
||||
uvicorn_config["ssl_keyfile"] = config.server.tls_keyfile
|
||||
uvicorn_config["ssl_certfile"] = config.server.tls_certfile
|
||||
if config.server.tls_cafile:
|
||||
uvicorn_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
uvicorn_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
|
||||
logger.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
logger.info(f"Listening on {host}:{port}")
|
||||
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
uvicorn.run("llama_stack.core.server.server:create_app", **uvicorn_config)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
|
||||
def _start_ui_development_server(self, stack_server_port: int):
|
||||
logger.info("Attempting to start UI development server...")
|
||||
# Check if npm is available
|
||||
npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False)
|
||||
if npm_check.returncode != 0:
|
||||
logger.warning(
|
||||
f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}"
|
||||
)
|
||||
return
|
||||
|
||||
ui_dir = REPO_ROOT / "llama_stack" / "ui"
|
||||
logs_dir = Path("~/.llama/ui/logs").expanduser()
|
||||
try:
|
||||
# Create logs directory if it doesn't exist
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ui_stdout_log_path = logs_dir / "stdout.log"
|
||||
ui_stderr_log_path = logs_dir / "stderr.log"
|
||||
|
||||
# Open log files in append mode
|
||||
stdout_log_file = open(ui_stdout_log_path, "a")
|
||||
stderr_log_file = open(ui_stderr_log_path, "a")
|
||||
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=str(ui_dir),
|
||||
stdout=stdout_log_file,
|
||||
stderr=stderr_log_file,
|
||||
env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"},
|
||||
)
|
||||
logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
|
||||
logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
|
||||
logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
||||
48
src/llama_stack/cli/stack/stack.py
Normal file
48
src/llama_stack/cli/stack/stack.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from importlib.metadata import version
|
||||
|
||||
from llama_stack.cli.stack.list_stacks import StackListBuilds
|
||||
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
from .list_apis import StackListApis
|
||||
from .list_deps import StackListDeps
|
||||
from .list_providers import StackListProviders
|
||||
from .remove import StackRemove
|
||||
from .run import StackRun
|
||||
|
||||
|
||||
class StackParser(Subcommand):
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"stack",
|
||||
prog="llama stack",
|
||||
description="Operations for the Llama Stack / Distributions",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--version",
|
||||
action="version",
|
||||
version=f"{version('llama-stack')}",
|
||||
)
|
||||
|
||||
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||
|
||||
subparsers = self.parser.add_subparsers(title="stack_subcommands")
|
||||
|
||||
# Add sub-commands
|
||||
StackListDeps.create(subparsers)
|
||||
StackListApis.create(subparsers)
|
||||
StackListProviders.create(subparsers)
|
||||
StackRun.create(subparsers)
|
||||
StackRemove.create(subparsers)
|
||||
StackListBuilds.create(subparsers)
|
||||
print_subcommand_description(self.parser, subparsers)
|
||||
151
src/llama_stack/cli/stack/utils.py
Normal file
151
src/llama_stack/cli/stack/utils.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import sys
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.core.datatypes import (
|
||||
BuildConfig,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.resolver import InvalidProviderError
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
)
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||
|
||||
|
||||
class ImageType(Enum):
|
||||
CONTAINER = "container"
|
||||
VENV = "venv"
|
||||
|
||||
|
||||
def print_subcommand_description(parser, subparsers):
|
||||
"""Print descriptions of subcommands."""
|
||||
description_text = ""
|
||||
for name, subcommand in subparsers.choices.items():
|
||||
description = subcommand.description
|
||||
description_text += f" {name:<21} {description}\n"
|
||||
parser.epilog = description_text
|
||||
|
||||
|
||||
def generate_run_config(
|
||||
build_config: BuildConfig,
|
||||
build_dir: Path,
|
||||
image_name: str,
|
||||
) -> Path:
|
||||
"""
|
||||
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||
"""
|
||||
apis = list(build_config.distribution_spec.providers.keys())
|
||||
distro_dir = DISTRIBS_BASE_DIR / image_name
|
||||
run_config = StackRunConfig(
|
||||
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
|
||||
image_name=image_name,
|
||||
apis=apis,
|
||||
providers={},
|
||||
storage=StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path=str(distro_dir / "kvstore.db")),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path=str(distro_dir / "sql_store.db")),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
|
||||
conversations=SqlStoreReference(backend="sql_default", table_name="openai_conversations"),
|
||||
),
|
||||
),
|
||||
external_providers_dir=build_config.external_providers_dir
|
||||
if build_config.external_providers_dir
|
||||
else EXTERNAL_PROVIDERS_DIR,
|
||||
)
|
||||
# build providers dict
|
||||
provider_registry = get_provider_registry(build_config)
|
||||
for api in apis:
|
||||
run_config.providers[api] = []
|
||||
providers = build_config.distribution_spec.providers[api]
|
||||
|
||||
for provider in providers:
|
||||
pid = provider.provider_type.split("::")[-1]
|
||||
|
||||
p = provider_registry[Api(api)][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
try:
|
||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider.provider_type].config_class)
|
||||
except (ModuleNotFoundError, ValueError) as exc:
|
||||
# HACK ALERT:
|
||||
# This code executes after building is done, the import cannot work since the
|
||||
# package is either available in the venv or container - not available on the host.
|
||||
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
|
||||
# external
|
||||
cprint(
|
||||
f"Failed to import provider {provider.provider_type} for API {api} - assuming it's external, skipping: {exc}",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
# Set config_type to None to avoid UnboundLocalError
|
||||
config_type = None
|
||||
|
||||
if config_type is not None and hasattr(config_type, "sample_run_config"):
|
||||
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
|
||||
else:
|
||||
config = {}
|
||||
|
||||
p_spec = Provider(
|
||||
provider_id=pid,
|
||||
provider_type=provider.provider_type,
|
||||
config=config,
|
||||
module=provider.module,
|
||||
)
|
||||
run_config.providers[api].append(p_spec)
|
||||
|
||||
run_config_file = build_dir / f"{image_name}-run.yaml"
|
||||
|
||||
with open(run_config_file, "w") as f:
|
||||
to_write = json.loads(run_config.model_dump_json())
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
# Only print this message for non-container builds since it will be displayed before the
|
||||
# container is built
|
||||
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
||||
# makes sense to display this message
|
||||
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
||||
cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr)
|
||||
return run_config_file
|
||||
|
||||
|
||||
@lru_cache
|
||||
def available_templates_specs() -> dict[str, BuildConfig]:
|
||||
import yaml
|
||||
|
||||
template_specs = {}
|
||||
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
||||
template_name = p.parent.name
|
||||
with open(p) as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
template_specs[template_name] = build_config
|
||||
return template_specs
|
||||
19
src/llama_stack/cli/subcommand.py
Normal file
19
src/llama_stack/cli/subcommand.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
class Subcommand:
|
||||
"""All llama cli subcommands must inherit from this class"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs):
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
def _add_arguments(self):
|
||||
pass
|
||||
39
src/llama_stack/cli/table.py
Normal file
39
src/llama_stack/cli/table.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
|
||||
def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()):
|
||||
# Convert rows and handle None values
|
||||
rows = [[x or "" for x in row] for row in rows]
|
||||
|
||||
# Sort rows if sort_by is specified
|
||||
if sort_by:
|
||||
rows.sort(key=lambda x: tuple(x[i] for i in sort_by))
|
||||
|
||||
# Create Rich table
|
||||
table = Table(show_lines=separate_rows)
|
||||
|
||||
# Add headers if provided
|
||||
if headers:
|
||||
for header in headers:
|
||||
table.add_column(header, style="bold white")
|
||||
else:
|
||||
# Add unnamed columns based on first row
|
||||
for _ in range(len(rows[0]) if rows else 0):
|
||||
table.add_column()
|
||||
|
||||
# Add rows
|
||||
for row in rows:
|
||||
table.add_row(*row)
|
||||
|
||||
# Print table
|
||||
console = Console()
|
||||
console.print(table)
|
||||
29
src/llama_stack/cli/utils.py
Normal file
29
src/llama_stack/cli/utils.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="cli")
|
||||
|
||||
|
||||
# TODO: this can probably just be inlined now?
|
||||
def add_config_distro_args(parser: argparse.ArgumentParser):
|
||||
"""Add unified config/distro arguments."""
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
|
||||
group.add_argument(
|
||||
"config",
|
||||
nargs="?",
|
||||
help="Configuration file path or distribution name",
|
||||
)
|
||||
|
||||
|
||||
def get_config_from_args(args: argparse.Namespace) -> str | None:
|
||||
if args.config is not None:
|
||||
return str(args.config)
|
||||
return None
|
||||
5
src/llama_stack/core/__init__.py
Normal file
5
src/llama_stack/core/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
5
src/llama_stack/core/access_control/__init__.py
Normal file
5
src/llama_stack/core/access_control/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
131
src/llama_stack/core/access_control/access_control.py
Normal file
131
src/llama_stack/core/access_control/access_control.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import User
|
||||
|
||||
from .conditions import (
|
||||
Condition,
|
||||
ProtectedResource,
|
||||
parse_conditions,
|
||||
)
|
||||
from .datatypes import (
|
||||
AccessRule,
|
||||
Action,
|
||||
Scope,
|
||||
)
|
||||
|
||||
|
||||
def matches_resource(resource_scope: str, actual_resource: str) -> bool:
|
||||
if resource_scope == actual_resource:
|
||||
return True
|
||||
return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1])
|
||||
|
||||
|
||||
def matches_scope(
|
||||
scope: Scope,
|
||||
action: Action,
|
||||
resource: str,
|
||||
user: str | None,
|
||||
) -> bool:
|
||||
if scope.resource and not matches_resource(scope.resource, resource):
|
||||
return False
|
||||
if scope.principal and scope.principal != user:
|
||||
return False
|
||||
return action in scope.actions
|
||||
|
||||
|
||||
def as_list(obj: Any) -> list[Any]:
|
||||
if isinstance(obj, list):
|
||||
return obj
|
||||
return [obj]
|
||||
|
||||
|
||||
def matches_conditions(
|
||||
conditions: list[Condition],
|
||||
resource: ProtectedResource,
|
||||
user: User,
|
||||
) -> bool:
|
||||
for condition in conditions:
|
||||
# must match all conditions
|
||||
if not condition.matches(resource, user):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def default_policy() -> list[AccessRule]:
|
||||
# for backwards compatibility, if no rules are provided, assume
|
||||
# full access subject to previous attribute matching rules
|
||||
return [
|
||||
AccessRule(
|
||||
permit=Scope(actions=list(Action)),
|
||||
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def is_action_allowed(
|
||||
policy: list[AccessRule],
|
||||
action: Action,
|
||||
resource: ProtectedResource,
|
||||
user: User | None,
|
||||
) -> bool:
|
||||
# If user is not set, assume authentication is not enabled
|
||||
if not user:
|
||||
return True
|
||||
|
||||
if not len(policy):
|
||||
policy = default_policy()
|
||||
|
||||
qualified_resource_id = f"{resource.type}::{resource.identifier}"
|
||||
for rule in policy:
|
||||
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
||||
if rule.when:
|
||||
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||
return False
|
||||
elif rule.unless:
|
||||
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
|
||||
if rule.when:
|
||||
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||
return True
|
||||
elif rule.unless:
|
||||
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
# assume access is denied unless we find a rule that permits access
|
||||
return False
|
||||
|
||||
|
||||
class AccessDeniedError(RuntimeError):
|
||||
def __init__(self, action: str | None = None, resource: ProtectedResource | None = None, user: User | None = None):
|
||||
self.action = action
|
||||
self.resource = resource
|
||||
self.user = user
|
||||
|
||||
message = _build_access_denied_message(action, resource, user)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def _build_access_denied_message(action: str | None, resource: ProtectedResource | None, user: User | None) -> str:
|
||||
"""Build detailed error message for access denied scenarios."""
|
||||
if action and resource and user:
|
||||
resource_info = f"{resource.type}::{resource.identifier}"
|
||||
user_info = f"'{user.principal}'"
|
||||
if user.attributes:
|
||||
attrs = ", ".join([f"{k}={v}" for k, v in user.attributes.items()])
|
||||
user_info += f" (attributes: {attrs})"
|
||||
|
||||
message = f"User {user_info} cannot perform action '{action}' on resource '{resource_info}'"
|
||||
else:
|
||||
message = "Insufficient permissions"
|
||||
|
||||
return message
|
||||
129
src/llama_stack/core/access_control/conditions.py
Normal file
129
src/llama_stack/core/access_control/conditions.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class User(Protocol):
|
||||
principal: str
|
||||
attributes: dict[str, list[str]] | None
|
||||
|
||||
|
||||
class ProtectedResource(Protocol):
|
||||
type: str
|
||||
identifier: str
|
||||
owner: User
|
||||
|
||||
|
||||
class Condition(Protocol):
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
|
||||
|
||||
|
||||
class UserInOwnersList:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
|
||||
if (
|
||||
hasattr(resource, "owner")
|
||||
and resource.owner
|
||||
and resource.owner.attributes
|
||||
and self.name in resource.owner.attributes
|
||||
):
|
||||
return resource.owner.attributes[self.name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
required = self.owners_values(resource)
|
||||
if not required:
|
||||
return True
|
||||
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
|
||||
return False
|
||||
user_values = user.attributes[self.name]
|
||||
for value in required:
|
||||
if value in user_values:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f"user in owners {self.name}"
|
||||
|
||||
|
||||
class UserNotInOwnersList(UserInOwnersList):
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return not super().matches(resource, user)
|
||||
|
||||
def __repr__(self):
|
||||
return f"user not in owners {self.name}"
|
||||
|
||||
|
||||
class UserWithValueInList:
|
||||
def __init__(self, name: str, value: str):
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
if user.attributes and self.name in user.attributes:
|
||||
return self.value in user.attributes[self.name]
|
||||
print(f"User does not have {self.value} in {self.name}")
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f"user with {self.value} in {self.name}"
|
||||
|
||||
|
||||
class UserWithValueNotInList(UserWithValueInList):
|
||||
def __init__(self, name: str, value: str):
|
||||
super().__init__(name, value)
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return not super().matches(resource, user)
|
||||
|
||||
def __repr__(self):
|
||||
return f"user with {self.value} not in {self.name}"
|
||||
|
||||
|
||||
class UserIsOwner:
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return resource.owner.principal == user.principal if resource.owner else False
|
||||
|
||||
def __repr__(self):
|
||||
return "user is owner"
|
||||
|
||||
|
||||
class UserIsNotOwner:
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return not resource.owner or resource.owner.principal != user.principal
|
||||
|
||||
def __repr__(self):
|
||||
return "user is not owner"
|
||||
|
||||
|
||||
def parse_condition(condition: str) -> Condition:
|
||||
words = condition.split()
|
||||
match words:
|
||||
case ["user", "is", "owner"]:
|
||||
return UserIsOwner()
|
||||
case ["user", "is", "not", "owner"]:
|
||||
return UserIsNotOwner()
|
||||
case ["user", "with", value, "in", name]:
|
||||
return UserWithValueInList(name, value)
|
||||
case ["user", "with", value, "not", "in", name]:
|
||||
return UserWithValueNotInList(name, value)
|
||||
case ["user", "in", "owners", name]:
|
||||
return UserInOwnersList(name)
|
||||
case ["user", "not", "in", "owners", name]:
|
||||
return UserNotInOwnersList(name)
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def parse_conditions(conditions: list[str]) -> list[Condition]:
|
||||
return [parse_condition(c) for c in conditions]
|
||||
107
src/llama_stack/core/access_control/datatypes.py
Normal file
107
src/llama_stack/core/access_control/datatypes.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from .conditions import parse_conditions
|
||||
|
||||
|
||||
class Action(StrEnum):
|
||||
CREATE = "create"
|
||||
READ = "read"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class Scope(BaseModel):
|
||||
principal: str | None = None
|
||||
actions: Action | list[Action]
|
||||
resource: str | None = None
|
||||
|
||||
|
||||
def _mutually_exclusive(obj, a: str, b: str):
|
||||
if getattr(obj, a) and getattr(obj, b):
|
||||
raise ValueError(f"{a} and {b} are mutually exclusive")
|
||||
|
||||
|
||||
def _require_one_of(obj, a: str, b: str):
|
||||
if not getattr(obj, a) and not getattr(obj, b):
|
||||
raise ValueError(f"on of {a} or {b} is required")
|
||||
|
||||
|
||||
class AccessRule(BaseModel):
|
||||
"""Access rule based loosely on cedar policy language
|
||||
|
||||
A rule defines a list of action either to permit or to forbid. It may specify a
|
||||
principal or a resource that must match for the rule to take effect. The resource
|
||||
to match should be specified in the form of a type qualified identifier, e.g.
|
||||
model::my-model or vector_store::some-db, or a wildcard for all resources of a type,
|
||||
e.g. model::*. If the principal or resource are not specified, they will match all
|
||||
requests.
|
||||
|
||||
A rule may also specify a condition, either a 'when' or an 'unless', with additional
|
||||
constraints as to where the rule applies. The constraints supported at present are:
|
||||
|
||||
- 'user with <attr-value> in <attr-name>'
|
||||
- 'user with <attr-value> not in <attr-name>'
|
||||
- 'user is owner'
|
||||
- 'user is not owner'
|
||||
- 'user in owners <attr-name>'
|
||||
- 'user not in owners <attr-name>'
|
||||
|
||||
Rules are tested in order to find a match. If a match is found, the request is
|
||||
permitted or forbidden depending on the type of rule. If no match is found, the
|
||||
request is denied. If no rules are specified, a rule that allows any action as
|
||||
long as the resource attributes match the user attributes is added
|
||||
(i.e. the previous behaviour is the default).
|
||||
|
||||
Some examples in yaml:
|
||||
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [create, read, delete]
|
||||
resource: model::*
|
||||
description: user-1 has full access to all models
|
||||
- permit:
|
||||
principal: user-2
|
||||
actions: [read]
|
||||
resource: model::model-1
|
||||
description: user-2 has read access to model-1 only
|
||||
- permit:
|
||||
actions: [read]
|
||||
when: user in owner teams
|
||||
description: any user has read access to any resource created by a member of their team
|
||||
- forbid:
|
||||
actions: [create, read, delete]
|
||||
resource: vector_store::*
|
||||
unless: user with admin in roles
|
||||
description: only user with admin role can use vector_store resources
|
||||
|
||||
"""
|
||||
|
||||
permit: Scope | None = None
|
||||
forbid: Scope | None = None
|
||||
when: str | list[str] | None = None
|
||||
unless: str | list[str] | None = None
|
||||
description: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_rule_format(self) -> Self:
|
||||
_require_one_of(self, "permit", "forbid")
|
||||
_mutually_exclusive(self, "permit", "forbid")
|
||||
_mutually_exclusive(self, "when", "unless")
|
||||
if isinstance(self.when, list):
|
||||
parse_conditions(self.when)
|
||||
elif self.when:
|
||||
parse_conditions([self.when])
|
||||
if isinstance(self.unless, list):
|
||||
parse_conditions(self.unless)
|
||||
elif self.unless:
|
||||
parse_conditions([self.unless])
|
||||
return self
|
||||
164
src/llama_stack/core/build.py
Normal file
164
src/llama_stack/core/build.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import importlib.resources
|
||||
import sys
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.core.datatypes import BuildConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.utils.exec import run_command
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.distributions.template import DistributionTemplate
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
# These are the dependencies needed by the distribution server.
|
||||
# `llama-stack` is automatically installed by the installation script.
|
||||
SERVER_DEPENDENCIES = [
|
||||
"aiosqlite",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"uvicorn",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
]
|
||||
|
||||
|
||||
class ApiInput(BaseModel):
|
||||
api: Api
|
||||
provider: str
|
||||
|
||||
|
||||
def get_provider_dependencies(
|
||||
config: BuildConfig | DistributionTemplate,
|
||||
) -> tuple[list[str], list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
if isinstance(config, DistributionTemplate):
|
||||
config = config.build_config()
|
||||
|
||||
providers = config.distribution_spec.providers
|
||||
additional_pip_packages = config.additional_pip_packages
|
||||
|
||||
deps = []
|
||||
external_provider_deps = []
|
||||
registry = get_provider_registry(config)
|
||||
for api_str, provider_or_providers in providers.items():
|
||||
providers_for_api = registry[Api(api_str)]
|
||||
|
||||
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
|
||||
|
||||
for provider in providers:
|
||||
# Providers from BuildConfig and RunConfig are subtly different - not great
|
||||
provider_type = provider if isinstance(provider, str) else provider.provider_type
|
||||
|
||||
if provider_type not in providers_for_api:
|
||||
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
|
||||
|
||||
provider_spec = providers_for_api[provider_type]
|
||||
if hasattr(provider_spec, "is_external") and provider_spec.is_external:
|
||||
# this ensures we install the top level module for our external providers
|
||||
if provider_spec.module:
|
||||
if isinstance(provider_spec.module, str):
|
||||
external_provider_deps.append(provider_spec.module)
|
||||
else:
|
||||
external_provider_deps.extend(provider_spec.module)
|
||||
if hasattr(provider_spec, "pip_packages"):
|
||||
deps.extend(provider_spec.pip_packages)
|
||||
if hasattr(provider_spec, "container_image") and provider_spec.container_image:
|
||||
raise ValueError("A stack's dependencies cannot have a container image")
|
||||
|
||||
normal_deps = []
|
||||
special_deps = []
|
||||
for package in deps:
|
||||
if any(f in package for f in ["--no-deps", "--index-url", "--extra-index-url"]):
|
||||
special_deps.append(package)
|
||||
else:
|
||||
normal_deps.append(package)
|
||||
|
||||
normal_deps.extend(additional_pip_packages or [])
|
||||
|
||||
return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps))
|
||||
|
||||
|
||||
def print_pip_install_help(config: BuildConfig):
|
||||
normal_deps, special_deps, _ = get_provider_dependencies(config)
|
||||
|
||||
cprint(
|
||||
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
for special_dep in special_deps:
|
||||
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
|
||||
print()
|
||||
|
||||
|
||||
def build_image(
|
||||
build_config: BuildConfig,
|
||||
image_name: str,
|
||||
distro_or_config: str,
|
||||
run_config: str | None = None,
|
||||
):
|
||||
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
|
||||
|
||||
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
if build_config.external_apis_dir:
|
||||
external_apis = load_external_apis(build_config)
|
||||
if external_apis:
|
||||
for _, api_spec in external_apis.items():
|
||||
normal_deps.extend(api_spec.pip_packages)
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "core/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
"--distro-or-config",
|
||||
distro_or_config,
|
||||
"--image-name",
|
||||
image_name,
|
||||
"--container-base",
|
||||
container_base,
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
# When building from a config file (not a template), include the run config path in the
|
||||
# build arguments
|
||||
if run_config is not None:
|
||||
args.extend(["--run-config", run_config])
|
||||
else:
|
||||
script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
"--env-name",
|
||||
str(image_name),
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
# Always pass both arguments, even if empty, to maintain consistent positional arguments
|
||||
if special_deps:
|
||||
args.extend(["--optional-deps", "#".join(special_deps)])
|
||||
if external_provider_deps:
|
||||
args.extend(
|
||||
["--external-provider-deps", "#".join(external_provider_deps)]
|
||||
) # the script will install external provider module, get its deps, and install those too.
|
||||
|
||||
return_code = run_command(args)
|
||||
|
||||
if return_code != 0:
|
||||
log.error(
|
||||
f"Failed to build target {image_name} with return code {return_code}",
|
||||
)
|
||||
|
||||
return return_code
|
||||
205
src/llama_stack/core/client.py
Normal file
205
src/llama_stack/core/client.py
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from typing import Any, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
_CLIENT_CLASSES = {}
|
||||
|
||||
|
||||
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
|
||||
client_class = create_api_client_class(protocol)
|
||||
impl = client_class(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
def create_api_client_class(protocol) -> type:
|
||||
if protocol in _CLIENT_CLASSES:
|
||||
return _CLIENT_CLASSES[protocol]
|
||||
|
||||
class APIClient:
|
||||
def __init__(self, base_url: str):
|
||||
print(f"({protocol.__name__}) Connecting to {base_url}")
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.routes = {}
|
||||
|
||||
# Store routes for this protocol
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
sig = inspect.signature(method)
|
||||
self.routes[name] = (method.__webmethod__, sig)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
|
||||
assert method_name in self.routes, f"Unknown endpoint: {method_name}"
|
||||
|
||||
# TODO: make this more precise, same thing needs to happen in server.py
|
||||
is_streaming = kwargs.get("stream", False)
|
||||
if is_streaming:
|
||||
return self._call_streaming(method_name, *args, **kwargs)
|
||||
else:
|
||||
return await self._call_non_streaming(method_name, *args, **kwargs)
|
||||
|
||||
async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
||||
_, sig = self.routes[method_name]
|
||||
|
||||
if sig.return_annotation is None:
|
||||
return_type = None
|
||||
else:
|
||||
return_type = extract_non_async_iterator_type(sig.return_annotation)
|
||||
assert return_type, f"Could not extract return type for {sig.return_annotation}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||
response = await client.request(**params)
|
||||
response.raise_for_status()
|
||||
|
||||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
# print(f"({protocol.__name__}) Returning {j}, type {return_type}")
|
||||
return parse_obj_as(return_type, j)
|
||||
|
||||
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
||||
webmethod, sig = self.routes[method_name]
|
||||
|
||||
return_type = extract_async_iterator_type(sig.return_annotation)
|
||||
assert return_type, f"Could not extract return type for {sig.return_annotation}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||
async with client.stream(**params) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
data = json.loads(data)
|
||||
if "error" in data:
|
||||
cprint(data, color="red", file=sys.stderr)
|
||||
continue
|
||||
|
||||
yield parse_obj_as(return_type, data)
|
||||
except Exception as e:
|
||||
cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
|
||||
cprint(data, color="red", file=sys.stderr)
|
||||
|
||||
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
||||
webmethod, sig = self.routes[method_name]
|
||||
|
||||
parameters = list(sig.parameters.values())[1:] # skip `self`
|
||||
for i, param in enumerate(parameters):
|
||||
if i >= len(args):
|
||||
break
|
||||
kwargs[param.name] = args[i]
|
||||
|
||||
# Get all webmethods for this method (supports multiple decorators)
|
||||
webmethods = getattr(method, "__webmethods__", [])
|
||||
|
||||
if not webmethods:
|
||||
raise RuntimeError(f"Method {method} has no webmethod decorators")
|
||||
|
||||
# Choose the preferred webmethod (non-deprecated if available)
|
||||
preferred_webmethod = None
|
||||
for wm in webmethods:
|
||||
if not getattr(wm, "deprecated", False):
|
||||
preferred_webmethod = wm
|
||||
break
|
||||
|
||||
# If no non-deprecated found, use the first one
|
||||
if preferred_webmethod is None:
|
||||
preferred_webmethod = webmethods[0]
|
||||
|
||||
url = f"{self.base_url}/{preferred_webmethod.level}/{preferred_webmethod.route.lstrip('/')}"
|
||||
|
||||
def convert(value):
|
||||
if isinstance(value, list):
|
||||
return [convert(v) for v in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert(v) for k, v in value.items()}
|
||||
elif isinstance(value, BaseModel):
|
||||
return json.loads(value.model_dump_json())
|
||||
elif isinstance(value, Enum):
|
||||
return value.value
|
||||
else:
|
||||
return value
|
||||
|
||||
params = {}
|
||||
data = {}
|
||||
if webmethod.method == "GET":
|
||||
params.update(kwargs)
|
||||
else:
|
||||
data.update(convert(kwargs))
|
||||
|
||||
ret = dict(
|
||||
method=webmethod.method or "POST",
|
||||
url=url,
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if params:
|
||||
ret["params"] = params
|
||||
if data:
|
||||
ret["json"] = data
|
||||
|
||||
return ret
|
||||
|
||||
# Add protocol methods to the wrapper
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
|
||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||
return await self.__acall__(method_name, *args, **kwargs)
|
||||
|
||||
method_impl.__name__ = name
|
||||
method_impl.__qualname__ = f"APIClient.{name}"
|
||||
method_impl.__signature__ = inspect.signature(method)
|
||||
setattr(APIClient, name, method_impl)
|
||||
|
||||
# Name the class after the protocol
|
||||
APIClient.__name__ = f"{protocol.__name__}Client"
|
||||
_CLIENT_CLASSES[protocol] = APIClient
|
||||
return APIClient
|
||||
|
||||
|
||||
# not quite general these methods are
|
||||
def extract_non_async_iterator_type(type_hint):
|
||||
if get_origin(type_hint) is Union:
|
||||
args = get_args(type_hint)
|
||||
for arg in args:
|
||||
if not issubclass(get_origin(arg) or arg, AsyncIterator):
|
||||
return arg
|
||||
return type_hint
|
||||
|
||||
|
||||
def extract_async_iterator_type(type_hint):
|
||||
if get_origin(type_hint) is Union:
|
||||
args = get_args(type_hint)
|
||||
for arg in args:
|
||||
if issubclass(get_origin(arg) or arg, AsyncIterator):
|
||||
inner_args = get_args(arg)
|
||||
return inner_args[0]
|
||||
return None
|
||||
37
src/llama_stack/core/common.sh
Executable file
37
src/llama_stack/core/common.sh
Executable file
|
|
@ -0,0 +1,37 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
cleanup() {
|
||||
# For venv environments, no special cleanup is needed
|
||||
# This function exists to avoid "function not found" errors
|
||||
local env_name="$1"
|
||||
echo "Cleanup called for environment: $env_name"
|
||||
}
|
||||
|
||||
handle_int() {
|
||||
if [ -n "$ENVNAME" ]; then
|
||||
cleanup "$ENVNAME"
|
||||
fi
|
||||
exit 1
|
||||
}
|
||||
|
||||
handle_exit() {
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "\033[1;31mABORTING.\033[0m"
|
||||
if [ -n "$ENVNAME" ]; then
|
||||
cleanup "$ENVNAME"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
|
||||
# check if a command is present
|
||||
is_command_available() {
|
||||
command -v "$1" &>/dev/null
|
||||
}
|
||||
212
src/llama_stack/core/configure.py
Normal file
212
src/llama_stack/core/configure.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
DistributionSpec,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.core.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||
provider_spec = registry[provider.provider_type]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
if provider.config:
|
||||
existing = config_type(**provider.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
return Provider(
|
||||
provider_id=provider.provider_id,
|
||||
provider_type=provider.provider_type,
|
||||
config=cfg.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
if is_nux:
|
||||
logger.info(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
||||
we need to configure the providers (implementations) you want to use for these APIs.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
provider_registry = get_provider_registry()
|
||||
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||
|
||||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.inspect, Api.providers)]
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
if api in builtin_apis:
|
||||
continue
|
||||
if api not in provider_registry:
|
||||
raise ValueError(f"Unknown API `{api_str}`")
|
||||
|
||||
existing_providers = config.providers.get(api_str, [])
|
||||
if existing_providers:
|
||||
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
|
||||
updated_providers = []
|
||||
for p in existing_providers:
|
||||
logger.info(f"> Configuring provider `({p.provider_type})`")
|
||||
updated_providers.append(configure_single_provider(provider_registry[api], p))
|
||||
logger.info("")
|
||||
else:
|
||||
# we are newly configuring this API
|
||||
plist = build_spec.providers.get(api_str, [])
|
||||
plist = plist if isinstance(plist, list) else [plist]
|
||||
|
||||
if not plist:
|
||||
raise ValueError(f"No provider configured for API {api_str}?")
|
||||
|
||||
logger.info(f"Configuring API `{api_str}`...")
|
||||
updated_providers = []
|
||||
for i, provider in enumerate(plist):
|
||||
if i >= 1:
|
||||
others = ", ".join(p.provider_type for p in plist[i:])
|
||||
logger.info(
|
||||
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(f"> Configuring provider `({provider.provider_type})`")
|
||||
pid = provider.provider_type.split("::")[-1]
|
||||
updated_providers.append(
|
||||
configure_single_provider(
|
||||
provider_registry[api],
|
||||
Provider(
|
||||
provider_id=(f"{pid}-{i:02d}" if len(plist) > 1 else pid),
|
||||
provider_type=provider.provider_type,
|
||||
config={},
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.info("")
|
||||
|
||||
config.providers[api_str] = updated_providers
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def upgrade_from_routing_table(
|
||||
config_dict: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
def get_providers(entries):
|
||||
return [
|
||||
Provider(
|
||||
provider_id=(f"{entry['provider_type']}-{i:02d}" if len(entries) > 1 else entry["provider_type"]),
|
||||
provider_type=entry["provider_type"],
|
||||
config=entry["config"],
|
||||
)
|
||||
for i, entry in enumerate(entries)
|
||||
]
|
||||
|
||||
providers_by_api = {}
|
||||
|
||||
routing_table = config_dict.get("routing_table", {})
|
||||
for api_str, entries in routing_table.items():
|
||||
providers = get_providers(entries)
|
||||
providers_by_api[api_str] = providers
|
||||
|
||||
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
||||
if provider_map:
|
||||
for api_str, provider in provider_map.items():
|
||||
if isinstance(provider, dict) and "provider_type" in provider:
|
||||
providers_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=f"{provider['provider_type']}",
|
||||
provider_type=provider["provider_type"],
|
||||
config=provider["config"],
|
||||
)
|
||||
]
|
||||
|
||||
config_dict["providers"] = providers_by_api
|
||||
|
||||
config_dict.pop("routing_table", None)
|
||||
config_dict.pop("api_providers", None)
|
||||
config_dict.pop("provider_map", None)
|
||||
|
||||
config_dict["apis"] = config_dict["apis_to_serve"]
|
||||
config_dict.pop("apis_to_serve", None)
|
||||
|
||||
# Add default storage config if not present
|
||||
if "storage" not in config_dict:
|
||||
config_dict["storage"] = {
|
||||
"backends": {
|
||||
"kv_default": {
|
||||
"type": "kv_sqlite",
|
||||
"db_path": "~/.llama/kvstore.db",
|
||||
},
|
||||
"sql_default": {
|
||||
"type": "sql_sqlite",
|
||||
"db_path": "~/.llama/sql_store.db",
|
||||
},
|
||||
},
|
||||
"stores": {
|
||||
"metadata": {
|
||||
"namespace": "registry",
|
||||
"backend": "kv_default",
|
||||
},
|
||||
"inference": {
|
||||
"table_name": "inference_store",
|
||||
"backend": "sql_default",
|
||||
"max_write_queue_size": 10000,
|
||||
"num_writers": 4,
|
||||
},
|
||||
"conversations": {
|
||||
"table_name": "openai_conversations",
|
||||
"backend": "sql_default",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
||||
if "routing_table" in config_dict:
|
||||
logger.info("Upgrading config...")
|
||||
config_dict = upgrade_from_routing_table(config_dict)
|
||||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
if not config_dict.get("external_providers_dir", None):
|
||||
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
||||
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
5
src/llama_stack/core/conversations/__init__.py
Normal file
5
src/llama_stack/core/conversations/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
314
src/llama_stack/core/conversations/conversations.py
Normal file
314
src/llama_stack/core/conversations/conversations.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemInclude,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_conversations")
|
||||
|
||||
|
||||
class ConversationServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in conversation service.
|
||||
|
||||
:param run_config: Stack run configuration for resolving persistence
|
||||
:param policy: Access control rules
|
||||
"""
|
||||
|
||||
run_config: StackRunConfig
|
||||
policy: list[AccessRule] = []
|
||||
|
||||
|
||||
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
"""Get the conversation service implementation."""
|
||||
impl = ConversationServiceImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ConversationServiceImpl(Conversations):
|
||||
"""Built-in conversation service implementation using AuthorizedSqlStore."""
|
||||
|
||||
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.policy = config.policy
|
||||
|
||||
# Use conversations store reference from run config
|
||||
conversations_ref = config.run_config.storage.stores.conversations
|
||||
if not conversations_ref:
|
||||
raise ValueError("storage.stores.conversations must be configured in run config")
|
||||
|
||||
base_sql_store = sqlstore_impl(conversations_ref)
|
||||
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the store and create tables."""
|
||||
await self.sql_store.create_table(
|
||||
"openai_conversations",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"items": ColumnType.JSON,
|
||||
"metadata": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"conversation_items",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"conversation_id": ColumnType.STRING,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"item_data": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
async def create_conversation(
|
||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||
) -> Conversation:
|
||||
"""Create a conversation."""
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
conversation_id = f"conv_{random_bytes.hex()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
record_data = {
|
||||
"id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"items": [],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
await self.sql_store.insert(
|
||||
table="openai_conversations",
|
||||
data=record_data,
|
||||
)
|
||||
|
||||
if items:
|
||||
item_records = []
|
||||
for item in items:
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
item_records.append(item_record)
|
||||
|
||||
await self.sql_store.insert(table="conversation_items", data=item_records)
|
||||
|
||||
conversation = Conversation(
|
||||
id=conversation_id,
|
||||
created_at=created_at,
|
||||
metadata=metadata,
|
||||
object="conversation",
|
||||
)
|
||||
|
||||
logger.debug(f"Created conversation {conversation_id}")
|
||||
return conversation
|
||||
|
||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Get a conversation with the given ID."""
|
||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Conversation {conversation_id} not found")
|
||||
|
||||
return Conversation(
|
||||
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
|
||||
)
|
||||
|
||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||
"""Update a conversation's metadata with the given ID"""
|
||||
await self.sql_store.update(
|
||||
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
|
||||
)
|
||||
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||
"""Delete a conversation with the given ID."""
|
||||
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
logger.debug(f"Deleted conversation {conversation_id}")
|
||||
return ConversationDeletedResource(id=conversation_id)
|
||||
|
||||
def _validate_conversation_id(self, conversation_id: str) -> None:
|
||||
"""Validate conversation ID format."""
|
||||
if not conversation_id.startswith("conv_"):
|
||||
raise ValueError(
|
||||
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||
)
|
||||
|
||||
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
|
||||
"""Get existing item ID or generate one if missing."""
|
||||
if item.id is None:
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
if item.type == "message":
|
||||
item_id = f"msg_{random_bytes.hex()}"
|
||||
else:
|
||||
item_id = f"item_{random_bytes.hex()}"
|
||||
item_dict["id"] = item_id
|
||||
return item_id
|
||||
return item.id
|
||||
|
||||
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Validate conversation ID and return the conversation if it exists."""
|
||||
self._validate_conversation_id(conversation_id)
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||
"""Create (add) items to a conversation."""
|
||||
await self._get_validated_conversation(conversation_id)
|
||||
|
||||
created_items = []
|
||||
base_time = int(time.time())
|
||||
|
||||
for i, item in enumerate(items):
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
# make each timestamp unique to maintain order
|
||||
created_at = base_time + i
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||
try:
|
||||
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_items",
|
||||
data={"created_at": created_at, "item_data": item_dict},
|
||||
where={"id": item_id},
|
||||
)
|
||||
|
||||
created_items.append(item_dict)
|
||||
|
||||
logger.debug(f"Created {len(created_items)} items in conversation {conversation_id}")
|
||||
|
||||
# Convert created items (dicts) to proper ConversationItem types
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=created_items[0]["id"] if created_items else None,
|
||||
last_id=created_items[-1]["id"] if created_items else None,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||
"""Retrieve a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
# Get item from conversation_items table
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
return adapter.validate_python(record["item_data"])
|
||||
|
||||
async def list_items(
|
||||
self,
|
||||
conversation_id: str,
|
||||
after: str | None = None,
|
||||
include: list[ConversationItemInclude] | None = None,
|
||||
limit: int | None = None,
|
||||
order: Literal["asc", "desc"] | None = None,
|
||||
) -> ConversationItemList:
|
||||
"""List items in the conversation."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
|
||||
# check if conversation exists
|
||||
await self.get_conversation(conversation_id)
|
||||
|
||||
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
|
||||
records = result.data
|
||||
|
||||
if order is not None and order == "asc":
|
||||
records.sort(key=lambda x: x["created_at"])
|
||||
else:
|
||||
records.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
actual_limit = limit or 20
|
||||
|
||||
records = records[:actual_limit]
|
||||
items = [record["item_data"] for record in records]
|
||||
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
|
||||
|
||||
first_id = response_items[0].id if response_items else None
|
||||
last_id = response_items[-1].id if response_items else None
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def openai_delete_conversation_item(
|
||||
self, conversation_id: str, item_id: str
|
||||
) -> ConversationItemDeletedResource:
|
||||
"""Delete a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
_ = await self._get_validated_conversation(conversation_id)
|
||||
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
await self.sql_store.delete(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
logger.debug(f"Deleted item {item_id} from conversation {conversation_id}")
|
||||
return ConversationItemDeletedResource(id=item_id)
|
||||
629
src/llama_stack/core/datatypes.py
Normal file
629
src/llama_stack/core/datatypes.py
Normal file
|
|
@ -0,0 +1,629 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Literal, Self
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.resource import Resource
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
KVStoreReference,
|
||||
StorageBackendType,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.log import LoggingConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = 2
|
||||
|
||||
|
||||
RoutingKey = str | list[str]
|
||||
|
||||
|
||||
class RegistryEntrySource(StrEnum):
|
||||
via_register_api = "via_register_api"
|
||||
listed_from_provider = "listed_from_provider"
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
attributes: dict[str, list[str]] | None = None
|
||||
|
||||
def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
|
||||
super().__init__(principal=principal, attributes=attributes)
|
||||
|
||||
|
||||
class ResourceWithOwner(Resource):
|
||||
"""Extension of Resource that adds an optional owner, i.e. the user that created the
|
||||
resource. This can be used to constrain access to the resource."""
|
||||
|
||||
owner: User | None = None
|
||||
source: RegistryEntrySource = RegistryEntrySource.via_register_api
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
class ModelWithOwner(Model, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ShieldWithOwner(Shield, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class VectorStoreWithOwner(VectorStore, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetWithOwner(Dataset, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Model | Shield | VectorStore | Dataset | ScoringFn | Benchmark | ToolGroup
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
ModelWithOwner
|
||||
| ShieldWithOwner
|
||||
| VectorStoreWithOwner
|
||||
| DatasetWithOwner
|
||||
| ScoringFnWithOwner
|
||||
| BenchmarkWithOwner
|
||||
| ToolGroupWithOwner,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
RoutedProtocol = Inference | Safety | VectorIO | DatasetIO | Scoring | Eval | ToolRuntime
|
||||
|
||||
|
||||
# Example: /inference, /safety
|
||||
class AutoRoutedProviderSpec(ProviderSpec):
|
||||
provider_type: str = "router"
|
||||
config_class: str = ""
|
||||
|
||||
container_image: str | None = None
|
||||
routing_table_api: Api
|
||||
module: str
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
# Example: /models, /shields
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
provider_type: str = "routing_table"
|
||||
config_class: str = ""
|
||||
container_image: str | None = None
|
||||
|
||||
router_api: Api
|
||||
module: str
|
||||
pip_packages: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class Provider(BaseModel):
|
||||
# provider_id of None means that the provider is not enabled - this happens
|
||||
# when the provider is enabled via a conditional environment variable
|
||||
provider_id: str | None
|
||||
provider_type: str
|
||||
config: dict[str, Any] = {}
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the external provider module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class BuildProvider(BaseModel):
|
||||
provider_type: str
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the external provider module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class DistributionSpec(BaseModel):
|
||||
description: str | None = Field(
|
||||
default="",
|
||||
description="Description of the distribution",
|
||||
)
|
||||
container_image: str | None = None
|
||||
providers: dict[str, list[BuildProvider]] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
Provider Types for each of the APIs provided by this distribution. If you
|
||||
select multiple providers, you should provide an appropriate 'routing_map'
|
||||
in the runtime configuration to help route to the correct provider.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class TelemetryConfig(BaseModel):
|
||||
"""
|
||||
Configuration for telemetry.
|
||||
|
||||
Llama Stack uses OpenTelemetry for telemetry. Please refer to https://opentelemetry.io/docs/languages/sdk-configuration/
|
||||
for env variables to configure the OpenTelemetry SDK.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
OTEL_SERVICE_NAME=llama-stack OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 uv run llama stack run starter
|
||||
```
|
||||
"""
|
||||
|
||||
enabled: bool = Field(default=False, description="enable or disable telemetry")
|
||||
|
||||
|
||||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
uri: str
|
||||
token: str | None = Field(default=None, description="token to authorise access to jwks")
|
||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
||||
|
||||
|
||||
class OAuth2IntrospectionConfig(BaseModel):
|
||||
url: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
send_secret_in_body: bool = False
|
||||
|
||||
|
||||
class AuthProviderType(StrEnum):
|
||||
"""Supported authentication provider types."""
|
||||
|
||||
OAUTH2_TOKEN = "oauth2_token"
|
||||
GITHUB_TOKEN = "github_token"
|
||||
CUSTOM = "custom"
|
||||
KUBERNETES = "kubernetes"
|
||||
|
||||
|
||||
class OAuth2TokenAuthConfig(BaseModel):
|
||||
"""Configuration for OAuth2 token authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
|
||||
audience: str = Field(default="llama-stack")
|
||||
verify_tls: bool = Field(default=True)
|
||||
tls_cafile: Path | None = Field(default=None)
|
||||
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"sub": "roles",
|
||||
"username": "roles",
|
||||
"groups": "teams",
|
||||
"team": "teams",
|
||||
"project": "projects",
|
||||
"tenant": "namespaces",
|
||||
"namespace": "namespaces",
|
||||
},
|
||||
)
|
||||
jwks: OAuth2JWKSConfig | None = Field(default=None, description="JWKS configuration")
|
||||
introspection: OAuth2IntrospectionConfig | None = Field(
|
||||
default=None, description="OAuth2 introspection configuration"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@field_validator("claims_mapping")
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode(self) -> Self:
|
||||
if not self.jwks and not self.introspection:
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
if self.jwks and self.introspection:
|
||||
raise ValueError("At present only one of jwks or introspection should be configured")
|
||||
return self
|
||||
|
||||
|
||||
class CustomAuthConfig(BaseModel):
|
||||
"""Configuration for custom authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
|
||||
endpoint: str = Field(
|
||||
...,
|
||||
description="Custom authentication endpoint URL",
|
||||
)
|
||||
|
||||
|
||||
class GitHubTokenAuthConfig(BaseModel):
|
||||
"""Configuration for GitHub token authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.GITHUB_TOKEN] = AuthProviderType.GITHUB_TOKEN
|
||||
github_api_base_url: str = Field(
|
||||
default="https://api.github.com",
|
||||
description="Base URL for GitHub API (use https://api.github.com for public GitHub)",
|
||||
)
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"login": "roles",
|
||||
"organizations": "teams",
|
||||
},
|
||||
description="Mapping from GitHub user fields to access attributes",
|
||||
)
|
||||
|
||||
|
||||
class KubernetesAuthProviderConfig(BaseModel):
|
||||
"""Configuration for Kubernetes authentication provider."""
|
||||
|
||||
type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES
|
||||
api_server_url: str = Field(
|
||||
default="https://kubernetes.default.svc",
|
||||
description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)",
|
||||
)
|
||||
verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates")
|
||||
tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"username": "roles",
|
||||
"groups": "roles",
|
||||
},
|
||||
description="Mapping of Kubernetes user claims to access attributes",
|
||||
)
|
||||
|
||||
@field_validator("api_server_url")
|
||||
@classmethod
|
||||
def validate_api_server_url(cls, v):
|
||||
parsed = urlparse(v)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
raise ValueError(f"api_server_url must be a valid URL with scheme and host: {v}")
|
||||
if parsed.scheme not in ["http", "https"]:
|
||||
raise ValueError(f"api_server_url scheme must be http or https: {v}")
|
||||
return v
|
||||
|
||||
@field_validator("claims_mapping")
|
||||
@classmethod
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
return v
|
||||
|
||||
|
||||
AuthProviderConfig = Annotated[
|
||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class AuthenticationConfig(BaseModel):
|
||||
"""Top-level authentication configuration."""
|
||||
|
||||
provider_config: AuthProviderConfig = Field(
|
||||
...,
|
||||
description="Authentication provider configuration",
|
||||
)
|
||||
access_policy: list[AccessRule] = Field(
|
||||
default=[],
|
||||
description="Rules for determining access to resources",
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationRequiredError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QualifiedModel(BaseModel):
|
||||
"""A qualified model identifier, consisting of a provider ID and a model ID."""
|
||||
|
||||
provider_id: str
|
||||
model_id: str
|
||||
|
||||
|
||||
class VectorStoresConfig(BaseModel):
|
||||
"""Configuration for vector stores in the stack."""
|
||||
|
||||
default_provider_id: str | None = Field(
|
||||
default=None,
|
||||
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
|
||||
)
|
||||
default_embedding_model: QualifiedModel | None = Field(
|
||||
default=None,
|
||||
description="Default embedding model configuration for vector stores.",
|
||||
)
|
||||
|
||||
|
||||
class SafetyConfig(BaseModel):
|
||||
"""Configuration for default moderations model."""
|
||||
|
||||
default_shield_id: str | None = Field(
|
||||
default=None,
|
||||
description="ID of the shield to use for when `model` is not specified in the `moderations` API request.",
|
||||
)
|
||||
|
||||
|
||||
class QuotaPeriod(StrEnum):
|
||||
DAY = "day"
|
||||
|
||||
|
||||
class QuotaConfig(BaseModel):
|
||||
kvstore: KVStoreReference = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
|
||||
authenticated_max_requests: int = Field(
|
||||
default=1000, description="Max requests for authenticated clients per period"
|
||||
)
|
||||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||
|
||||
|
||||
class CORSConfig(BaseModel):
|
||||
allow_origins: list[str] = Field(default_factory=list)
|
||||
allow_origin_regex: str | None = Field(default=None)
|
||||
allow_methods: list[str] = Field(default=["OPTIONS"])
|
||||
allow_headers: list[str] = Field(default_factory=list)
|
||||
allow_credentials: bool = Field(default=False)
|
||||
expose_headers: list[str] = Field(default_factory=list)
|
||||
max_age: int = Field(default=600, ge=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_credentials_config(self) -> Self:
|
||||
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
|
||||
raise ValueError("Cannot use wildcard origins with credentials enabled")
|
||||
return self
|
||||
|
||||
|
||||
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
|
||||
if cors_config is False or cors_config is None:
|
||||
return None
|
||||
|
||||
if cors_config is True:
|
||||
# dev mode: allow localhost on any port
|
||||
return CORSConfig(
|
||||
allow_origins=[],
|
||||
allow_origin_regex=r"https?://localhost:\d+",
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
|
||||
)
|
||||
|
||||
if isinstance(cors_config, CORSConfig):
|
||||
return cors_config
|
||||
|
||||
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
|
||||
|
||||
|
||||
class RegisteredResources(BaseModel):
|
||||
"""Registry of resources available in the distribution."""
|
||||
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
vector_stores: list[VectorStoreInput] = Field(default_factory=list)
|
||||
datasets: list[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
|
||||
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
|
||||
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
port: int = Field(
|
||||
default=8321,
|
||||
description="Port to listen on",
|
||||
ge=1024,
|
||||
le=65535,
|
||||
)
|
||||
tls_certfile: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS certificate file for HTTPS",
|
||||
)
|
||||
tls_keyfile: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS key file for HTTPS",
|
||||
)
|
||||
tls_cafile: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS CA file for HTTPS with mutual TLS authentication",
|
||||
)
|
||||
auth: AuthenticationConfig | None = Field(
|
||||
default=None,
|
||||
description="Authentication configuration for the server",
|
||||
)
|
||||
host: str | None = Field(
|
||||
default=None,
|
||||
description="The host the server should listen on",
|
||||
)
|
||||
quota: QuotaConfig | None = Field(
|
||||
default=None,
|
||||
description="Per client quota request configuration",
|
||||
)
|
||||
cors: bool | CORSConfig | None = Field(
|
||||
default=None,
|
||||
description="CORS configuration for cross-origin requests. Can be:\n"
|
||||
"- true: Enable localhost CORS for development\n"
|
||||
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
image_name: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
|
||||
this could be just a hash
|
||||
""",
|
||||
)
|
||||
container_image: str | None = Field(
|
||||
default=None,
|
||||
description="Reference to the container image if this package refers to a container",
|
||||
)
|
||||
apis: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="""
|
||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||
)
|
||||
|
||||
providers: dict[str, list[Provider]] = Field(
|
||||
description="""
|
||||
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
||||
can be instantiated multiple times (with different configs) if necessary.
|
||||
""",
|
||||
)
|
||||
storage: StorageConfig = Field(
|
||||
description="Catalog of named storage backends and references available to the stack",
|
||||
)
|
||||
|
||||
registered_resources: RegisteredResources = Field(
|
||||
default_factory=RegisteredResources,
|
||||
description="Registry of resources available in the distribution",
|
||||
)
|
||||
|
||||
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||
|
||||
telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig, description="Configuration for telemetry")
|
||||
|
||||
server: ServerConfig = Field(
|
||||
default_factory=ServerConfig,
|
||||
description="Configuration for the HTTP(S) server",
|
||||
)
|
||||
|
||||
external_providers_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
external_apis_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
vector_stores: VectorStoresConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for vector stores, including default embedding model",
|
||||
)
|
||||
|
||||
safety: SafetyConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for default moderations model",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
return Path(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_server_stores(self) -> "StackRunConfig":
|
||||
backend_map = self.storage.backends
|
||||
stores = self.storage.stores
|
||||
kv_backends = {
|
||||
name
|
||||
for name, cfg in backend_map.items()
|
||||
if cfg.type
|
||||
in {
|
||||
StorageBackendType.KV_REDIS,
|
||||
StorageBackendType.KV_SQLITE,
|
||||
StorageBackendType.KV_POSTGRES,
|
||||
StorageBackendType.KV_MONGODB,
|
||||
}
|
||||
}
|
||||
sql_backends = {
|
||||
name
|
||||
for name, cfg in backend_map.items()
|
||||
if cfg.type in {StorageBackendType.SQL_SQLITE, StorageBackendType.SQL_POSTGRES}
|
||||
}
|
||||
|
||||
def _ensure_backend(reference, expected_set, store_name: str) -> None:
|
||||
if reference is None:
|
||||
return
|
||||
backend_name = reference.backend
|
||||
if backend_name not in backend_map:
|
||||
raise ValueError(
|
||||
f"{store_name} references unknown backend '{backend_name}'. "
|
||||
f"Available backends: {sorted(backend_map)}"
|
||||
)
|
||||
if backend_name not in expected_set:
|
||||
raise ValueError(
|
||||
f"{store_name} references backend '{backend_name}' of type "
|
||||
f"'{backend_map[backend_name].type.value}', but a backend of type "
|
||||
f"{'kv_*' if expected_set is kv_backends else 'sql_*'} is required."
|
||||
)
|
||||
|
||||
_ensure_backend(stores.metadata, kv_backends, "storage.stores.metadata")
|
||||
_ensure_backend(stores.inference, sql_backends, "storage.stores.inference")
|
||||
_ensure_backend(stores.conversations, sql_backends, "storage.stores.conversations")
|
||||
_ensure_backend(stores.responses, sql_backends, "storage.stores.responses")
|
||||
_ensure_backend(stores.prompts, kv_backends, "storage.stores.prompts")
|
||||
return self
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
version: int = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
||||
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
|
||||
image_type: str = Field(
|
||||
default="venv",
|
||||
description="Type of package to build (container | venv)",
|
||||
)
|
||||
image_name: str | None = Field(
|
||||
default=None,
|
||||
description="Name of the distribution to build",
|
||||
)
|
||||
external_providers_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||
"pip_packages MUST contain the provider package name.",
|
||||
)
|
||||
additional_pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
|
||||
)
|
||||
external_apis_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
return Path(v)
|
||||
return v
|
||||
276
src/llama_stack/core/distribution.py
Normal file
276
src/llama_stack/core/distribution.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
|
||||
|
||||
|
||||
def stack_apis() -> list[Api]:
|
||||
return list(Api)
|
||||
|
||||
|
||||
class AutoRoutedApiInfo(BaseModel):
|
||||
routing_table_api: Api
|
||||
router_api: Api
|
||||
|
||||
|
||||
def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
||||
return [
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.models,
|
||||
router_api=Api.inference,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.shields,
|
||||
router_api=Api.safety,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.datasets,
|
||||
router_api=Api.datasetio,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.scoring_functions,
|
||||
router_api=Api.scoring,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.benchmarks,
|
||||
router_api=Api.eval,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.tool_groups,
|
||||
router_api=Api.tool_runtime,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.vector_stores,
|
||||
router_api=Api.vector_io,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def providable_apis() -> list[Api]:
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api not in INTERNAL_APIS]
|
||||
|
||||
|
||||
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
||||
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
|
||||
return spec
|
||||
|
||||
|
||||
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
|
||||
return spec
|
||||
|
||||
|
||||
def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
"""Get the provider registry, optionally including external providers.
|
||||
|
||||
This function loads both built-in providers and external providers from YAML files or from their provided modules.
|
||||
External providers are loaded from a directory structure like:
|
||||
|
||||
providers.d/
|
||||
remote/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
inline/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
|
||||
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
|
||||
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
|
||||
There is special handling for all of the potential cases this method can be called from.
|
||||
|
||||
Args:
|
||||
config: Optional object containing the external providers directory path
|
||||
building: Optional bool delineating whether or not this is being called from a build process
|
||||
|
||||
Returns:
|
||||
A dictionary mapping APIs to their available providers
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the external providers directory doesn't exist
|
||||
ValueError: If any provider spec is invalid
|
||||
"""
|
||||
|
||||
registry: dict[Api, dict[str, ProviderSpec]] = {}
|
||||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
logger.debug(f"Importing module {name}")
|
||||
try:
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {name}: {e}")
|
||||
|
||||
# Refresh providable APIs with external APIs if any
|
||||
external_apis = load_external_apis(config)
|
||||
for api, api_spec in external_apis.items():
|
||||
name = api_spec.name.lower()
|
||||
logger.info(f"Importing external API {name} module {api_spec.module}")
|
||||
try:
|
||||
module = importlib.import_module(api_spec.module)
|
||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except (ImportError, AttributeError) as e:
|
||||
# Populate the registry with an empty dict to avoid breaking the provider registry
|
||||
# This assume that the in-tree provider(s) are not available for this API which means
|
||||
# that users will need to use external providers for this API.
|
||||
registry[api] = {}
|
||||
logger.error(
|
||||
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
|
||||
"Install the API package to load any in-tree providers for this API."
|
||||
)
|
||||
|
||||
# Check if config has external providers
|
||||
if config:
|
||||
if hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||
registry = get_external_providers_from_dir(registry, config)
|
||||
# else lets check for modules in each provider
|
||||
registry = get_external_providers_from_module(
|
||||
registry=registry,
|
||||
config=config,
|
||||
building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)),
|
||||
)
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def get_external_providers_from_dir(
|
||||
registry: dict[Api, dict[str, ProviderSpec]], config
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
logger.warning(
|
||||
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
|
||||
)
|
||||
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||
|
||||
for api in providable_apis():
|
||||
api_name = api.name.lower()
|
||||
|
||||
# Process both remote and inline providers
|
||||
for provider_type in ["remote", "inline"]:
|
||||
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||
if not os.path.exists(api_dir):
|
||||
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||
continue
|
||||
|
||||
# Look for provider spec files in the API directory
|
||||
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||
|
||||
try:
|
||||
with open(spec_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
|
||||
if provider_type == "remote":
|
||||
spec = _load_remote_provider_spec(spec_data, api)
|
||||
provider_type_key = f"remote::{provider_name}"
|
||||
else:
|
||||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||
provider_type_key = f"inline::{provider_name}"
|
||||
|
||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||
if provider_type_key in registry[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
registry[api][provider_type_key] = spec
|
||||
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
raise yaml_err
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||
raise e
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def get_external_providers_from_module(
|
||||
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
provider_list = None
|
||||
if isinstance(config, BuildConfig):
|
||||
provider_list = config.distribution_spec.providers.items()
|
||||
else:
|
||||
provider_list = config.providers.items()
|
||||
if provider_list is None:
|
||||
logger.warning("Could not get list of providers from config")
|
||||
return registry
|
||||
for provider_api, providers in provider_list:
|
||||
for provider in providers:
|
||||
if not hasattr(provider, "module") or provider.module is None:
|
||||
continue
|
||||
# get provider using module
|
||||
try:
|
||||
if not building:
|
||||
package_name = provider.module.split("==")[0]
|
||||
module = importlib.import_module(f"{package_name}.provider")
|
||||
# if config class is wrong you will get an error saying module could not be imported
|
||||
spec = module.get_provider_spec()
|
||||
else:
|
||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
spec = ProviderSpec(
|
||||
api=Api(provider_api),
|
||||
provider_type=provider.provider_type,
|
||||
is_external=True,
|
||||
module=provider.module,
|
||||
config_class="",
|
||||
)
|
||||
provider_type = provider.provider_type
|
||||
if isinstance(spec, list):
|
||||
# optionally allow people to pass inline and remote provider specs as a returned list.
|
||||
# with the old method, users could pass in directories of specs using overlapping code
|
||||
# we want to ensure we preserve that flexibility in this method.
|
||||
logger.info(
|
||||
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
|
||||
)
|
||||
for provider_spec in spec:
|
||||
if provider_spec.provider_type != provider.provider_type:
|
||||
continue
|
||||
logger.info(f"Adding {provider.provider_type} to registry")
|
||||
registry[Api(provider_api)][provider.provider_type] = provider_spec
|
||||
else:
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
||||
) from exc
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
|
||||
raise e
|
||||
return registry
|
||||
54
src/llama_stack/core/external.py
Normal file
54
src/llama_stack/core/external.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
|
||||
"""Load external API specifications from the configured directory.
|
||||
|
||||
Args:
|
||||
config: StackRunConfig or BuildConfig containing the external APIs directory path
|
||||
|
||||
Returns:
|
||||
A dictionary mapping API names to their specifications
|
||||
"""
|
||||
if not config or not config.external_apis_dir:
|
||||
return {}
|
||||
|
||||
external_apis_dir = config.external_apis_dir.expanduser().resolve()
|
||||
if not external_apis_dir.is_dir():
|
||||
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
|
||||
return {}
|
||||
|
||||
logger.info(f"Loading external APIs from {external_apis_dir}")
|
||||
external_apis: dict[Api, ExternalApiSpec] = {}
|
||||
|
||||
# Look for YAML files in the external APIs directory
|
||||
for yaml_path in external_apis_dir.glob("*.yaml"):
|
||||
try:
|
||||
with open(yaml_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
|
||||
spec = ExternalApiSpec(**spec_data)
|
||||
api = Api.add(spec.name)
|
||||
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
|
||||
external_apis[api] = spec
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Failed to load external API spec from {yaml_path}")
|
||||
raise
|
||||
|
||||
return external_apis
|
||||
42
src/llama_stack/core/id_generation.py
Normal file
42
src/llama_stack/core/id_generation.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
IdFactory = Callable[[], str]
|
||||
IdOverride = Callable[[str, IdFactory], str]
|
||||
|
||||
_id_override: IdOverride | None = None
|
||||
|
||||
|
||||
def generate_object_id(kind: str, factory: IdFactory) -> str:
|
||||
"""Generate an identifier for the given kind using the provided factory.
|
||||
|
||||
Allows tests to override ID generation deterministically by installing an
|
||||
override callback via :func:`set_id_override`.
|
||||
"""
|
||||
|
||||
override = _id_override
|
||||
if override is not None:
|
||||
return override(kind, factory)
|
||||
return factory()
|
||||
|
||||
|
||||
def set_id_override(override: IdOverride) -> IdOverride | None:
|
||||
"""Install an override used to generate deterministic identifiers."""
|
||||
|
||||
global _id_override
|
||||
|
||||
previous = _id_override
|
||||
_id_override = override
|
||||
return previous
|
||||
|
||||
|
||||
def reset_id_override(previous: IdOverride | None) -> None:
|
||||
"""Restore the previous override returned by :func:`set_id_override`."""
|
||||
|
||||
global _id_override
|
||||
_id_override = previous
|
||||
86
src/llama_stack/core/inspect.py
Normal file
86
src/llama_stack/core/inspect.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from importlib.metadata import version
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inspect import (
|
||||
HealthInfo,
|
||||
Inspect,
|
||||
ListRoutesResponse,
|
||||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = DistributionInspectImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class DistributionInspectImpl(Inspect):
|
||||
def __init__(self, config: DistributionInspectConfig, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
run_config: StackRunConfig = self.config.run_config
|
||||
|
||||
ret = []
|
||||
external_apis = load_external_apis(run_config)
|
||||
all_endpoints = get_all_api_routes(external_apis)
|
||||
for api, endpoints in all_endpoints.items():
|
||||
# Always include provider and inspect APIs, filter others based on run config
|
||||
if api.value in ["providers", "inspect"]:
|
||||
ret.extend(
|
||||
[
|
||||
RouteInfo(
|
||||
route=e.path,
|
||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||
)
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
]
|
||||
)
|
||||
else:
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
if providers: # Only process if there are providers for this API
|
||||
ret.extend(
|
||||
[
|
||||
RouteInfo(
|
||||
route=e.path,
|
||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
]
|
||||
)
|
||||
|
||||
return ListRoutesResponse(data=ret)
|
||||
|
||||
async def health(self) -> HealthInfo:
|
||||
return HealthInfo(status=HealthStatus.OK)
|
||||
|
||||
async def version(self) -> VersionInfo:
|
||||
return VersionInfo(version=version("llama-stack"))
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
540
src/llama_stack/core/library_client.py
Normal file
540
src/llama_stack/core/library_client.py
Normal file
|
|
@ -0,0 +1,540 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import Response as FastAPIResponse
|
||||
from llama_stack_client import (
|
||||
NOT_GIVEN,
|
||||
APIResponse,
|
||||
AsyncAPIResponse,
|
||||
AsyncLlamaStackClient,
|
||||
AsyncStream,
|
||||
LlamaStackClient,
|
||||
)
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.core.build import print_pip_install_help
|
||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.core.resolver import ProviderRegistry
|
||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
get_stack_run_config_from_distro,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.core.telemetry import Telemetry
|
||||
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.core.utils.exec import in_notebook
|
||||
from llama_stack.log import get_logger, setup_logging
|
||||
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_pydantic_to_json_value(value: Any) -> Any:
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
elif isinstance(value, list):
|
||||
return [convert_pydantic_to_json_value(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, BaseModel):
|
||||
return json.loads(value.model_dump_json())
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
|
||||
return value
|
||||
|
||||
origin = get_origin(annotation)
|
||||
|
||||
if origin is list:
|
||||
item_type = get_args(annotation)[0]
|
||||
try:
|
||||
return [convert_to_pydantic(item_type, item) for item in value]
|
||||
except Exception:
|
||||
logger.error(f"Error converting list {value} into {item_type}")
|
||||
return value
|
||||
|
||||
elif origin is dict:
|
||||
key_type, val_type = get_args(annotation)
|
||||
try:
|
||||
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
||||
except Exception:
|
||||
logger.error(f"Error converting dict {value} into {val_type}")
|
||||
return value
|
||||
|
||||
try:
|
||||
# Handle Pydantic models and discriminated unions
|
||||
return TypeAdapter(annotation).validate_python(value)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: this is workardound for having Union[str, AgentToolGroup] in API schema.
|
||||
# We should get rid of any non-discriminated unions in the API schema.
|
||||
if origin is Union:
|
||||
for union_type in get_args(annotation):
|
||||
try:
|
||||
return convert_to_pydantic(union_type, value)
|
||||
except Exception:
|
||||
continue
|
||||
logger.warning(
|
||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||
)
|
||||
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
|
||||
|
||||
|
||||
class LibraryClientUploadFile:
|
||||
"""LibraryClient UploadFile object that mimics FastAPI's UploadFile interface."""
|
||||
|
||||
def __init__(self, filename: str, content: bytes):
|
||||
self.filename = filename
|
||||
self.content = content
|
||||
self.content_type = "application/octet-stream"
|
||||
|
||||
async def read(self) -> bytes:
|
||||
return self.content
|
||||
|
||||
|
||||
class LibraryClientHttpxResponse:
|
||||
"""LibraryClient httpx Response object for FastAPI Response conversion."""
|
||||
|
||||
def __init__(self, response):
|
||||
self.content = response.body if isinstance(response.body, bytes) else response.body.encode()
|
||||
self.status_code = response.status_code
|
||||
self.headers = response.headers
|
||||
|
||||
|
||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
config_path_or_distro_name: str,
|
||||
skip_logger_removal: bool = False,
|
||||
custom_provider_registry: ProviderRegistry | None = None,
|
||||
provider_data: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
||||
)
|
||||
self.provider_data = provider_data
|
||||
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
# use a new event loop to avoid interfering with the main event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(self.async_client.initialize())
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
Deprecated method for backward compatibility.
|
||||
"""
|
||||
pass
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
loop = self.loop
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if kwargs.get("stream"):
|
||||
|
||||
def sync_generator():
|
||||
try:
|
||||
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
|
||||
while True:
|
||||
chunk = loop.run_until_complete(async_stream.__anext__())
|
||||
yield chunk
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
|
||||
return sync_generator()
|
||||
else:
|
||||
try:
|
||||
result = loop.run_until_complete(self.async_client.request(*args, **kwargs))
|
||||
finally:
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
return result
|
||||
|
||||
|
||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
config_path_or_distro_name: str,
|
||||
custom_provider_registry: ProviderRegistry | None = None,
|
||||
provider_data: dict[str, Any] | None = None,
|
||||
skip_logger_removal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Initialize logging from environment variables first
|
||||
setup_logging()
|
||||
|
||||
# when using the library client, we should not log to console since many
|
||||
# of our logs are intended for server-side usage
|
||||
if sinks_from_env := os.environ.get("TELEMETRY_SINKS", None):
|
||||
current_sinks = sinks_from_env.strip().lower().split(",")
|
||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
if not skip_logger_removal:
|
||||
self._remove_root_logger_handlers()
|
||||
|
||||
if config_path_or_distro_name.endswith(".yaml"):
|
||||
config_path = Path(config_path_or_distro_name)
|
||||
if not config_path.exists():
|
||||
raise ValueError(f"Config file {config_path} does not exist")
|
||||
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
else:
|
||||
# distribution
|
||||
config = get_stack_run_config_from_distro(config_path_or_distro_name)
|
||||
|
||||
self.config_path_or_distro_name = config_path_or_distro_name
|
||||
self.config = config
|
||||
self.custom_provider_registry = custom_provider_registry
|
||||
self.provider_data = provider_data
|
||||
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
||||
|
||||
def _remove_root_logger_handlers(self):
|
||||
"""
|
||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||
"""
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""
|
||||
Initialize the async client.
|
||||
|
||||
Returns:
|
||||
bool: True if initialization was successful
|
||||
"""
|
||||
|
||||
try:
|
||||
self.route_impls = None
|
||||
|
||||
stack = Stack(self.config, self.custom_provider_registry)
|
||||
await stack.initialize()
|
||||
self.impls = stack.impls
|
||||
except ModuleNotFoundError as _e:
|
||||
cprint(_e.msg, color="red", file=sys.stderr)
|
||||
cprint(
|
||||
"Using llama-stack as a library requires installing dependencies depending on the distribution (providers) you choose.\n",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if self.config_path_or_distro_name.endswith(".yaml"):
|
||||
providers: dict[str, list[BuildProvider]] = {}
|
||||
for api, run_providers in self.config.providers.items():
|
||||
for provider in run_providers:
|
||||
providers.setdefault(api, []).append(
|
||||
BuildProvider(provider_type=provider.provider_type, module=provider.module)
|
||||
)
|
||||
providers = dict(providers)
|
||||
build_config = BuildConfig(
|
||||
distribution_spec=DistributionSpec(
|
||||
providers=providers,
|
||||
),
|
||||
external_providers_dir=self.config.external_providers_dir,
|
||||
)
|
||||
print_pip_install_help(build_config)
|
||||
else:
|
||||
prefix = "!" if in_notebook() else ""
|
||||
cprint(
|
||||
f"Please run:\n\n{prefix}llama stack list-deps {self.config_path_or_distro_name} | xargs -L1 uv pip install\n\n",
|
||||
"yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
cprint(
|
||||
"Please check your internet connection and try again.",
|
||||
"red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise _e
|
||||
|
||||
assert self.impls is not None
|
||||
if self.config.telemetry.enabled:
|
||||
setup_logger(Telemetry())
|
||||
|
||||
if not os.environ.get("PYTEST_CURRENT_TEST"):
|
||||
console = Console()
|
||||
console.print(f"Using config [blue]{self.config_path_or_distro_name}[/blue]:")
|
||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||
console.print(yaml.dump(safe_config, indent=2))
|
||||
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
return True
|
||||
|
||||
async def request(
|
||||
self,
|
||||
cast_to: Any,
|
||||
options: Any,
|
||||
*,
|
||||
stream=False,
|
||||
stream_cls=None,
|
||||
):
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized. Please call initialize() first.")
|
||||
|
||||
# Create headers with provider data if available
|
||||
headers = options.headers or {}
|
||||
if self.provider_data:
|
||||
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
|
||||
if all(key not in headers for key in keys):
|
||||
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||
|
||||
# Use context manager for provider data
|
||||
with request_provider_data_context(headers):
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
response = await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
return response
|
||||
|
||||
def _handle_file_uploads(self, options: Any, body: dict) -> tuple[dict, list[str]]:
|
||||
"""Handle file uploads from OpenAI client and add them to the request body."""
|
||||
if not (hasattr(options, "files") and options.files):
|
||||
return body, []
|
||||
|
||||
if not isinstance(options.files, list):
|
||||
return body, []
|
||||
|
||||
field_names = []
|
||||
for file_tuple in options.files:
|
||||
if not (isinstance(file_tuple, tuple) and len(file_tuple) >= 2):
|
||||
continue
|
||||
|
||||
field_name = file_tuple[0]
|
||||
file_object = file_tuple[1]
|
||||
|
||||
if isinstance(file_object, BytesIO):
|
||||
file_object.seek(0)
|
||||
file_content = file_object.read()
|
||||
filename = getattr(file_object, "name", "uploaded_file")
|
||||
field_names.append(field_name)
|
||||
body[field_name] = LibraryClientUploadFile(filename, file_content)
|
||||
|
||||
return body, field_names
|
||||
|
||||
async def _call_non_streaming(
|
||||
self,
|
||||
*,
|
||||
cast_to: Any,
|
||||
options: Any,
|
||||
):
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
# Merge extra_json parameters (extra_body from SDK is converted to extra_json)
|
||||
if hasattr(options, "extra_json") and options.extra_json:
|
||||
body |= options.extra_json
|
||||
|
||||
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
try:
|
||||
result = await matched_func(**body)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
# Handle FastAPI Response objects (e.g., from file content retrieval)
|
||||
if isinstance(result, FastAPIResponse):
|
||||
return LibraryClientHttpxResponse(result)
|
||||
|
||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
|
||||
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
|
||||
|
||||
status_code = httpx.codes.OK
|
||||
|
||||
if options.method.upper() == "DELETE" and result is None:
|
||||
status_code = httpx.codes.NO_CONTENT
|
||||
|
||||
if status_code == httpx.codes.NO_CONTENT:
|
||||
json_content = ""
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=status_code,
|
||||
content=json_content.encode("utf-8"),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers or {},
|
||||
json=convert_pydantic_to_json_value(filtered_body),
|
||||
),
|
||||
)
|
||||
response = APIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=False,
|
||||
stream_cls=None,
|
||||
)
|
||||
return response.parse()
|
||||
|
||||
async def _call_streaming(
|
||||
self,
|
||||
*,
|
||||
cast_to: Any,
|
||||
options: Any,
|
||||
stream_cls: Any,
|
||||
):
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
# Prepare body for the function call (handles both Pydantic and traditional params)
|
||||
body = self._convert_body(func, body)
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
||||
async def gen():
|
||||
try:
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
sse_event = f"data: {data}\n\n"
|
||||
yield sse_event.encode("utf-8")
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=wrapped_gen,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers or {},
|
||||
json=convert_pydantic_to_json_value(body),
|
||||
),
|
||||
)
|
||||
|
||||
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
||||
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
||||
# so we need to convert it to AsyncStream
|
||||
# mypy can't track runtime variables inside the [...] of a generic, so ignore that check
|
||||
args = get_args(stream_cls)
|
||||
stream_cls = AsyncStream[args[0]] # type: ignore[valid-type]
|
||||
response = AsyncAPIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=True,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
|
||||
body = body or {}
|
||||
exclude_params = exclude_params or set()
|
||||
sig = inspect.signature(func)
|
||||
params_list = [p for p in sig.parameters.values() if p.name != "self"]
|
||||
|
||||
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
|
||||
if len(params_list) == 1:
|
||||
param = params_list[0]
|
||||
param_type = param.annotation
|
||||
if is_unwrapped_body_param(param_type):
|
||||
base_type = get_args(param_type)[0]
|
||||
return {param.name: base_type(**body)}
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||
|
||||
# Check if there's an unwrapped body parameter among multiple parameters
|
||||
# (e.g., path param + body param like: vector_store_id: str, params: Annotated[Model, Body(...)])
|
||||
unwrapped_body_param = None
|
||||
for param in params_list:
|
||||
if is_unwrapped_body_param(param.annotation):
|
||||
unwrapped_body_param = param
|
||||
break
|
||||
|
||||
# Convert parameters to Pydantic models where needed
|
||||
converted_body = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name in body:
|
||||
value = body.get(param_name)
|
||||
if param_name in exclude_params:
|
||||
converted_body[param_name] = value
|
||||
else:
|
||||
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||
|
||||
# handle unwrapped body parameter after processing all named parameters
|
||||
if unwrapped_body_param:
|
||||
base_type = get_args(unwrapped_body_param.annotation)[0]
|
||||
# extract only keys not already used by other params
|
||||
remaining_keys = {k: v for k, v in body.items() if k not in converted_body}
|
||||
converted_body[unwrapped_body_param.name] = base_type(**remaining_keys)
|
||||
|
||||
return converted_body
|
||||
5
src/llama_stack/core/prompts/__init__.py
Normal file
5
src/llama_stack/core/prompts/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
232
src/llama_stack/core/prompts/prompts.py
Normal file
232
src/llama_stack/core/prompts/prompts.py
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
|
||||
|
||||
class PromptServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in prompt service.
|
||||
|
||||
:param run_config: Stack run configuration containing distribution info
|
||||
"""
|
||||
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
|
||||
"""Get the prompt service implementation."""
|
||||
impl = PromptServiceImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class PromptServiceImpl(Prompts):
|
||||
"""Built-in prompt service implementation using KVStore."""
|
||||
|
||||
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.kvstore: KVStore
|
||||
|
||||
async def initialize(self) -> None:
|
||||
# Use prompts store reference from run config
|
||||
prompts_ref = self.config.run_config.storage.stores.prompts
|
||||
if not prompts_ref:
|
||||
raise ValueError("storage.stores.prompts must be configured in run config")
|
||||
self.kvstore = await kvstore_impl(prompts_ref)
|
||||
|
||||
def _get_default_key(self, prompt_id: str) -> str:
|
||||
"""Get the KVStore key that stores the default version number."""
|
||||
return f"prompts:v1:{prompt_id}:default"
|
||||
|
||||
async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
|
||||
"""Get the KVStore key for prompt data, returning default version if applicable."""
|
||||
if version:
|
||||
return self._get_version_key(prompt_id, str(version))
|
||||
|
||||
default_key = self._get_default_key(prompt_id)
|
||||
resolved_version = await self.kvstore.get(default_key)
|
||||
if resolved_version is None:
|
||||
raise ValueError(f"Prompt {prompt_id}:default not found")
|
||||
return self._get_version_key(prompt_id, resolved_version)
|
||||
|
||||
def _get_version_key(self, prompt_id: str, version: str) -> str:
|
||||
"""Get the KVStore key for a specific prompt version."""
|
||||
return f"prompts:v1:{prompt_id}:{version}"
|
||||
|
||||
def _get_list_key_prefix(self) -> str:
|
||||
"""Get the key prefix for listing prompts."""
|
||||
return "prompts:v1:"
|
||||
|
||||
def _serialize_prompt(self, prompt: Prompt) -> str:
|
||||
"""Serialize a prompt to JSON string for storage."""
|
||||
return json.dumps(
|
||||
{
|
||||
"prompt_id": prompt.prompt_id,
|
||||
"prompt": prompt.prompt,
|
||||
"version": prompt.version,
|
||||
"variables": prompt.variables or [],
|
||||
"is_default": prompt.is_default,
|
||||
}
|
||||
)
|
||||
|
||||
def _deserialize_prompt(self, data: str) -> Prompt:
|
||||
"""Deserialize a prompt from JSON string."""
|
||||
obj = json.loads(data)
|
||||
return Prompt(
|
||||
prompt_id=obj["prompt_id"],
|
||||
prompt=obj["prompt"],
|
||||
version=obj["version"],
|
||||
variables=obj.get("variables", []),
|
||||
is_default=obj.get("is_default", False),
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
"""List all prompts (default versions only)."""
|
||||
prefix = self._get_list_key_prefix()
|
||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||
|
||||
prompts = []
|
||||
for key in keys:
|
||||
if key.endswith(":default"):
|
||||
try:
|
||||
default_version = await self.kvstore.get(key)
|
||||
if default_version:
|
||||
prompt_id = key.replace(prefix, "").replace(":default", "")
|
||||
version_key = self._get_version_key(prompt_id, default_version)
|
||||
data = await self.kvstore.get(version_key)
|
||||
if data:
|
||||
prompt = self._deserialize_prompt(data)
|
||||
prompts.append(prompt)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
|
||||
return ListPromptsResponse(data=prompts)
|
||||
|
||||
async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
|
||||
"""Get a prompt by its identifier and optional version."""
|
||||
key = await self._get_prompt_key(prompt_id, version)
|
||||
data = await self.kvstore.get(key)
|
||||
if data is None:
|
||||
raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
|
||||
return self._deserialize_prompt(data)
|
||||
|
||||
async def create_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Create a new prompt."""
|
||||
if variables is None:
|
||||
variables = []
|
||||
|
||||
prompt_obj = Prompt(
|
||||
prompt_id=Prompt.generate_prompt_id(),
|
||||
prompt=prompt,
|
||||
version=1,
|
||||
variables=variables,
|
||||
)
|
||||
|
||||
version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
|
||||
data = self._serialize_prompt(prompt_obj)
|
||||
await self.kvstore.set(version_key, data)
|
||||
|
||||
default_key = self._get_default_key(prompt_obj.prompt_id)
|
||||
await self.kvstore.set(default_key, str(prompt_obj.version))
|
||||
|
||||
return prompt_obj
|
||||
|
||||
async def update_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt: str,
|
||||
version: int,
|
||||
variables: list[str] | None = None,
|
||||
set_as_default: bool = True,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version)."""
|
||||
if version < 1:
|
||||
raise ValueError("Version must be >= 1")
|
||||
if variables is None:
|
||||
variables = []
|
||||
|
||||
prompt_versions = await self.list_prompt_versions(prompt_id)
|
||||
latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version))
|
||||
|
||||
if version and latest_prompt.version != version:
|
||||
raise ValueError(
|
||||
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request."
|
||||
)
|
||||
|
||||
current_version = latest_prompt.version if version is None else version
|
||||
new_version = current_version + 1
|
||||
|
||||
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
|
||||
|
||||
version_key = self._get_version_key(prompt_id, str(new_version))
|
||||
data = self._serialize_prompt(updated_prompt)
|
||||
await self.kvstore.set(version_key, data)
|
||||
|
||||
if set_as_default:
|
||||
await self.set_default_version(prompt_id, new_version)
|
||||
|
||||
return updated_prompt
|
||||
|
||||
async def delete_prompt(self, prompt_id: str) -> None:
|
||||
"""Delete a prompt and all its versions."""
|
||||
await self.get_prompt(prompt_id)
|
||||
|
||||
prefix = f"prompts:v1:{prompt_id}:"
|
||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||
|
||||
for key in keys:
|
||||
await self.kvstore.delete(key)
|
||||
|
||||
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
|
||||
"""List all versions of a specific prompt."""
|
||||
prefix = f"prompts:v1:{prompt_id}:"
|
||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||
|
||||
default_version = None
|
||||
prompts = []
|
||||
|
||||
for key in keys:
|
||||
data = await self.kvstore.get(key)
|
||||
if key.endswith(":default"):
|
||||
default_version = data
|
||||
else:
|
||||
if data:
|
||||
prompt_obj = self._deserialize_prompt(data)
|
||||
prompts.append(prompt_obj)
|
||||
|
||||
if not prompts:
|
||||
raise ValueError(f"Prompt {prompt_id} not found")
|
||||
|
||||
for prompt in prompts:
|
||||
prompt.is_default = str(prompt.version) == default_version
|
||||
|
||||
prompts.sort(key=lambda x: x.version)
|
||||
return ListPromptsResponse(data=prompts)
|
||||
|
||||
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
|
||||
"""Set which version of a prompt should be the default, If not set. the default is the latest."""
|
||||
version_key = self._get_version_key(prompt_id, str(version))
|
||||
data = await self.kvstore.get(version_key)
|
||||
if data is None:
|
||||
raise ValueError(f"Prompt {prompt_id} version {version} not found")
|
||||
|
||||
default_key = self._get_default_key(prompt_id)
|
||||
await self.kvstore.set(default_key, str(version))
|
||||
|
||||
return self._deserialize_prompt(data)
|
||||
137
src/llama_stack/core/providers.py
Normal file
137
src/llama_stack/core/providers.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .utils.config import redact_sensitive_fields
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ProviderImplConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = ProviderImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ProviderImpl(Providers):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ProviderImpl.shutdown")
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
providers_health = await self.get_providers_health()
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
for p in providers:
|
||||
# Skip providers that are not enabled
|
||||
if p.provider_id is None:
|
||||
continue
|
||||
ret.append(
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
health=providers_health.get(api, {}).get(
|
||||
p.provider_id,
|
||||
HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
all_providers = await self.list_providers()
|
||||
for p in all_providers.data:
|
||||
if p.provider_id == provider_id:
|
||||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
|
||||
"""Get health status for all providers.
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
||||
Each API maps to a dictionary of provider IDs to their health responses.
|
||||
"""
|
||||
providers_health: dict[str, dict[str, HealthResponse]] = {}
|
||||
|
||||
# The timeout has to be long enough to allow all the providers to be checked, especially in
|
||||
# the case of the inference router health check since it checks all registered inference
|
||||
# providers.
|
||||
# The timeout must not be equal to the one set by health method for a given implementation,
|
||||
# otherwise we will miss some providers.
|
||||
timeout = 3.0
|
||||
|
||||
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
||||
# Skip special implementations (inspect/providers) that don't have provider specs
|
||||
if not hasattr(impl, "__provider_spec__"):
|
||||
return None
|
||||
api_name = impl.__provider_spec__.api.name
|
||||
if not hasattr(impl, "health"):
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
return api_name, health
|
||||
except TimeoutError:
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
|
||||
)
|
||||
|
||||
# Create tasks for all providers
|
||||
tasks = [check_provider_health(impl) for impl in self.deps.values()]
|
||||
|
||||
# Wait for all health checks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Organize results by API and provider ID
|
||||
for result in results:
|
||||
if result is None: # Skip special implementations
|
||||
continue
|
||||
api_name, health_response = result
|
||||
providers_health[api_name] = health_response
|
||||
|
||||
return providers_health
|
||||
115
src/llama_stack/core/request_headers.py
Normal file
115
src/llama_stack/core/request_headers.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import contextvars
|
||||
import json
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
# Context variable for request provider data and auth attributes
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(AbstractContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None):
|
||||
self.provider_data = provider_data or {}
|
||||
if user:
|
||||
self.provider_data["__authenticated_user"] = user
|
||||
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
# Save the current value and set the new one
|
||||
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Restore the previous value
|
||||
if self.token is not None:
|
||||
PROVIDER_DATA_VAR.reset(self.token)
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
def get_request_provider_data(self) -> Any:
|
||||
spec = self.__provider_spec__
|
||||
if not spec:
|
||||
raise ValueError(f"Provider spec not set on {self.__class__}")
|
||||
|
||||
provider_type = spec.provider_type
|
||||
validator_class = spec.provider_data_validator
|
||||
if not validator_class:
|
||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||
|
||||
val = PROVIDER_DATA_VAR.get()
|
||||
if not val:
|
||||
return None
|
||||
|
||||
validator = instantiate_class_type(validator_class)
|
||||
try:
|
||||
provider_data = validator(**val)
|
||||
return provider_data
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing provider data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | None:
|
||||
"""Parse provider data from request headers"""
|
||||
keys = [
|
||||
"X-LlamaStack-Provider-Data",
|
||||
"x-llamastack-provider-data",
|
||||
]
|
||||
val = None
|
||||
for key in keys:
|
||||
val = headers.get(key, None)
|
||||
if val:
|
||||
break
|
||||
|
||||
if not val:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
log.error("Provider data not encoded as a JSON object!")
|
||||
return None
|
||||
|
||||
|
||||
def request_provider_data_context(
|
||||
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
|
||||
) -> AbstractContextManager:
|
||||
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||
|
||||
|
||||
def get_authenticated_user() -> User | None:
|
||||
"""Helper to retrieve auth attributes from the provider data context"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__authenticated_user")
|
||||
|
||||
|
||||
def user_from_scope(scope: dict) -> User | None:
|
||||
"""Create a User object from ASGI scope data (set by authentication middleware)"""
|
||||
user_attributes = scope.get("user_attributes", {})
|
||||
principal = scope.get("principal", "")
|
||||
|
||||
# auth not enabled
|
||||
if not principal and not user_attributes:
|
||||
return None
|
||||
|
||||
return User(principal=principal, attributes=user_attributes)
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue