mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge branch 'main' into toolcall-arg-recursive-type
This commit is contained in:
commit
113cb4cd65
1665 changed files with 65332 additions and 619376 deletions
|
|
@ -3,8 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.core.library_client import ( # noqa: F401
|
||||
AsyncLlamaStackAsLibraryClient,
|
||||
LlamaStackAsLibraryClient,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,30 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
from typing import Annotated, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolPromptFormat,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
|
|
@ -57,658 +40,12 @@ class ResponseGuardrailSpec(BaseModel):
|
|||
ResponseGuardrail = str | ResponseGuardrailSpec
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
||||
:param content: The content of the attachment.
|
||||
:param mime_type: The MIME type of the attachment.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""A document to be used by an agent.
|
||||
|
||||
:param content: The content of the document.
|
||||
:param mime_type: The MIME type of the document.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class StepCommon(BaseModel):
|
||||
"""A common step in an agent turn.
|
||||
|
||||
:param turn_id: The ID of the turn.
|
||||
:param step_id: The ID of the step.
|
||||
:param started_at: The time the step started.
|
||||
:param completed_at: The time the step completed.
|
||||
"""
|
||||
|
||||
turn_id: str
|
||||
step_id: str
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
class StepType(StrEnum):
|
||||
"""Type of the step in an agent turn.
|
||||
|
||||
:cvar inference: The step is an inference step that calls an LLM.
|
||||
:cvar tool_execution: The step is a tool execution step that executes a tool call.
|
||||
:cvar shield_call: The step is a shield call step that checks for safety violations.
|
||||
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
|
||||
"""
|
||||
|
||||
inference = "inference"
|
||||
tool_execution = "tool_execution"
|
||||
shield_call = "shield_call"
|
||||
memory_retrieval = "memory_retrieval"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceStep(StepCommon):
|
||||
"""An inference step in an agent turn.
|
||||
|
||||
:param model_response: The response from the LLM.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference] = StepType.inference
|
||||
model_response: CompletionMessage
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolExecutionStep(StepCommon):
|
||||
"""A tool execution step in an agent turn.
|
||||
|
||||
:param tool_calls: The tool calls to execute.
|
||||
:param tool_responses: The tool responses from the tool calls.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
|
||||
tool_calls: list[ToolCall]
|
||||
tool_responses: list[ToolResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldCallStep(StepCommon):
|
||||
"""A shield call step in an agent turn.
|
||||
|
||||
:param violation: The violation from the shield call.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.shield_call] = StepType.shield_call
|
||||
violation: SafetyViolation | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryRetrievalStep(StepCommon):
|
||||
"""A memory retrieval step in an agent turn.
|
||||
|
||||
:param vector_store_ids: The IDs of the vector databases to retrieve context from.
|
||||
:param inserted_context: The context retrieved from the vector databases.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
|
||||
# TODO: should this be List[str]?
|
||||
vector_store_ids: str
|
||||
inserted_context: InterleavedContent
|
||||
|
||||
|
||||
Step = Annotated[
|
||||
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
|
||||
Field(discriminator="step_type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Turn(BaseModel):
|
||||
"""A single turn in an interaction with an Agentic System.
|
||||
|
||||
:param turn_id: Unique identifier for the turn within a session
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param input_messages: List of messages that initiated this turn
|
||||
:param steps: Ordered list of processing steps executed during this turn
|
||||
:param output_message: The model's generated response containing content and metadata
|
||||
:param output_attachments: (Optional) Files or media attached to the agent's response
|
||||
:param started_at: Timestamp when the turn began
|
||||
:param completed_at: (Optional) Timestamp when the turn finished, if completed
|
||||
"""
|
||||
|
||||
turn_id: str
|
||||
session_id: str
|
||||
input_messages: list[UserMessage | ToolResponseMessage]
|
||||
steps: list[Step]
|
||||
output_message: CompletionMessage
|
||||
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
|
||||
|
||||
started_at: datetime
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Session(BaseModel):
|
||||
"""A single session of an interaction with an Agentic System.
|
||||
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param session_name: Human-readable name for the session
|
||||
:param turns: List of all turns that have occurred in this session
|
||||
:param started_at: Timestamp when the session was created
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
session_name: str
|
||||
turns: list[Turn]
|
||||
started_at: datetime
|
||||
|
||||
|
||||
class AgentToolGroupWithArgs(BaseModel):
|
||||
name: str
|
||||
args: dict[str, Any]
|
||||
|
||||
|
||||
AgentToolGroup = str | AgentToolGroupWithArgs
|
||||
register_schema(AgentToolGroup, name="AgentTool")
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_config: ToolConfig | None = Field(default=None)
|
||||
|
||||
max_infer_iters: int | None = 10
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if self.tool_config:
|
||||
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
|
||||
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
||||
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
||||
else:
|
||||
params = {}
|
||||
if self.tool_choice:
|
||||
params["tool_choice"] = self.tool_choice
|
||||
if self.tool_prompt_format:
|
||||
params["tool_prompt_format"] = self.tool_prompt_format
|
||||
self.tool_config = ToolConfig(**params)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentConfig(AgentConfigCommon):
|
||||
"""Configuration for an agent.
|
||||
|
||||
:param model: The model identifier to use for the agent
|
||||
:param instructions: The system instructions for the agent
|
||||
:param name: Optional name for the agent, used in telemetry and identification
|
||||
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
|
||||
:param response_format: Optional response format configuration
|
||||
"""
|
||||
|
||||
model: str
|
||||
instructions: str
|
||||
name: str | None = None
|
||||
enable_session_persistence: bool | None = False
|
||||
response_format: ResponseFormat | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Agent(BaseModel):
|
||||
"""An agent instance with configuration and metadata.
|
||||
|
||||
:param agent_id: Unique identifier for the agent
|
||||
:param agent_config: Configuration settings for the agent
|
||||
:param created_at: Timestamp when the agent was created
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
agent_config: AgentConfig
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: str | None = None
|
||||
|
||||
|
||||
class AgentTurnResponseEventType(StrEnum):
|
||||
step_start = "step_start"
|
||||
step_complete = "step_complete"
|
||||
step_progress = "step_progress"
|
||||
|
||||
turn_start = "turn_start"
|
||||
turn_complete = "turn_complete"
|
||||
turn_awaiting_input = "turn_awaiting_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||
"""Payload for step start events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param step_type: Type of step being executed
|
||||
:param step_id: Unique identifier for the step within a turn
|
||||
:param metadata: (Optional) Additional metadata for the step
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||
"""Payload for step completion events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param step_type: Type of step being executed
|
||||
:param step_id: Unique identifier for the step within a turn
|
||||
:param step_details: Complete details of the executed step
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
step_details: Step
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||
"""Payload for step progress events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param step_type: Type of step being executed
|
||||
:param step_id: Unique identifier for the step within a turn
|
||||
:param delta: Incremental content changes during step execution
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
delta: ContentDelta
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||
"""Payload for turn start events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param turn_id: Unique identifier for the turn within a session
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
|
||||
turn_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||
"""Payload for turn completion events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param turn: Complete turn data including all steps and results
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
|
||||
turn: Turn
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||
"""Payload for turn awaiting input events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param turn: Turn data when waiting for external tool responses
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
|
||||
turn: Turn
|
||||
|
||||
|
||||
AgentTurnResponseEventPayload = Annotated[
|
||||
AgentTurnResponseStepStartPayload
|
||||
| AgentTurnResponseStepProgressPayload
|
||||
| AgentTurnResponseStepCompletePayload
|
||||
| AgentTurnResponseTurnStartPayload
|
||||
| AgentTurnResponseTurnCompletePayload
|
||||
| AgentTurnResponseTurnAwaitingInputPayload,
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseEvent(BaseModel):
|
||||
"""An event in an agent turn response stream.
|
||||
|
||||
:param payload: Event-specific payload containing event data
|
||||
"""
|
||||
|
||||
payload: AgentTurnResponseEventPayload
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentCreateResponse(BaseModel):
|
||||
"""Response returned when creating a new agent.
|
||||
|
||||
:param agent_id: Unique identifier for the created agent
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentSessionCreateResponse(BaseModel):
|
||||
"""Response returned when creating a new agent session.
|
||||
|
||||
:param session_id: Unique identifier for the created session
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||
"""Request to create a new turn for an agent.
|
||||
|
||||
:param agent_id: Unique identifier for the agent
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param messages: List of messages to start the turn with
|
||||
:param documents: (Optional) List of documents to provide to the agent
|
||||
:param toolgroups: (Optional) List of tool groups to make available for this turn
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
:param tool_config: (Optional) Tool configuration to override agent defaults
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
session_id: str
|
||||
|
||||
# TODO: figure out how we can simplify this and make why
|
||||
# ToolResponseMessage needs to be here (it is function call
|
||||
# execution from outside the system)
|
||||
messages: list[UserMessage | ToolResponseMessage]
|
||||
|
||||
documents: list[Document] | None = None
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
|
||||
stream: bool | None = False
|
||||
tool_config: ToolConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResumeRequest(BaseModel):
|
||||
"""Request to resume an agent turn with tool responses.
|
||||
|
||||
:param agent_id: Unique identifier for the agent
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param turn_id: Unique identifier for the turn within a session
|
||||
:param tool_responses: List of tool responses to submit to continue the turn
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: list[ToolResponse]
|
||||
stream: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStreamChunk(BaseModel):
|
||||
"""Streamed agent turn completion response.
|
||||
|
||||
:param event: Individual event in the agent turn response stream
|
||||
"""
|
||||
|
||||
event: AgentTurnResponseEvent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentStepResponse(BaseModel):
|
||||
"""Response containing details of a specific agent step.
|
||||
|
||||
:param step: The complete step data and execution details
|
||||
"""
|
||||
|
||||
step: Step
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agents(Protocol):
|
||||
"""Agents
|
||||
|
||||
APIs for creating and interacting with agentic systems."""
|
||||
|
||||
@webmethod(
|
||||
route="/agents",
|
||||
method="POST",
|
||||
descriptive_name="create_agent",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse:
|
||||
"""Create an agent with the given configuration.
|
||||
|
||||
:param agent_config: The configuration for the agent.
|
||||
:returns: An AgentCreateResponse with the agent ID.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_turn",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||
"""Create a new turn for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the turn for.
|
||||
:param session_id: The ID of the session to create the turn for.
|
||||
:param messages: List of messages to start the turn with.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param documents: (Optional) List of documents to create the turn with.
|
||||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
||||
:returns: If stream=False, returns a Turn object.
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
descriptive_name="resume_agent_turn",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: list[ToolResponse],
|
||||
stream: bool | None = False,
|
||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||
|
||||
:param agent_id: The ID of the agent to resume.
|
||||
:param session_id: The ID of the session to resume.
|
||||
:param turn_id: The ID of the turn to resume.
|
||||
:param tool_responses: The tool call responses to resume the turn with.
|
||||
:param stream: Whether to stream the response.
|
||||
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
) -> Turn:
|
||||
"""Retrieve an agent turn by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to get the turn for.
|
||||
:param session_id: The ID of the session to get the turn for.
|
||||
:param turn_id: The ID of the turn to get.
|
||||
:returns: A Turn.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_step(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
step_id: str,
|
||||
) -> AgentStepResponse:
|
||||
"""Retrieve an agent step by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to get the step for.
|
||||
:param session_id: The ID of the session to get the step for.
|
||||
:param turn_id: The ID of the turn to get the step for.
|
||||
:param step_id: The ID of the step to get.
|
||||
:returns: An AgentStepResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_session",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
"""Create a new session for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the session for.
|
||||
:param session_name: The name of the session to create.
|
||||
:returns: An AgentSessionCreateResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
"""Retrieve an agent session by its ID.
|
||||
|
||||
:param session_id: The ID of the session to get.
|
||||
:param agent_id: The ID of the agent to get the session for.
|
||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
||||
:returns: A Session.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def delete_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Delete an agent session by its ID and its associated turns.
|
||||
|
||||
:param session_id: The ID of the session to delete.
|
||||
:param agent_id: The ID of the agent to delete the session for.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def delete_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Delete an agent by its ID and its associated sessions and turns.
|
||||
|
||||
:param agent_id: The ID of the agent to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||
"""List all agents.
|
||||
|
||||
:param start_index: The index to start the pagination from.
|
||||
:param limit: The number of agents to return.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
"""Describe an agent by its ID.
|
||||
|
||||
:param agent_id: ID of the agent.
|
||||
:returns: An Agent of the agent.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""List all session(s) of a given agent.
|
||||
|
||||
:param agent_id: The ID of the agent to list sessions for.
|
||||
:param start_index: The index to start the pagination from.
|
||||
:param limit: The number of sessions to return.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
# We situate the OpenAI Responses API in the Agents API just like we did things
|
||||
# for Inference. The Responses API, in its intent, serves the same purpose as
|
||||
# the Agents API above -- it is essentially a lightweight "agentic loop" with
|
||||
|
|
@ -750,6 +87,7 @@ class Agents(Protocol):
|
|||
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
|
||||
),
|
||||
] = None,
|
||||
max_tool_calls: int | None = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a model response.
|
||||
|
||||
|
|
@ -760,6 +98,7 @@ class Agents(Protocol):
|
|||
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
|
||||
:param include: (Optional) Additional fields to include in the response.
|
||||
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
|
||||
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -403,7 +403,7 @@ class OpenAIResponseText(BaseModel):
|
|||
|
||||
|
||||
# Must match type Literals of OpenAIResponseInputToolWebSearch below
|
||||
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
|
||||
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11", "web_search_2025_08_26"]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -415,9 +415,12 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
|
|||
"""
|
||||
|
||||
# Must match values of WebSearchToolTypes above
|
||||
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
|
||||
"web_search"
|
||||
)
|
||||
type: (
|
||||
Literal["web_search"]
|
||||
| Literal["web_search_preview"]
|
||||
| Literal["web_search_preview_2025_03_11"]
|
||||
| Literal["web_search_2025_08_26"]
|
||||
) = "web_search"
|
||||
# TODO: actually use search_context_size somewhere...
|
||||
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
|
||||
# TODO: add user_location
|
||||
|
|
@ -591,6 +594,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
:param truncation: (Optional) Truncation strategy applied to the response
|
||||
:param usage: (Optional) Token usage information for the response
|
||||
:param instructions: (Optional) System message inserted into the model's context
|
||||
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response
|
||||
"""
|
||||
|
||||
created_at: int
|
||||
|
|
@ -612,6 +616,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
truncation: str | None = None
|
||||
usage: OpenAIResponseUsage | None = None
|
||||
instructions: str | None = None
|
||||
max_tool_calls: int | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class Benchmarks(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA, deprecated=True)
|
||||
async def register_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
|
@ -95,7 +95,7 @@ class Benchmarks(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA, deprecated=True)
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
"""Unregister a benchmark.
|
||||
|
||||
|
|
|
|||
|
|
@ -56,14 +56,6 @@ class ToolGroupNotFoundError(ResourceNotFoundError):
|
|||
super().__init__(toolgroup_name, "Tool Group", "client.toolgroups.list()")
|
||||
|
||||
|
||||
class SessionNotFoundError(ValueError):
|
||||
"""raised when Llama Stack cannot find a referenced session or access is denied"""
|
||||
|
||||
def __init__(self, session_name: str) -> None:
|
||||
message = f"Session '{session_name}' not found or access denied."
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelTypeError(TypeError):
|
||||
"""raised when a model is present but not the correct type"""
|
||||
|
||||
|
|
|
|||
|
|
@ -34,3 +34,44 @@ class PaginatedResponse(BaseModel):
|
|||
data: list[dict[str, Any]]
|
||||
has_more: bool
|
||||
url: str | None = None
|
||||
|
||||
|
||||
# This is a short term solution to allow inference API to return metrics
|
||||
# The ideal way to do this is to have a way for all response types to include metrics
|
||||
# and all metric events logged to the telemetry API to be included with the response
|
||||
# To do this, we will need to augment all response types with a metrics field.
|
||||
# We have hit a blocker from stainless SDK that prevents us from doing this.
|
||||
# The blocker is that if we were to augment the response types that have a data field
|
||||
# in them like so
|
||||
# class ListModelsResponse(BaseModel):
|
||||
# metrics: Optional[List[MetricEvent]] = None
|
||||
# data: List[Models]
|
||||
# ...
|
||||
# The client SDK will need to access the data by using a .data field, which is not
|
||||
# ergonomic. Stainless SDK does support unwrapping the response type, but it
|
||||
# requires that the response type to only have a single field.
|
||||
|
||||
# We will need a way in the client SDK to signal that the metrics are needed
|
||||
# and if they are needed, the client SDK has to return the full response type
|
||||
# without unwrapping it.
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricInResponse(BaseModel):
|
||||
"""A metric value included in API responses.
|
||||
:param metric: The name of the metric
|
||||
:param value: The numeric value of the metric
|
||||
:param unit: (Optional) The unit of measurement for the metric value
|
||||
"""
|
||||
|
||||
metric: str
|
||||
value: int | float
|
||||
unit: str | None = None
|
||||
|
||||
|
||||
class MetricResponseMixin(BaseModel):
|
||||
"""Mixin class for API responses that can include metrics.
|
||||
:param metrics: (Optional) List of metrics associated with the API response
|
||||
"""
|
||||
|
||||
metrics: list[MetricInResponse] | None = None
|
||||
|
|
|
|||
22
src/llama_stack/apis/common/tracing.py
Normal file
22
src/llama_stack/apis/common/tracing.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
def telemetry_traceable(cls):
|
||||
"""
|
||||
Mark a protocol for automatic tracing when telemetry is enabled.
|
||||
|
||||
This is a metadata-only decorator with no dependencies on core.
|
||||
Actual tracing is applied by core routers at runtime if telemetry is enabled.
|
||||
|
||||
Usage:
|
||||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class MyProtocol(Protocol):
|
||||
...
|
||||
"""
|
||||
cls.__marked_for_tracing__ = True
|
||||
return cls
|
||||
|
|
@ -103,17 +103,6 @@ class CompletionInputType(BaseModel):
|
|||
type: Literal["completion_input"] = "completion_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnInputType(BaseModel):
|
||||
"""Parameter type for agent turn input.
|
||||
|
||||
:param type: Discriminator type. Always "agent_turn_input"
|
||||
"""
|
||||
|
||||
# expects List[Message] for messages (may also include attachments?)
|
||||
type: Literal["agent_turn_input"] = "agent_turn_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DialogType(BaseModel):
|
||||
"""Parameter type for dialog data with semantic output labels.
|
||||
|
|
@ -135,8 +124,7 @@ ParamType = Annotated[
|
|||
| JsonType
|
||||
| UnionType
|
||||
| ChatCompletionInputType
|
||||
| CompletionInputType
|
||||
| AgentTurnInputType,
|
||||
| CompletionInputType,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ParamType, name="ParamType")
|
||||
|
|
|
|||
|
|
@ -6,26 +6,22 @@
|
|||
|
||||
from .conversations import (
|
||||
Conversation,
|
||||
ConversationCreateRequest,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemCreateRequest,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
ConversationUpdateRequest,
|
||||
Metadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Conversation",
|
||||
"ConversationCreateRequest",
|
||||
"ConversationDeletedResource",
|
||||
"ConversationItem",
|
||||
"ConversationItemCreateRequest",
|
||||
"ConversationItemDeletedResource",
|
||||
"ConversationItemList",
|
||||
"Conversations",
|
||||
"ConversationUpdateRequest",
|
||||
"Metadata",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -20,8 +20,8 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
Metadata = dict[str, str]
|
||||
|
|
@ -102,32 +102,6 @@ register_schema(ConversationItem, name="ConversationItem")
|
|||
# ]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationCreateRequest(BaseModel):
|
||||
"""Request body for creating a conversation."""
|
||||
|
||||
items: list[ConversationItem] | None = Field(
|
||||
default=[],
|
||||
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
max_length=20,
|
||||
)
|
||||
metadata: Metadata | None = Field(
|
||||
default={},
|
||||
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
|
||||
max_length=16,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationUpdateRequest(BaseModel):
|
||||
"""Request body for updating a conversation."""
|
||||
|
||||
metadata: Metadata = Field(
|
||||
...,
|
||||
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationDeletedResource(BaseModel):
|
||||
"""Response for deleted conversation."""
|
||||
|
|
@ -183,7 +157,7 @@ class ConversationItemDeletedResource(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
@telemetry_traceable
|
||||
class Conversations(Protocol):
|
||||
"""Conversations
|
||||
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class ListDatasetsResponse(BaseModel):
|
|||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
|
||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA, deprecated=True)
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
|
|
@ -235,7 +235,7 @@ class Datasets(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA, deprecated=True)
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
|||
|
|
@ -4,17 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -32,19 +31,7 @@ class ModelCandidate(BaseModel):
|
|||
system_message: SystemMessage | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentCandidate(BaseModel):
|
||||
"""An agent candidate for evaluation.
|
||||
|
||||
:param config: The configuration for the agent candidate.
|
||||
"""
|
||||
|
||||
type: Literal["agent"] = "agent"
|
||||
config: AgentConfig
|
||||
|
||||
|
||||
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
|
||||
register_schema(EvalCandidate, name="EvalCandidate")
|
||||
EvalCandidate = ModelCandidate
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ from fastapi import File, Form, Response, UploadFile
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
|
@ -102,7 +102,7 @@ class OpenAIFileDeleteResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
@telemetry_traceable
|
||||
class Files(Protocol):
|
||||
"""Files
|
||||
|
||||
|
|
|
|||
|
|
@ -1,43 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
)
|
||||
|
||||
|
||||
class LogEvent:
|
||||
def __init__(
|
||||
self,
|
||||
content: str = "",
|
||||
end: str = "\n",
|
||||
color="white",
|
||||
):
|
||||
self.content = content
|
||||
self.color = color
|
||||
self.end = "\n" if end is None else end
|
||||
|
||||
def print(self, flush=True):
|
||||
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
|
||||
|
||||
|
||||
class EventLogger:
|
||||
async def log(self, event_generator):
|
||||
async for chunk in event_generator:
|
||||
if isinstance(chunk, ChatCompletionResponseStreamChunk):
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.start:
|
||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
||||
elif event.event_type == ChatCompletionResponseEventType.progress:
|
||||
yield LogEvent(event.delta, color="yellow", end="")
|
||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||
yield LogEvent("")
|
||||
else:
|
||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
||||
yield LogEvent(chunk.completion_message.content, color="yellow")
|
||||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
|
|
@ -15,29 +15,18 @@ from typing import (
|
|||
)
|
||||
|
||||
from fastapi import Body
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.responses import (
|
||||
Order,
|
||||
)
|
||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.core.telemetry.telemetry import MetricResponseMixin
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
register_schema(ToolCall)
|
||||
register_schema(ToolDefinition)
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
|
|
@ -202,58 +191,6 @@ class ToolResponseMessage(BaseModel):
|
|||
content: InterleavedContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionMessage(BaseModel):
|
||||
"""A message containing the model's (assistant) response in a chat conversation.
|
||||
|
||||
:param role: Must be "assistant" to identify this as the model's response
|
||||
:param content: The content of the model's response
|
||||
:param stop_reason: Reason why the model stopped generating. Options are:
|
||||
- `StopReason.end_of_turn`: The model finished generating the entire response.
|
||||
- `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response.
|
||||
- `StopReason.out_of_tokens`: The model ran out of token budget.
|
||||
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
|
||||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: InterleavedContent
|
||||
stop_reason: StopReason
|
||||
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
|
||||
|
||||
|
||||
Message = Annotated[
|
||||
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
|
||||
Field(discriminator="role"),
|
||||
]
|
||||
register_schema(Message, name="Message")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolResponse(BaseModel):
|
||||
"""Response from a tool invocation.
|
||||
|
||||
:param call_id: Unique identifier for the tool call this response is for
|
||||
:param tool_name: Name of the tool that was invoked
|
||||
:param content: The response content from the tool
|
||||
:param metadata: (Optional) Additional metadata about the tool response
|
||||
"""
|
||||
|
||||
call_id: str
|
||||
tool_name: BuiltinTool | str
|
||||
content: InterleavedContent
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
class ToolChoice(Enum):
|
||||
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.
|
||||
|
||||
|
|
@ -290,22 +227,6 @@ class ChatCompletionResponseEventType(Enum):
|
|||
progress = "progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseEvent(BaseModel):
|
||||
"""An event during chat completion generation.
|
||||
|
||||
:param event_type: Type of the event
|
||||
:param delta: Content generated since last event. This can be one or more tokens, or a tool call.
|
||||
:param logprobs: Optional log probabilities for generated tokens
|
||||
:param stop_reason: Optional reason why generation stopped, if complete
|
||||
"""
|
||||
|
||||
event_type: ChatCompletionResponseEventType
|
||||
delta: ContentDelta
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
stop_reason: StopReason | None = None
|
||||
|
||||
|
||||
class ResponseFormatType(StrEnum):
|
||||
"""Types of formats for structured (guided) decoding.
|
||||
|
||||
|
|
@ -358,34 +279,6 @@ class CompletionRequest(BaseModel):
|
|||
logprobs: LogProbConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponse(MetricResponseMixin):
|
||||
"""Response from a completion request.
|
||||
|
||||
:param content: The generated completion text
|
||||
:param stop_reason: Reason why generation stopped
|
||||
:param logprobs: Optional log probabilities for generated tokens
|
||||
"""
|
||||
|
||||
content: str
|
||||
stop_reason: StopReason
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed completion response.
|
||||
|
||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||
:param stop_reason: Optional reason why generation stopped, if complete
|
||||
:param logprobs: Optional log probabilities for generated tokens
|
||||
"""
|
||||
|
||||
delta: str
|
||||
stop_reason: StopReason | None = None
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
class SystemMessageBehavior(Enum):
|
||||
"""Config for how to override the default system prompt.
|
||||
|
||||
|
|
@ -399,70 +292,6 @@ class SystemMessageBehavior(Enum):
|
|||
replace = "replace"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolConfig(BaseModel):
|
||||
"""Configuration for tool use.
|
||||
|
||||
:param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
|
||||
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
||||
:param system_message_behavior: (Optional) Config for how to override the default system prompt.
|
||||
- `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt.
|
||||
- `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string
|
||||
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||
"""
|
||||
|
||||
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
||||
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if isinstance(self.tool_choice, str):
|
||||
try:
|
||||
self.tool_choice = ToolChoice[self.tool_choice]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
# This is an internally used class
|
||||
@json_schema_type
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[Message]
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
|
||||
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
||||
|
||||
response_format: ResponseFormat | None = None
|
||||
stream: bool | None = False
|
||||
logprobs: LogProbConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed chat completion response.
|
||||
|
||||
:param event: The event containing the new content
|
||||
"""
|
||||
|
||||
event: ChatCompletionResponseEvent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponse(MetricResponseMixin):
|
||||
"""Response from a chat completion request.
|
||||
|
||||
:param completion_message: The complete response message
|
||||
:param logprobs: Optional log probabilities for generated tokens
|
||||
"""
|
||||
|
||||
completion_message: CompletionMessage
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
"""Response containing generated embeddings.
|
||||
|
|
@ -1160,7 +989,7 @@ class OpenAIEmbeddingsRequestWithExtraBody(BaseModel, extra="allow"):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
@telemetry_traceable
|
||||
class InferenceProvider(Protocol):
|
||||
"""
|
||||
This protocol defines the interface that should be implemented by all inference providers.
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class Inspect(Protocol):
|
|||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
|
||||
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
|
||||
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns all non-deprecated routes.
|
||||
:returns: Response containing information about all available routes.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ class OpenAIListModelsResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
@telemetry_traceable
|
||||
class Models(Protocol):
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
"""List all models.
|
||||
|
|
@ -136,7 +136,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -158,7 +158,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def unregister_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ from typing import Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
|
@ -92,7 +92,7 @@ class ListPromptsResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
@telemetry_traceable
|
||||
class Prompts(Protocol):
|
||||
"""Prompts
|
||||
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ from typing import Any, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ class ShieldStore(Protocol):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
@telemetry_traceable
|
||||
class Safety(Protocol):
|
||||
"""Safety
|
||||
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ class ScoringFunctions(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
|
|
@ -199,7 +199,9 @@ class ScoringFunctions(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(
|
||||
route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
||||
"""Unregister a scoring function.
|
||||
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ class ListShieldsResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
@telemetry_traceable
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
|
|
@ -67,7 +67,7 @@ class Shields(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
|
|
@ -85,7 +85,7 @@ class Shields(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
"""Unregister a shield.
|
||||
|
||||
|
|
|
|||
|
|
@ -5,18 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RRFRanker(BaseModel):
|
||||
"""
|
||||
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||
|
|
@ -30,7 +25,6 @@ class RRFRanker(BaseModel):
|
|||
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WeightedRanker(BaseModel):
|
||||
"""
|
||||
Weighted ranker configuration that combines vector and keyword scores.
|
||||
|
|
@ -55,10 +49,8 @@ Ranker = Annotated[
|
|||
RRFRanker | WeightedRanker,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Ranker, name="Ranker")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
"""
|
||||
A document to be used for document ingestion in the RAG Tool.
|
||||
|
|
@ -75,7 +67,6 @@ class RAGDocument(BaseModel):
|
|||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryResult(BaseModel):
|
||||
"""Result of a RAG query containing retrieved content and metadata.
|
||||
|
||||
|
|
@ -87,7 +78,6 @@ class RAGQueryResult(BaseModel):
|
|||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryGenerator(Enum):
|
||||
"""Types of query generators for RAG systems.
|
||||
|
||||
|
|
@ -101,7 +91,6 @@ class RAGQueryGenerator(Enum):
|
|||
custom = "custom"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGSearchMode(StrEnum):
|
||||
"""
|
||||
Search modes for RAG query retrieval:
|
||||
|
|
@ -115,7 +104,6 @@ class RAGSearchMode(StrEnum):
|
|||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||
"""Configuration for the default RAG query generator.
|
||||
|
||||
|
|
@ -127,7 +115,6 @@ class DefaultRAGQueryGeneratorConfig(BaseModel):
|
|||
separator: str = " "
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||
"""Configuration for the LLM-based RAG query generator.
|
||||
|
||||
|
|
@ -145,10 +132,8 @@ RAGQueryGeneratorConfig = Annotated[
|
|||
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the RAG query generation.
|
||||
|
|
@ -181,38 +166,3 @@ class RAGQueryConfig(BaseModel):
|
|||
if len(v) == 0:
|
||||
raise ValueError("chunk_template must not be empty")
|
||||
return v
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class RAGToolRuntime(Protocol):
|
||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def insert(
|
||||
self,
|
||||
documents: list[RAGDocument],
|
||||
vector_store_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
"""Index documents so they can be used by the RAG system.
|
||||
|
||||
:param documents: List of documents to index in the RAG system
|
||||
:param vector_store_id: ID of the vector database to store the document embeddings
|
||||
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_store_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
"""Query the RAG system for context; typically invoked by the agent.
|
||||
|
||||
:param content: The query content to search for in the indexed documents
|
||||
:param vector_store_ids: List of vector database IDs to search within
|
||||
:param query_config: (Optional) Configuration parameters for the query operation
|
||||
:returns: RAGQueryResult containing the retrieved content and metadata
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -11,13 +11,11 @@ from pydantic import BaseModel
|
|||
from typing_extensions import runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from .rag_tool import RAGToolRuntime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
|
|
@ -109,9 +107,9 @@ class ListToolDefsResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
@telemetry_traceable
|
||||
class ToolGroups(Protocol):
|
||||
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
|
@ -169,7 +167,7 @@ class ToolGroups(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def unregister_toolgroup(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
|
@ -191,12 +189,10 @@ class SpecialToolGroup(Enum):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
@telemetry_traceable
|
||||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore | None = None
|
||||
|
||||
rag_tool: RAGToolRuntime | None = None
|
||||
|
||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_runtime_tools(
|
||||
|
|
|
|||
|
|
@ -10,13 +10,13 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import Body
|
||||
from fastapi import Body, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.strong_typing.schema import register_schema
|
||||
|
||||
|
|
@ -224,10 +224,16 @@ class VectorStoreContent(BaseModel):
|
|||
|
||||
:param type: Content type, currently only "text" is supported
|
||||
:param text: The actual text content
|
||||
:param embedding: Optional embedding vector for this content chunk
|
||||
:param chunk_metadata: Optional chunk metadata
|
||||
:param metadata: Optional user-defined metadata
|
||||
"""
|
||||
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
embedding: list[float] | None = None
|
||||
chunk_metadata: ChunkMetadata | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -260,7 +266,7 @@ class VectorStoreSearchResponsePage(BaseModel):
|
|||
"""
|
||||
|
||||
object: str = "vector_store.search_results.page"
|
||||
search_query: str
|
||||
search_query: list[str]
|
||||
data: list[VectorStoreSearchResponse]
|
||||
has_more: bool = False
|
||||
next_page: str | None = None
|
||||
|
|
@ -280,6 +286,22 @@ class VectorStoreDeleteResponse(BaseModel):
|
|||
deleted: bool = True
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileContentResponse(BaseModel):
|
||||
"""Represents the parsed content of a vector store file.
|
||||
|
||||
:param object: The object type, which is always `vector_store.file_content.page`
|
||||
:param data: Parsed content of the file
|
||||
:param has_more: Indicates if there are more content pages to fetch
|
||||
:param next_page: The token for the next page, if any
|
||||
"""
|
||||
|
||||
object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page"
|
||||
data: list[VectorStoreContent]
|
||||
has_more: bool = False
|
||||
next_page: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreChunkingStrategyAuto(BaseModel):
|
||||
"""Automatic chunking strategy for vector store files.
|
||||
|
|
@ -395,22 +417,6 @@ class VectorStoreListFilesResponse(BaseModel):
|
|||
has_more: bool = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileContentsResponse(BaseModel):
|
||||
"""Response from retrieving the contents of a vector store file.
|
||||
|
||||
:param file_id: Unique identifier for the file
|
||||
:param filename: Name of the file
|
||||
:param attributes: Key-value attributes associated with the file
|
||||
:param content: List of content items from the file
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
attributes: dict[str, Any]
|
||||
content: list[VectorStoreContent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileDeleteResponse(BaseModel):
|
||||
"""Response from deleting a vector store file.
|
||||
|
|
@ -478,7 +484,7 @@ class OpenAICreateVectorStoreRequestWithExtraBody(BaseModel, extra="allow"):
|
|||
name: str | None = None
|
||||
file_ids: list[str] | None = None
|
||||
expires_after: dict[str, Any] | None = None
|
||||
chunking_strategy: dict[str, Any] | None = None
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
|
|
@ -502,7 +508,7 @@ class VectorStoreTable(Protocol):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
@telemetry_traceable
|
||||
class VectorIO(Protocol):
|
||||
vector_store_table: VectorStoreTable | None = None
|
||||
|
||||
|
|
@ -732,12 +738,16 @@ class VectorIO(Protocol):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
include_embeddings: Annotated[bool | None, Query(default=False)] = False,
|
||||
include_metadata: Annotated[bool | None, Query(default=False)] = False,
|
||||
) -> VectorStoreFileContentResponse:
|
||||
"""Retrieves the contents of a vector store file.
|
||||
|
||||
:param vector_store_id: The ID of the vector store containing the file to retrieve.
|
||||
:param file_id: The ID of the file to retrieve.
|
||||
:returns: A list of InterleavedContent representing the file contents.
|
||||
:param include_embeddings: Whether to include embedding vectors in the response.
|
||||
:param include_metadata: Whether to include chunk metadata in the response.
|
||||
:returns: File contents, optionally with embeddings and metadata based on query parameters.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
|
|||
|
|
@ -46,6 +46,10 @@ class StackListDeps(Subcommand):
|
|||
def _run_stack_list_deps_command(self, args: argparse.Namespace) -> None:
|
||||
# always keep implementation completely silo-ed away from CLI so CLI
|
||||
# can be fast to load and reduces dependencies
|
||||
if not args.config and not args.providers:
|
||||
self.parser.print_help()
|
||||
self.parser.exit()
|
||||
|
||||
from ._list_deps import run_stack_list_deps_command
|
||||
|
||||
return run_stack_list_deps_command(args)
|
||||
|
|
|
|||
|
|
@ -9,48 +9,69 @@ from pathlib import Path
|
|||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
|
||||
class StackListBuilds(Subcommand):
|
||||
"""List built stacks in .llama/distributions directory"""
|
||||
"""List available distributions (both built-in and custom)"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"list",
|
||||
prog="llama stack list",
|
||||
description="list the build stacks",
|
||||
description="list available distributions",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._list_stack_command)
|
||||
|
||||
def _get_distribution_dirs(self) -> dict[str, Path]:
|
||||
"""Return a dictionary of distribution names and their paths"""
|
||||
distributions = {}
|
||||
dist_dir = Path.home() / ".llama" / "distributions"
|
||||
def _get_distribution_dirs(self) -> dict[str, tuple[Path, str]]:
|
||||
"""Return a dictionary of distribution names and their paths with source type
|
||||
|
||||
Returns:
|
||||
dict mapping distro name to (path, source_type) where source_type is 'built-in' or 'custom'
|
||||
"""
|
||||
distributions = {}
|
||||
|
||||
# Get built-in distributions from source code
|
||||
distro_dir = Path(__file__).parent.parent.parent / "distributions"
|
||||
if distro_dir.exists():
|
||||
for stack_dir in distro_dir.iterdir():
|
||||
if stack_dir.is_dir() and not stack_dir.name.startswith(".") and not stack_dir.name.startswith("__"):
|
||||
distributions[stack_dir.name] = (stack_dir, "built-in")
|
||||
|
||||
# Get custom/run distributions from ~/.llama/distributions
|
||||
# These override built-in ones if they have the same name
|
||||
if DISTRIBS_BASE_DIR.exists():
|
||||
for stack_dir in DISTRIBS_BASE_DIR.iterdir():
|
||||
if stack_dir.is_dir() and not stack_dir.name.startswith("."):
|
||||
# Clean up the name (remove llamastack- prefix if present)
|
||||
name = stack_dir.name.replace("llamastack-", "")
|
||||
distributions[name] = (stack_dir, "custom")
|
||||
|
||||
if dist_dir.exists():
|
||||
for stack_dir in dist_dir.iterdir():
|
||||
if stack_dir.is_dir():
|
||||
distributions[stack_dir.name] = stack_dir
|
||||
return distributions
|
||||
|
||||
def _list_stack_command(self, args: argparse.Namespace) -> None:
|
||||
distributions = self._get_distribution_dirs()
|
||||
|
||||
if not distributions:
|
||||
print("No stacks found in ~/.llama/distributions")
|
||||
print("No distributions found")
|
||||
return
|
||||
|
||||
headers = ["Stack Name", "Path"]
|
||||
headers.extend(["Build Config", "Run Config"])
|
||||
headers = ["Stack Name", "Source", "Path", "Build Config", "Run Config"]
|
||||
rows = []
|
||||
for name, path in distributions.items():
|
||||
row = [name, str(path)]
|
||||
for name, (path, source_type) in sorted(distributions.items()):
|
||||
row = [name, source_type, str(path)]
|
||||
# Check for build and run config files
|
||||
build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No"
|
||||
run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No"
|
||||
# For built-in distributions, configs are named build.yaml and run.yaml
|
||||
# For custom distributions, configs are named {name}-build.yaml and {name}-run.yaml
|
||||
if source_type == "built-in":
|
||||
build_config = "Yes" if (path / "build.yaml").exists() else "No"
|
||||
run_config = "Yes" if (path / "run.yaml").exists() else "No"
|
||||
else:
|
||||
build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No"
|
||||
run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No"
|
||||
row.extend([build_config, run_config])
|
||||
rows.append(row)
|
||||
print_table(rows, headers, separate_rows=True)
|
||||
|
|
|
|||
|
|
@ -253,7 +253,7 @@ class StackRun(Subcommand):
|
|||
)
|
||||
return
|
||||
|
||||
ui_dir = REPO_ROOT / "llama_stack" / "ui"
|
||||
ui_dir = REPO_ROOT / "llama_stack_ui"
|
||||
logs_dir = Path("~/.llama/ui/logs").expanduser()
|
||||
try:
|
||||
# Create logs directory if it doesn't exist
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import importlib.resources
|
||||
import sys
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -12,9 +11,6 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.core.datatypes import BuildConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.utils.exec import run_command
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.distributions.template import DistributionTemplate
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
|
@ -101,64 +97,3 @@ def print_pip_install_help(config: BuildConfig):
|
|||
for special_dep in special_deps:
|
||||
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
|
||||
print()
|
||||
|
||||
|
||||
def build_image(
|
||||
build_config: BuildConfig,
|
||||
image_name: str,
|
||||
distro_or_config: str,
|
||||
run_config: str | None = None,
|
||||
):
|
||||
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
|
||||
|
||||
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
if build_config.external_apis_dir:
|
||||
external_apis = load_external_apis(build_config)
|
||||
if external_apis:
|
||||
for _, api_spec in external_apis.items():
|
||||
normal_deps.extend(api_spec.pip_packages)
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "core/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
"--distro-or-config",
|
||||
distro_or_config,
|
||||
"--image-name",
|
||||
image_name,
|
||||
"--container-base",
|
||||
container_base,
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
# When building from a config file (not a template), include the run config path in the
|
||||
# build arguments
|
||||
if run_config is not None:
|
||||
args.extend(["--run-config", run_config])
|
||||
else:
|
||||
script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
"--env-name",
|
||||
str(image_name),
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
# Always pass both arguments, even if empty, to maintain consistent positional arguments
|
||||
if special_deps:
|
||||
args.extend(["--optional-deps", "#".join(special_deps)])
|
||||
if external_provider_deps:
|
||||
args.extend(
|
||||
["--external-provider-deps", "#".join(external_provider_deps)]
|
||||
) # the script will install external provider module, get its deps, and install those too.
|
||||
|
||||
return_code = run_command(args)
|
||||
|
||||
if return_code != 0:
|
||||
log.error(
|
||||
f"Failed to build target {image_name} with return code {return_code}",
|
||||
)
|
||||
|
||||
return return_code
|
||||
|
|
|
|||
|
|
@ -203,16 +203,11 @@ class ConversationServiceImpl(Conversations):
|
|||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||
try:
|
||||
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_items",
|
||||
data={"created_at": created_at, "item_data": item_dict},
|
||||
where={"id": item_id},
|
||||
)
|
||||
await self.sql_store.upsert(
|
||||
table="conversation_items",
|
||||
data=item_record,
|
||||
conflict_columns=["id"],
|
||||
)
|
||||
|
||||
created_items.append(item_dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from llama_stack.apis.inspect import (
|
|||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
|
|
@ -46,8 +45,8 @@ class DistributionInspectImpl(Inspect):
|
|||
# Helper function to determine if a route should be included based on api_filter
|
||||
def should_include_route(webmethod) -> bool:
|
||||
if api_filter is None:
|
||||
# Default: only non-deprecated v1 APIs
|
||||
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
|
||||
# Default: only non-deprecated APIs
|
||||
return not webmethod.deprecated
|
||||
elif api_filter == "deprecated":
|
||||
# Special filter: show deprecated routes regardless of their actual level
|
||||
return bool(webmethod.deprecated)
|
||||
|
|
|
|||
|
|
@ -18,14 +18,21 @@ from typing import Any, TypeVar, Union, get_args, get_origin
|
|||
import httpx
|
||||
import yaml
|
||||
from fastapi import Response as FastAPIResponse
|
||||
from llama_stack_client import (
|
||||
NOT_GIVEN,
|
||||
APIResponse,
|
||||
AsyncAPIResponse,
|
||||
AsyncLlamaStackClient,
|
||||
AsyncStream,
|
||||
LlamaStackClient,
|
||||
)
|
||||
|
||||
try:
|
||||
from llama_stack_client import (
|
||||
NOT_GIVEN,
|
||||
APIResponse,
|
||||
AsyncAPIResponse,
|
||||
AsyncLlamaStackClient,
|
||||
AsyncStream,
|
||||
LlamaStackClient,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"llama-stack-client is not installed. Please install it with `uv pip install llama-stack[client]`."
|
||||
) from e
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
from termcolor import cprint
|
||||
|
|
@ -382,6 +389,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
# Pass through params that aren't already handled as path params
|
||||
if options.params:
|
||||
extra_query_params = {k: v for k, v in options.params.items() if k not in path_params}
|
||||
if extra_query_params:
|
||||
body["extra_query"] = extra_query_params
|
||||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||
|
|
|
|||
|
|
@ -397,6 +397,18 @@ async def instantiate_provider(
|
|||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
# Apply tracing if telemetry is enabled and any base class has __marked_for_tracing__ marker
|
||||
if run_config.telemetry.enabled:
|
||||
traced_classes = [
|
||||
base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False)
|
||||
]
|
||||
|
||||
if traced_classes:
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
for cls in traced_classes:
|
||||
trace_protocol(cls)
|
||||
|
||||
protocols = api_protocol_map_for_compliance_check(run_config)
|
||||
additional_protocols = additional_protocols_map()
|
||||
# TODO: check compliance for special tool groups
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ async def get_routing_table_impl(
|
|||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
|
||||
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
|
@ -92,5 +93,6 @@ async def get_auto_router_impl(
|
|||
api_to_dep_impl["safety_config"] = run_config.safety
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ class InferenceRouter(Inference):
|
|||
|
||||
response = await provider.openai_completion(params)
|
||||
response.model = request_model_id
|
||||
if self.telemetry_enabled:
|
||||
if self.telemetry_enabled and response.usage is not None:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
|
|
@ -253,7 +253,7 @@ class InferenceRouter(Inference):
|
|||
if self.store:
|
||||
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
||||
|
||||
if self.telemetry_enabled:
|
||||
if self.telemetry_enabled and response.usage is not None:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
|
@ -52,7 +52,7 @@ class SafetyRouter(Safety):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
|
|
|
|||
|
|
@ -8,14 +8,9 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -26,36 +21,6 @@ logger = get_logger(name=__name__, category="core::routers")
|
|||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
class RagToolImpl(RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: ToolGroupsRoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_store_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
|
||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
||||
return await provider.query(content, vector_store_ids, query_config)
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
documents: list[RAGDocument],
|
||||
vector_store_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
||||
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: ToolGroupsRoutingTable,
|
||||
|
|
@ -63,11 +28,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug("Initializing ToolRuntimeRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||
self.rag_tool = self.RagToolImpl(routing_table)
|
||||
for method in ("query", "insert"):
|
||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("ToolRuntimeRouter.initialize")
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -20,9 +20,11 @@ from llama_stack.apis.vector_io import (
|
|||
SearchRankingOptions,
|
||||
VectorIO,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreChunkingStrategyStatic,
|
||||
VectorStoreChunkingStrategyStaticConfig,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileBatchObject,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileContentResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFilesListInBatchResponse,
|
||||
|
|
@ -167,6 +169,13 @@ class VectorIORouter(VectorIO):
|
|||
if embedding_dimension is not None:
|
||||
params.model_extra["embedding_dimension"] = embedding_dimension
|
||||
|
||||
# Set chunking strategy explicitly if not provided
|
||||
if params.chunking_strategy is None or params.chunking_strategy.type == "auto":
|
||||
# actualize the chunking strategy to static
|
||||
params.chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||
static=VectorStoreChunkingStrategyStaticConfig()
|
||||
)
|
||||
|
||||
return await provider.openai_create_vector_store(params)
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
|
|
@ -238,6 +247,13 @@ class VectorIORouter(VectorIO):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||
|
||||
# Check if provider_id is being changed (not supported)
|
||||
if metadata and "provider_id" in metadata:
|
||||
current_store = await self.routing_table.get_object_by_identifier("vector_store", vector_store_id)
|
||||
if current_store and current_store.provider_id != metadata["provider_id"]:
|
||||
raise ValueError("provider_id cannot be changed after vector store creation")
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -283,6 +299,8 @@ class VectorIORouter(VectorIO):
|
|||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||
if chunking_strategy is None or chunking_strategy.type == "auto":
|
||||
chunking_strategy = VectorStoreChunkingStrategyStatic(static=VectorStoreChunkingStrategyStaticConfig())
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -327,12 +345,19 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
include_embeddings: bool | None = False,
|
||||
include_metadata: bool | None = False,
|
||||
) -> VectorStoreFileContentResponse:
|
||||
logger.debug(
|
||||
f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}, "
|
||||
f"include_embeddings={include_embeddings}, include_metadata={include_metadata}"
|
||||
)
|
||||
|
||||
return await self.routing_table.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
include_embeddings=include_embeddings,
|
||||
include_metadata=include_metadata,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.vector_io.vector_io import (
|
|||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileContentResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
|
|
@ -195,12 +195,17 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
include_embeddings: bool | None = False,
|
||||
include_metadata: bool | None = False,
|
||||
) -> VectorStoreFileContentResponse:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
include_embeddings=include_embeddings,
|
||||
include_metadata=include_metadata,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from aiohttp import hdrs
|
|||
from starlette.routing import Route
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
|
|
@ -25,33 +24,16 @@ RouteImpls = dict[str, PathImpl]
|
|||
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
return {
|
||||
SpecialToolGroup.rag_tool: RAGToolRuntime,
|
||||
}
|
||||
|
||||
|
||||
def get_all_api_routes(
|
||||
external_apis: dict[Api, ExternalApiSpec] | None = None,
|
||||
) -> dict[Api, list[tuple[Route, WebMethod]]]:
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map(external_apis)
|
||||
toolgroup_protocols = toolgroup_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
routes = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
# HACK ALERT
|
||||
if api == Api.tool_runtime:
|
||||
for tool_group in SpecialToolGroup:
|
||||
sub_protocol = toolgroup_protocols[tool_group]
|
||||
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
|
||||
for name, method in sub_protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||
|
||||
for name, method in protocol_methods:
|
||||
# Get all webmethods for this method (supports multiple decorators)
|
||||
webmethods = getattr(method, "__webmethods__", [])
|
||||
|
|
|
|||
|
|
@ -526,8 +526,8 @@ def extract_path_params(route: str) -> list[str]:
|
|||
|
||||
def remove_disabled_providers(obj):
|
||||
if isinstance(obj, dict):
|
||||
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
|
||||
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
|
||||
# Filter out items where provider_id is explicitly disabled or empty
|
||||
if "provider_id" in obj and obj["provider_id"] in ("__disabled__", "", None):
|
||||
return None
|
||||
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
|
||||
|
|
@ -78,7 +78,6 @@ class LlamaStack(
|
|||
Inspect,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
RAGToolRuntime,
|
||||
Files,
|
||||
Prompts,
|
||||
Conversations,
|
||||
|
|
|
|||
|
|
@ -163,47 +163,6 @@ class MetricEvent(EventCommon):
|
|||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricInResponse(BaseModel):
|
||||
"""A metric value included in API responses.
|
||||
:param metric: The name of the metric
|
||||
:param value: The numeric value of the metric
|
||||
:param unit: (Optional) The unit of measurement for the metric value
|
||||
"""
|
||||
|
||||
metric: str
|
||||
value: int | float
|
||||
unit: str | None = None
|
||||
|
||||
|
||||
# This is a short term solution to allow inference API to return metrics
|
||||
# The ideal way to do this is to have a way for all response types to include metrics
|
||||
# and all metric events logged to the telemetry API to be included with the response
|
||||
# To do this, we will need to augment all response types with a metrics field.
|
||||
# We have hit a blocker from stainless SDK that prevents us from doing this.
|
||||
# The blocker is that if we were to augment the response types that have a data field
|
||||
# in them like so
|
||||
# class ListModelsResponse(BaseModel):
|
||||
# metrics: Optional[List[MetricEvent]] = None
|
||||
# data: List[Models]
|
||||
# ...
|
||||
# The client SDK will need to access the data by using a .data field, which is not
|
||||
# ergonomic. Stainless SDK does support unwrapping the response type, but it
|
||||
# requires that the response type to only have a single field.
|
||||
|
||||
# We will need a way in the client SDK to signal that the metrics are needed
|
||||
# and if they are needed, the client SDK has to return the full response type
|
||||
# without unwrapping it.
|
||||
|
||||
|
||||
class MetricResponseMixin(BaseModel):
|
||||
"""Mixin class for API responses that can include metrics.
|
||||
:param metrics: (Optional) List of metrics associated with the API response
|
||||
"""
|
||||
|
||||
metrics: list[MetricInResponse] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StructuredLogType(Enum):
|
||||
"""The type of structured log event payload.
|
||||
|
|
@ -427,6 +386,7 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
|||
"counters": {},
|
||||
"gauges": {},
|
||||
"up_down_counters": {},
|
||||
"histograms": {},
|
||||
}
|
||||
_global_lock = threading.Lock()
|
||||
_TRACER_PROVIDER = None
|
||||
|
|
@ -540,6 +500,16 @@ class Telemetry:
|
|||
)
|
||||
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
|
||||
|
||||
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["histograms"]:
|
||||
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Histogram for {name}",
|
||||
)
|
||||
return cast(metrics.Histogram, _GLOBAL_STORAGE["histograms"][name])
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
# Add metric as an event to the current span
|
||||
try:
|
||||
|
|
@ -571,7 +541,16 @@ class Telemetry:
|
|||
# Log to OpenTelemetry meter if available
|
||||
if self.meter is None:
|
||||
return
|
||||
if isinstance(event.value, int):
|
||||
|
||||
# Use histograms for token-related metrics (per-request measurements)
|
||||
# Use counters for other cumulative metrics
|
||||
token_metrics = {"prompt_tokens", "completion_tokens", "total_tokens"}
|
||||
|
||||
if event.metric in token_metrics:
|
||||
# Token metrics are per-request measurements, use histogram
|
||||
histogram = self._get_or_create_histogram(event.metric, event.unit)
|
||||
histogram.record(event.value, attributes=_clean_attributes(event.attributes))
|
||||
elif isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=_clean_attributes(event.attributes))
|
||||
elif isinstance(event.value, float):
|
||||
|
|
|
|||
|
|
@ -129,6 +129,15 @@ def trace_protocol[T: type[Any]](cls: T) -> T:
|
|||
else:
|
||||
return sync_wrapper
|
||||
|
||||
# Wrap methods on the class itself (for classes applied at runtime)
|
||||
# Skip if already wrapped (indicated by __wrapped__ attribute)
|
||||
for name, method in vars(cls).items():
|
||||
if inspect.isfunction(method) and not name.startswith("_"):
|
||||
if not hasattr(method, "__wrapped__"):
|
||||
wrapped = trace_method(method)
|
||||
setattr(cls, name, wrapped) # noqa: B010
|
||||
|
||||
# Also set up __init_subclass__ for future subclasses
|
||||
original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None))
|
||||
|
||||
def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
# More info on playground configuration can be found here:
|
||||
# https://llama-stack.readthedocs.io/en/latest/playground
|
||||
|
||||
FROM python:3.12-slim
|
||||
WORKDIR /app
|
||||
COPY . /app/
|
||||
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
||||
/usr/local/bin/pip3 install -r requirements.txt
|
||||
EXPOSE 8501
|
||||
|
||||
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
# (Experimental) LLama Stack UI
|
||||
|
||||
## Docker Setup
|
||||
|
||||
:warning: This is a work in progress.
|
||||
|
||||
## Developer Setup
|
||||
|
||||
1. Start up Llama Stack API server. More details [here](https://llamastack.github.io/latest/getting_started/index.htmll).
|
||||
|
||||
```
|
||||
llama stack list-deps together | xargs -L1 uv pip install
|
||||
|
||||
llama stack run together
|
||||
```
|
||||
|
||||
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
|
||||
|
||||
```bash
|
||||
llama-stack-client datasets register \
|
||||
--dataset-id "mmlu" \
|
||||
--provider-id "huggingface" \
|
||||
--url "https://huggingface.co/datasets/llamastack/evals" \
|
||||
--metadata '{"path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train"}' \
|
||||
--schema '{"input_query": {"type": "string"}, "expected_answer": {"type": "string", "chat_completion_input": {"type": "string"}}}'
|
||||
```
|
||||
|
||||
```bash
|
||||
llama-stack-client benchmarks register \
|
||||
--eval-task-id meta-reference-mmlu \
|
||||
--provider-id meta-reference \
|
||||
--dataset-id mmlu \
|
||||
--scoring-functions basic::regex_parser_multiple_choice_answer
|
||||
```
|
||||
|
||||
3. Start Streamlit UI
|
||||
|
||||
```bash
|
||||
uv run --with ".[ui]" streamlit run llama_stack.core/ui/app.py
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Environment Variable | Description | Default Value |
|
||||
|----------------------------|------------------------------------|---------------------------|
|
||||
| LLAMA_STACK_ENDPOINT | The endpoint for the Llama Stack | http://localhost:8321 |
|
||||
| FIREWORKS_API_KEY | API key for Fireworks provider | (empty string) |
|
||||
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
|
||||
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
|
||||
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import streamlit as st
|
||||
|
||||
|
||||
def main():
|
||||
# Evaluation pages
|
||||
application_evaluation_page = st.Page(
|
||||
"page/evaluations/app_eval.py",
|
||||
title="Evaluations (Scoring)",
|
||||
icon="📊",
|
||||
default=False,
|
||||
)
|
||||
native_evaluation_page = st.Page(
|
||||
"page/evaluations/native_eval.py",
|
||||
title="Evaluations (Generation + Scoring)",
|
||||
icon="📊",
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Playground pages
|
||||
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
|
||||
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
||||
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
|
||||
|
||||
# Distribution pages
|
||||
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
|
||||
provider_page = st.Page(
|
||||
"page/distribution/providers.py",
|
||||
title="API Providers",
|
||||
icon="🔍",
|
||||
default=False,
|
||||
)
|
||||
|
||||
pg = st.navigation(
|
||||
{
|
||||
"Playground": [
|
||||
chat_page,
|
||||
rag_page,
|
||||
tool_page,
|
||||
application_evaluation_page,
|
||||
native_evaluation_page,
|
||||
],
|
||||
"Inspect": [provider_page, resources_page],
|
||||
},
|
||||
expanded=False,
|
||||
)
|
||||
pg.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
|
||||
class LlamaStackApi:
|
||||
def __init__(self):
|
||||
self.client = LlamaStackClient(
|
||||
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
|
||||
provider_data={
|
||||
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
|
||||
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
|
||||
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
|
||||
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
|
||||
"tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):
|
||||
"""Run scoring on a single row"""
|
||||
if not scoring_params:
|
||||
scoring_params = dict.fromkeys(scoring_function_ids)
|
||||
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
|
||||
|
||||
|
||||
llama_stack_api = LlamaStackApi()
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
|
||||
def process_dataset(file):
|
||||
if file is None:
|
||||
return "No file uploaded", None
|
||||
|
||||
try:
|
||||
# Determine file type and read accordingly
|
||||
file_ext = os.path.splitext(file.name)[1].lower()
|
||||
if file_ext == ".csv":
|
||||
df = pd.read_csv(file)
|
||||
elif file_ext in [".xlsx", ".xls"]:
|
||||
df = pd.read_excel(file)
|
||||
else:
|
||||
return "Unsupported file format. Please upload a CSV or Excel file.", None
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error processing file: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def data_url_from_file(file) -> str:
|
||||
file_content = file.getvalue()
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type = file.type
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return data_url
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def datasets():
|
||||
st.header("Datasets")
|
||||
|
||||
datasets_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()}
|
||||
if len(datasets_info) > 0:
|
||||
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
|
||||
st.json(datasets_info[selected_dataset], expanded=True)
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def benchmarks():
|
||||
# Benchmarks Section
|
||||
st.header("Benchmarks")
|
||||
|
||||
benchmarks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.benchmarks.list()}
|
||||
|
||||
if len(benchmarks_info) > 0:
|
||||
selected_benchmark = st.selectbox("Select an eval task", list(benchmarks_info.keys()), key="benchmark_inspect")
|
||||
st.json(benchmarks_info[selected_benchmark], expanded=True)
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def models():
|
||||
# Models Section
|
||||
st.header("Models")
|
||||
models_info = {m.id: m.model_dump() for m in llama_stack_api.client.models.list()}
|
||||
|
||||
selected_model = st.selectbox("Select a model", list(models_info.keys()))
|
||||
st.json(models_info[selected_model])
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def providers():
|
||||
st.header("🔍 API Providers")
|
||||
apis_providers_lst = llama_stack_api.client.providers.list()
|
||||
api_to_providers = {}
|
||||
for api_provider in apis_providers_lst:
|
||||
if api_provider.api in api_to_providers:
|
||||
api_to_providers[api_provider.api].append(api_provider)
|
||||
else:
|
||||
api_to_providers[api_provider.api] = [api_provider]
|
||||
|
||||
for api in api_to_providers.keys():
|
||||
st.markdown(f"###### {api}")
|
||||
st.dataframe([x.to_dict() for x in api_to_providers[api]], width=500)
|
||||
|
||||
|
||||
providers()
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from streamlit_option_menu import option_menu
|
||||
|
||||
from llama_stack.core.ui.page.distribution.datasets import datasets
|
||||
from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
|
||||
from llama_stack.core.ui.page.distribution.models import models
|
||||
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
|
||||
from llama_stack.core.ui.page.distribution.shields import shields
|
||||
|
||||
|
||||
def resources_page():
|
||||
options = [
|
||||
"Models",
|
||||
"Shields",
|
||||
"Scoring Functions",
|
||||
"Datasets",
|
||||
"Benchmarks",
|
||||
]
|
||||
icons = ["magic", "shield", "file-bar-graph", "database", "list-task"]
|
||||
selected_resource = option_menu(
|
||||
None,
|
||||
options,
|
||||
icons=icons,
|
||||
orientation="horizontal",
|
||||
styles={
|
||||
"nav-link": {
|
||||
"font-size": "12px",
|
||||
},
|
||||
},
|
||||
)
|
||||
if selected_resource == "Benchmarks":
|
||||
benchmarks()
|
||||
elif selected_resource == "Datasets":
|
||||
datasets()
|
||||
elif selected_resource == "Models":
|
||||
models()
|
||||
elif selected_resource == "Scoring Functions":
|
||||
scoring_functions()
|
||||
elif selected_resource == "Shields":
|
||||
shields()
|
||||
|
||||
|
||||
resources_page()
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def scoring_functions():
|
||||
st.header("Scoring Functions")
|
||||
|
||||
scoring_functions_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.scoring_functions.list()}
|
||||
|
||||
selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys()))
|
||||
st.json(scoring_functions_info[selected_scoring_function], expanded=True)
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def shields():
|
||||
# Shields Section
|
||||
st.header("Shields")
|
||||
|
||||
shields_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()}
|
||||
|
||||
selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
|
||||
st.json(shields_info[selected_shield])
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
from llama_stack.core.ui.modules.utils import process_dataset
|
||||
|
||||
|
||||
def application_evaluation_page():
|
||||
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
|
||||
st.title("📊 Evaluations (Scoring)")
|
||||
|
||||
# File uploader
|
||||
uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"])
|
||||
|
||||
if uploaded_file is None:
|
||||
st.error("No file uploaded")
|
||||
return
|
||||
|
||||
# Process uploaded file
|
||||
df = process_dataset(uploaded_file)
|
||||
if df is None:
|
||||
st.error("Error processing file")
|
||||
return
|
||||
|
||||
# Display dataset information
|
||||
st.success("Dataset loaded successfully!")
|
||||
|
||||
# Display dataframe preview
|
||||
st.subheader("Dataset Preview")
|
||||
st.dataframe(df)
|
||||
|
||||
# Select Scoring Functions to Run Evaluation On
|
||||
st.subheader("Select Scoring Functions")
|
||||
scoring_functions = llama_stack_api.client.scoring_functions.list()
|
||||
scoring_functions = {sf.identifier: sf for sf in scoring_functions}
|
||||
scoring_functions_names = list(scoring_functions.keys())
|
||||
selected_scoring_functions = st.multiselect(
|
||||
"Choose one or more scoring functions",
|
||||
options=scoring_functions_names,
|
||||
help="Choose one or more scoring functions.",
|
||||
)
|
||||
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [m.identifier for m in available_models]
|
||||
|
||||
scoring_params = {}
|
||||
if selected_scoring_functions:
|
||||
st.write("Selected:")
|
||||
for scoring_fn_id in selected_scoring_functions:
|
||||
scoring_fn = scoring_functions[scoring_fn_id]
|
||||
st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}")
|
||||
new_params = None
|
||||
if scoring_fn.params:
|
||||
new_params = {}
|
||||
for param_name, param_value in scoring_fn.params.to_dict().items():
|
||||
if param_name == "type":
|
||||
new_params[param_name] = param_value
|
||||
continue
|
||||
|
||||
if param_name == "judge_model":
|
||||
value = st.selectbox(
|
||||
f"Select **{param_name}** for {scoring_fn_id}",
|
||||
options=available_models,
|
||||
index=0,
|
||||
key=f"{scoring_fn_id}_{param_name}",
|
||||
)
|
||||
new_params[param_name] = value
|
||||
else:
|
||||
value = st.text_area(
|
||||
f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format",
|
||||
value=json.dumps(param_value, indent=2),
|
||||
height=80,
|
||||
)
|
||||
try:
|
||||
new_params[param_name] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
st.error(f"Invalid JSON for **{param_name}** in {scoring_fn_id}")
|
||||
|
||||
st.json(new_params)
|
||||
scoring_params[scoring_fn_id] = new_params
|
||||
|
||||
# Add run evaluation button & slider
|
||||
total_rows = len(df)
|
||||
num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows)
|
||||
|
||||
if st.button("Run Evaluation"):
|
||||
progress_text = "Running evaluation..."
|
||||
progress_bar = st.progress(0, text=progress_text)
|
||||
rows = df.to_dict(orient="records")
|
||||
if num_rows < total_rows:
|
||||
rows = rows[:num_rows]
|
||||
|
||||
# Create separate containers for progress text and results
|
||||
progress_text_container = st.empty()
|
||||
results_container = st.empty()
|
||||
output_res = {}
|
||||
for i, r in enumerate(rows):
|
||||
# Update progress
|
||||
progress = i / len(rows)
|
||||
progress_bar.progress(progress, text=progress_text)
|
||||
|
||||
# Run evaluation for current row
|
||||
score_res = llama_stack_api.run_scoring(
|
||||
r,
|
||||
scoring_function_ids=selected_scoring_functions,
|
||||
scoring_params=scoring_params,
|
||||
)
|
||||
|
||||
for k in r.keys():
|
||||
if k not in output_res:
|
||||
output_res[k] = []
|
||||
output_res[k].append(r[k])
|
||||
|
||||
for fn_id in selected_scoring_functions:
|
||||
if fn_id not in output_res:
|
||||
output_res[fn_id] = []
|
||||
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
|
||||
|
||||
# Display current row results using separate containers
|
||||
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
|
||||
results_container.json(
|
||||
score_res.to_json(),
|
||||
expanded=2,
|
||||
)
|
||||
|
||||
progress_bar.progress(1.0, text="Evaluation complete!")
|
||||
|
||||
# Display results in dataframe
|
||||
if output_res:
|
||||
output_df = pd.DataFrame(output_res)
|
||||
st.subheader("Evaluation Results")
|
||||
st.dataframe(output_df)
|
||||
|
||||
|
||||
application_evaluation_page()
|
||||
|
|
@ -1,253 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def select_benchmark_1():
|
||||
# Select Benchmarks
|
||||
st.subheader("1. Choose An Eval Task")
|
||||
benchmarks = llama_stack_api.client.benchmarks.list()
|
||||
benchmarks = {et.identifier: et for et in benchmarks}
|
||||
benchmarks_names = list(benchmarks.keys())
|
||||
selected_benchmark = st.selectbox(
|
||||
"Choose an eval task.",
|
||||
options=benchmarks_names,
|
||||
help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.",
|
||||
)
|
||||
with st.expander("View Eval Task"):
|
||||
st.json(benchmarks[selected_benchmark], expanded=True)
|
||||
|
||||
st.session_state["selected_benchmark"] = selected_benchmark
|
||||
st.session_state["benchmarks"] = benchmarks
|
||||
if st.button("Confirm", key="confirm_1"):
|
||||
st.session_state["selected_benchmark_1_next"] = True
|
||||
|
||||
|
||||
def define_eval_candidate_2():
|
||||
if not st.session_state.get("selected_benchmark_1_next", None):
|
||||
return
|
||||
|
||||
st.subheader("2. Define Eval Candidate")
|
||||
st.info(
|
||||
"""
|
||||
Define the configurations for the evaluation candidate model or agent used for generation.
|
||||
Select "model" if you want to run generation with inference API, or "agent" if you want to run generation with agent API through specifying AgentConfig.
|
||||
"""
|
||||
)
|
||||
with st.expander("Define Eval Candidate", expanded=True):
|
||||
# Define Eval Candidate
|
||||
candidate_type = st.radio("Candidate Type", ["model", "agent"])
|
||||
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [model.identifier for model in available_models]
|
||||
selected_model = st.selectbox(
|
||||
"Choose a model",
|
||||
available_models,
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Sampling Parameters
|
||||
st.markdown("##### Sampling Parameters")
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.0,
|
||||
step=0.1,
|
||||
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
||||
)
|
||||
top_p = st.slider(
|
||||
"Top P",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.95,
|
||||
step=0.1,
|
||||
)
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
max_value=4096,
|
||||
value=512,
|
||||
step=1,
|
||||
help="The maximum number of tokens to generate",
|
||||
)
|
||||
repetition_penalty = st.slider(
|
||||
"Repetition Penalty",
|
||||
min_value=1.0,
|
||||
max_value=2.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
|
||||
)
|
||||
if candidate_type == "model":
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
eval_candidate = {
|
||||
"type": "model",
|
||||
"model": selected_model,
|
||||
"sampling_params": {
|
||||
"strategy": strategy,
|
||||
"max_tokens": max_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
},
|
||||
}
|
||||
elif candidate_type == "agent":
|
||||
system_prompt = st.text_area(
|
||||
"System Prompt",
|
||||
value="You are a helpful AI assistant.",
|
||||
help="Initial instructions given to the AI to set its behavior and context",
|
||||
)
|
||||
tools_json = st.text_area(
|
||||
"Tools Configuration (JSON)",
|
||||
value=json.dumps(
|
||||
[
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "brave",
|
||||
"api_key": "ENTER_BRAVE_API_KEY_HERE",
|
||||
}
|
||||
]
|
||||
),
|
||||
help="Enter tool configurations in JSON format. Each tool should have a name, description, and parameters.",
|
||||
height=200,
|
||||
)
|
||||
try:
|
||||
tools = json.loads(tools_json)
|
||||
except json.JSONDecodeError:
|
||||
st.error("Invalid JSON format for tools configuration")
|
||||
tools = []
|
||||
eval_candidate = {
|
||||
"type": "agent",
|
||||
"config": {
|
||||
"model": selected_model,
|
||||
"instructions": system_prompt,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"tool_prompt_format": "json",
|
||||
"input_shields": [],
|
||||
"output_shields": [],
|
||||
"enable_session_persistence": False,
|
||||
},
|
||||
}
|
||||
st.session_state["eval_candidate"] = eval_candidate
|
||||
|
||||
if st.button("Confirm", key="confirm_2"):
|
||||
st.session_state["selected_eval_candidate_2_next"] = True
|
||||
|
||||
|
||||
def run_evaluation_3():
|
||||
if not st.session_state.get("selected_eval_candidate_2_next", None):
|
||||
return
|
||||
|
||||
st.subheader("3. Run Evaluation")
|
||||
# Add info box to explain configurations being used
|
||||
st.info(
|
||||
"""
|
||||
Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button.
|
||||
"""
|
||||
)
|
||||
selected_benchmark = st.session_state["selected_benchmark"]
|
||||
benchmarks = st.session_state["benchmarks"]
|
||||
eval_candidate = st.session_state["eval_candidate"]
|
||||
|
||||
dataset_id = benchmarks[selected_benchmark].dataset_id
|
||||
rows = llama_stack_api.client.datasets.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
total_rows = len(rows.data)
|
||||
# Add number of examples control
|
||||
num_rows = st.number_input(
|
||||
"Number of Examples to Evaluate",
|
||||
min_value=1,
|
||||
max_value=total_rows,
|
||||
value=5,
|
||||
help="Number of examples from the dataset to evaluate. ",
|
||||
)
|
||||
|
||||
benchmark_config = {
|
||||
"type": "benchmark",
|
||||
"eval_candidate": eval_candidate,
|
||||
"scoring_params": {},
|
||||
}
|
||||
|
||||
with st.expander("View Evaluation Task", expanded=True):
|
||||
st.json(benchmarks[selected_benchmark], expanded=True)
|
||||
with st.expander("View Evaluation Task Configuration", expanded=True):
|
||||
st.json(benchmark_config, expanded=True)
|
||||
|
||||
# Add run button and handle evaluation
|
||||
if st.button("Run Evaluation"):
|
||||
progress_text = "Running evaluation..."
|
||||
progress_bar = st.progress(0, text=progress_text)
|
||||
rows = rows.data
|
||||
if num_rows < total_rows:
|
||||
rows = rows[:num_rows]
|
||||
|
||||
# Create separate containers for progress text and results
|
||||
progress_text_container = st.empty()
|
||||
results_container = st.empty()
|
||||
output_res = {}
|
||||
for i, r in enumerate(rows):
|
||||
# Update progress
|
||||
progress = i / len(rows)
|
||||
progress_bar.progress(progress, text=progress_text)
|
||||
# Run evaluation for current row
|
||||
eval_res = llama_stack_api.client.eval.evaluate_rows(
|
||||
benchmark_id=selected_benchmark,
|
||||
input_rows=[r],
|
||||
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
for k in r.keys():
|
||||
if k not in output_res:
|
||||
output_res[k] = []
|
||||
output_res[k].append(r[k])
|
||||
|
||||
for k in eval_res.generations[0].keys():
|
||||
if k not in output_res:
|
||||
output_res[k] = []
|
||||
output_res[k].append(eval_res.generations[0][k])
|
||||
|
||||
for scoring_fn in benchmarks[selected_benchmark].scoring_functions:
|
||||
if scoring_fn not in output_res:
|
||||
output_res[scoring_fn] = []
|
||||
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
||||
|
||||
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
|
||||
results_container.json(eval_res, expanded=2)
|
||||
|
||||
progress_bar.progress(1.0, text="Evaluation complete!")
|
||||
# Display results in dataframe
|
||||
if output_res:
|
||||
output_df = pd.DataFrame(output_res)
|
||||
st.subheader("Evaluation Results")
|
||||
st.dataframe(output_df)
|
||||
|
||||
|
||||
def native_evaluation_page():
|
||||
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
|
||||
st.title("📊 Evaluations (Generation + Scoring)")
|
||||
|
||||
select_benchmark_1()
|
||||
define_eval_candidate_2()
|
||||
run_evaluation_3()
|
||||
|
||||
|
||||
native_evaluation_page()
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
# Sidebar configurations
|
||||
with st.sidebar:
|
||||
st.header("Configuration")
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [
|
||||
model.id
|
||||
for model in available_models
|
||||
if model.custom_metadata and model.custom_metadata.get("model_type") == "llm"
|
||||
]
|
||||
selected_model = st.selectbox(
|
||||
"Choose a model",
|
||||
available_models,
|
||||
index=0,
|
||||
)
|
||||
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.0,
|
||||
step=0.1,
|
||||
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
||||
)
|
||||
|
||||
top_p = st.slider(
|
||||
"Top P",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.95,
|
||||
step=0.1,
|
||||
)
|
||||
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
max_value=4096,
|
||||
value=512,
|
||||
step=1,
|
||||
help="The maximum number of tokens to generate",
|
||||
)
|
||||
|
||||
repetition_penalty = st.slider(
|
||||
"Repetition Penalty",
|
||||
min_value=1.0,
|
||||
max_value=2.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
|
||||
)
|
||||
|
||||
stream = st.checkbox("Stream", value=True)
|
||||
system_prompt = st.text_area(
|
||||
"System Prompt",
|
||||
value="You are a helpful AI assistant.",
|
||||
help="Initial instructions given to the AI to set its behavior and context",
|
||||
)
|
||||
|
||||
# Add clear chat button to sidebar
|
||||
if st.button("Clear Chat", use_container_width=True):
|
||||
st.session_state.messages = []
|
||||
st.rerun()
|
||||
|
||||
|
||||
# Main chat interface
|
||||
st.title("🦙 Chat")
|
||||
|
||||
|
||||
# Initialize chat history
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
|
||||
# Display chat messages
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
|
||||
# Chat input
|
||||
if prompt := st.chat_input("Example: What is Llama Stack?"):
|
||||
# Add user message to chat history
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Display user message
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
# Display assistant response
|
||||
with st.chat_message("assistant"):
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
response = llama_stack_api.client.inference.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
model_id=selected_model,
|
||||
stream=stream,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
"max_tokens": max_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
},
|
||||
)
|
||||
|
||||
if stream:
|
||||
for chunk in response:
|
||||
if chunk.event.event_type == "progress":
|
||||
full_response += chunk.event.delta.text
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
else:
|
||||
full_response = response.completion_message.content
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
|
|
@ -1,352 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import enum
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client import Agent
|
||||
from llama_stack_client.lib.agents.react.agent import ReActAgent
|
||||
from llama_stack_client.lib.agents.react.tool_parser import ReActOutput
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
class AgentType(enum.Enum):
|
||||
REGULAR = "Regular"
|
||||
REACT = "ReAct"
|
||||
|
||||
|
||||
def tool_chat_page():
|
||||
st.title("🛠 Tools")
|
||||
|
||||
client = llama_stack_api.client
|
||||
models = client.models.list()
|
||||
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
|
||||
|
||||
tool_groups = client.toolgroups.list()
|
||||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||
selected_vector_stores = []
|
||||
|
||||
def reset_agent():
|
||||
st.session_state.clear()
|
||||
st.cache_resource.clear()
|
||||
|
||||
with st.sidebar:
|
||||
st.title("Configuration")
|
||||
st.subheader("Model")
|
||||
model = st.selectbox(label="Model", options=model_list, on_change=reset_agent, label_visibility="collapsed")
|
||||
|
||||
st.subheader("Available ToolGroups")
|
||||
|
||||
toolgroup_selection = st.pills(
|
||||
label="Built-in tools",
|
||||
options=builtin_tools_list,
|
||||
selection_mode="multi",
|
||||
on_change=reset_agent,
|
||||
format_func=lambda tool: "".join(tool.split("::")[1:]),
|
||||
help="List of built-in tools from your llama stack server.",
|
||||
)
|
||||
|
||||
if "builtin::rag" in toolgroup_selection:
|
||||
vector_stores = llama_stack_api.client.vector_stores.list() or []
|
||||
if not vector_stores:
|
||||
st.info("No vector databases available for selection.")
|
||||
vector_stores = [vector_store.identifier for vector_store in vector_stores]
|
||||
selected_vector_stores = st.multiselect(
|
||||
label="Select Document Collections to use in RAG queries",
|
||||
options=vector_stores,
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
mcp_selection = st.pills(
|
||||
label="MCP Servers",
|
||||
options=mcp_tools_list,
|
||||
selection_mode="multi",
|
||||
on_change=reset_agent,
|
||||
format_func=lambda tool: "".join(tool.split("::")[1:]),
|
||||
help="List of MCP servers registered to your llama stack server.",
|
||||
)
|
||||
|
||||
toolgroup_selection.extend(mcp_selection)
|
||||
|
||||
grouped_tools = {}
|
||||
total_tools = 0
|
||||
|
||||
for toolgroup_id in toolgroup_selection:
|
||||
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
||||
grouped_tools[toolgroup_id] = [tool.name for tool in tools]
|
||||
total_tools += len(tools)
|
||||
|
||||
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
||||
|
||||
for group_id, tools in grouped_tools.items():
|
||||
with st.expander(f"🔧 Tools from `{group_id}`"):
|
||||
for idx, tool in enumerate(tools, start=1):
|
||||
st.markdown(f"{idx}. `{tool.split(':')[-1]}`")
|
||||
|
||||
st.subheader("Agent Configurations")
|
||||
st.subheader("Agent Type")
|
||||
agent_type = st.radio(
|
||||
label="Select Agent Type",
|
||||
options=["Regular", "ReAct"],
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
if agent_type == "ReAct":
|
||||
agent_type = AgentType.REACT
|
||||
else:
|
||||
agent_type = AgentType.REGULAR
|
||||
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
max_value=4096,
|
||||
value=512,
|
||||
step=64,
|
||||
help="The maximum number of tokens to generate",
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
for i, tool_name in enumerate(toolgroup_selection):
|
||||
if tool_name == "builtin::rag":
|
||||
tool_dict = dict(
|
||||
name="builtin::rag",
|
||||
args={
|
||||
"vector_store_ids": list(selected_vector_stores),
|
||||
},
|
||||
)
|
||||
toolgroup_selection[i] = tool_dict
|
||||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
||||
return ReActAgent(
|
||||
client=client,
|
||||
model=model,
|
||||
tools=toolgroup_selection,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": ReActOutput.model_json_schema(),
|
||||
},
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
else:
|
||||
return Agent(
|
||||
client,
|
||||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
st.session_state.agent_type = agent_type
|
||||
|
||||
agent = create_agent()
|
||||
|
||||
if "agent_session_id" not in st.session_state:
|
||||
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
|
||||
|
||||
session_id = st.session_state["agent_session_id"]
|
||||
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
||||
|
||||
for msg in st.session_state.messages:
|
||||
with st.chat_message(msg["role"]):
|
||||
st.markdown(msg["content"])
|
||||
|
||||
if prompt := st.chat_input(placeholder=""):
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
turn_response = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def response_generator(turn_response):
|
||||
if st.session_state.get("agent_type") == AgentType.REACT:
|
||||
return _handle_react_response(turn_response)
|
||||
else:
|
||||
return _handle_regular_response(turn_response)
|
||||
|
||||
def _handle_react_response(turn_response):
|
||||
current_step_content = ""
|
||||
final_answer = None
|
||||
tool_results = []
|
||||
|
||||
for response in turn_response:
|
||||
if not hasattr(response.event, "payload"):
|
||||
yield (
|
||||
"\n\n🚨 :red[_Llama Stack server Error:_]\n"
|
||||
"The response received is missing an expected `payload` attribute.\n"
|
||||
"This could indicate a malformed response or an internal issue within the server.\n\n"
|
||||
f"Error details: {response}"
|
||||
)
|
||||
return
|
||||
|
||||
payload = response.event.payload
|
||||
|
||||
if payload.event_type == "step_progress" and hasattr(payload.delta, "text"):
|
||||
current_step_content += payload.delta.text
|
||||
continue
|
||||
|
||||
if payload.event_type == "step_complete":
|
||||
step_details = payload.step_details
|
||||
|
||||
if step_details.step_type == "inference":
|
||||
yield from _process_inference_step(current_step_content, tool_results, final_answer)
|
||||
current_step_content = ""
|
||||
elif step_details.step_type == "tool_execution":
|
||||
tool_results = _process_tool_execution(step_details, tool_results)
|
||||
current_step_content = ""
|
||||
else:
|
||||
current_step_content = ""
|
||||
|
||||
if not final_answer and tool_results:
|
||||
yield from _format_tool_results_summary(tool_results)
|
||||
|
||||
def _process_inference_step(current_step_content, tool_results, final_answer):
|
||||
try:
|
||||
react_output_data = json.loads(current_step_content)
|
||||
thought = react_output_data.get("thought")
|
||||
action = react_output_data.get("action")
|
||||
answer = react_output_data.get("answer")
|
||||
|
||||
if answer and answer != "null" and answer is not None:
|
||||
final_answer = answer
|
||||
|
||||
if thought:
|
||||
with st.expander("🤔 Thinking...", expanded=False):
|
||||
st.markdown(f":grey[__{thought}__]")
|
||||
|
||||
if action and isinstance(action, dict):
|
||||
tool_name = action.get("tool_name")
|
||||
tool_params = action.get("tool_params")
|
||||
with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False):
|
||||
st.json(tool_params)
|
||||
|
||||
if answer and answer != "null" and answer is not None:
|
||||
yield f"\n\n✅ **Final Answer:**\n{answer}"
|
||||
|
||||
except json.JSONDecodeError:
|
||||
yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```"
|
||||
except Exception as e:
|
||||
yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```"
|
||||
|
||||
return final_answer
|
||||
|
||||
def _process_tool_execution(step_details, tool_results):
|
||||
try:
|
||||
if hasattr(step_details, "tool_responses") and step_details.tool_responses:
|
||||
for tool_response in step_details.tool_responses:
|
||||
tool_name = tool_response.tool_name
|
||||
content = tool_response.content
|
||||
tool_results.append((tool_name, content))
|
||||
with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False):
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
st.json(parsed_content)
|
||||
except json.JSONDecodeError:
|
||||
st.code(content, language=None)
|
||||
else:
|
||||
with st.expander("⚙️ Observation", expanded=False):
|
||||
st.markdown(":grey[_Tool execution step completed, but no response data found._]")
|
||||
except Exception as e:
|
||||
with st.expander("⚙️ Error in Tool Execution", expanded=False):
|
||||
st.markdown(f":red[_Error processing tool execution: {str(e)}_]")
|
||||
|
||||
return tool_results
|
||||
|
||||
def _format_tool_results_summary(tool_results):
|
||||
yield "\n\n**Here's what I found:**\n"
|
||||
for tool_name, content in tool_results:
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
|
||||
if tool_name == "web_search" and "top_k" in parsed_content:
|
||||
yield from _format_web_search_results(parsed_content)
|
||||
elif "results" in parsed_content and isinstance(parsed_content["results"], list):
|
||||
yield from _format_results_list(parsed_content["results"])
|
||||
elif isinstance(parsed_content, dict) and len(parsed_content) > 0:
|
||||
yield from _format_dict_results(parsed_content)
|
||||
elif isinstance(parsed_content, list) and len(parsed_content) > 0:
|
||||
yield from _format_list_results(parsed_content)
|
||||
except json.JSONDecodeError:
|
||||
yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n"
|
||||
except (TypeError, AttributeError, KeyError, IndexError) as e:
|
||||
print(f"Error processing {tool_name} result: {type(e).__name__}: {e}")
|
||||
|
||||
def _format_web_search_results(parsed_content):
|
||||
for i, result in enumerate(parsed_content["top_k"], 1):
|
||||
if i <= 3:
|
||||
title = result.get("title", "Untitled")
|
||||
url = result.get("url", "")
|
||||
content_text = result.get("content", "").strip()
|
||||
yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n"
|
||||
|
||||
def _format_results_list(results):
|
||||
for i, result in enumerate(results, 1):
|
||||
if i <= 3:
|
||||
if isinstance(result, dict):
|
||||
name = result.get("name", result.get("title", "Result " + str(i)))
|
||||
description = result.get("description", result.get("content", result.get("summary", "")))
|
||||
yield f"\n- **{name}**\n {description}\n"
|
||||
else:
|
||||
yield f"\n- {result}\n"
|
||||
|
||||
def _format_dict_results(parsed_content):
|
||||
yield "\n```\n"
|
||||
for key, value in list(parsed_content.items())[:5]:
|
||||
if isinstance(value, str) and len(value) < 100:
|
||||
yield f"{key}: {value}\n"
|
||||
else:
|
||||
yield f"{key}: [Complex data]\n"
|
||||
yield "```\n"
|
||||
|
||||
def _format_list_results(parsed_content):
|
||||
yield "\n"
|
||||
for _, item in enumerate(parsed_content[:3], 1):
|
||||
if isinstance(item, str):
|
||||
yield f"- {item}\n"
|
||||
elif isinstance(item, dict) and "text" in item:
|
||||
yield f"- {item['text']}\n"
|
||||
elif isinstance(item, dict) and len(item) > 0:
|
||||
first_value = next(iter(item.values()))
|
||||
if isinstance(first_value, str) and len(first_value) < 100:
|
||||
yield f"- {first_value}\n"
|
||||
|
||||
def _handle_regular_response(turn_response):
|
||||
for response in turn_response:
|
||||
if hasattr(response.event, "payload"):
|
||||
print(response.event.payload)
|
||||
if response.event.payload.event_type == "step_progress":
|
||||
if hasattr(response.event.payload.delta, "text"):
|
||||
yield response.event.payload.delta.text
|
||||
if response.event.payload.event_type == "step_complete":
|
||||
if response.event.payload.step_details.step_type == "tool_execution":
|
||||
if response.event.payload.step_details.tool_calls:
|
||||
tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name)
|
||||
yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n'
|
||||
else:
|
||||
yield "No tool_calls present in step_details"
|
||||
else:
|
||||
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
response_content = st.write_stream(response_generator(turn_response))
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
|
||||
tool_chat_page()
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
llama-stack>=0.2.1
|
||||
llama-stack-client>=0.2.1
|
||||
pandas
|
||||
streamlit
|
||||
streamlit-option-menu
|
||||
|
|
@ -52,7 +52,17 @@ def resolve_config_or_distro(
|
|||
logger.debug(f"Using distribution: {distro_config}")
|
||||
return distro_config
|
||||
|
||||
# Strategy 3: Try as built distribution name
|
||||
# Strategy 3: Try as distro config path (if no .yaml extension and contains a slash)
|
||||
# eg: starter::run-with-postgres-store.yaml
|
||||
# Use :: to avoid slash and confusion with a filesystem path
|
||||
if "::" in config_or_distro:
|
||||
distro_name, config_name = config_or_distro.split("::")
|
||||
distro_config = _get_distro_config_path(distro_name, config_name)
|
||||
if distro_config.exists():
|
||||
logger.info(f"Using distribution: {distro_config}")
|
||||
return distro_config
|
||||
|
||||
# Strategy 4: Try as built distribution name
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.debug(f"Using built distribution: {distrib_config}")
|
||||
|
|
@ -63,13 +73,15 @@ def resolve_config_or_distro(
|
|||
logger.debug(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
# Strategy 4: Failed - provide helpful error
|
||||
# Strategy 5: Failed - provide helpful error
|
||||
raise ValueError(_format_resolution_error(config_or_distro, mode))
|
||||
|
||||
|
||||
def _get_distro_config_path(distro_name: str, mode: Mode) -> Path:
|
||||
def _get_distro_config_path(distro_name: str, mode: str) -> Path:
|
||||
"""Get the config file path for a distro."""
|
||||
return DISTRO_DIR / distro_name / f"{mode}.yaml"
|
||||
if not mode.endswith(".yaml"):
|
||||
mode = f"{mode}.yaml"
|
||||
return DISTRO_DIR / distro_name / mode
|
||||
|
||||
|
||||
def _format_resolution_error(config_or_distro: str, mode: Mode) -> str:
|
||||
|
|
|
|||
|
|
@ -84,6 +84,15 @@ def run_command(command: list[str]) -> int:
|
|||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
# Print stdout and stderr if command failed
|
||||
if result.returncode != 0:
|
||||
log.error(f"Command {' '.join(command)} failed with returncode {result.returncode}")
|
||||
if result.stdout:
|
||||
log.error(f"STDOUT: {result.stdout}")
|
||||
if result.stderr:
|
||||
log.error(f"STDERR: {result.stderr}")
|
||||
|
||||
return result.returncode
|
||||
except subprocess.SubprocessError as e:
|
||||
log.error(f"Subprocess error: {e}")
|
||||
|
|
|
|||
|
|
@ -56,4 +56,5 @@ image_type: venv
|
|||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- asyncpg
|
||||
- psycopg2-binary
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,293 @@
|
|||
version: 2
|
||||
image_name: ci-tests
|
||||
apis:
|
||||
- agents
|
||||
- batches
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
api_key: ${env.OPENAI_API_KEY:=}
|
||||
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
config:
|
||||
api_key: ${env.ANTHROPIC_API_KEY:=}
|
||||
- provider_id: gemini
|
||||
provider_type: remote::gemini
|
||||
config:
|
||||
api_key: ${env.GEMINI_API_KEY:=}
|
||||
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
|
||||
provider_type: remote::vertexai
|
||||
config:
|
||||
project: ${env.VERTEX_AI_PROJECT:=}
|
||||
location: ${env.VERTEX_AI_LOCATION:=us-central1}
|
||||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
persistence:
|
||||
namespace: vector_io::faiss
|
||||
backend: kv_default
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db
|
||||
persistence:
|
||||
namespace: vector_io::sqlite_vec
|
||||
backend: kv_default
|
||||
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||
provider_type: inline::milvus
|
||||
config:
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db
|
||||
persistence:
|
||||
namespace: vector_io::milvus
|
||||
backend: kv_default
|
||||
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
persistence:
|
||||
namespace: vector_io::chroma_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
host: ${env.PGVECTOR_HOST:=localhost}
|
||||
port: ${env.PGVECTOR_PORT:=5432}
|
||||
db: ${env.PGVECTOR_DB:=}
|
||||
user: ${env.PGVECTOR_USER:=}
|
||||
password: ${env.PGVECTOR_PASSWORD:=}
|
||||
persistence:
|
||||
namespace: vector_io::pgvector
|
||||
backend: kv_default
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
persistence:
|
||||
namespace: vector_io::qdrant_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
persistence:
|
||||
namespace: vector_io::weaviate
|
||||
backend: kv_default
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files}
|
||||
metadata_store:
|
||||
table_name: files_metadata
|
||||
backend: sql_default
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
- provider_id: code-scanner
|
||||
provider_type: inline::code-scanner
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
post_training:
|
||||
- provider_id: torchtune-cpu
|
||||
provider_type: inline::torchtune-cpu
|
||||
config:
|
||||
checkpoint_format: meta
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: eval
|
||||
backend: kv_default
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::huggingface
|
||||
backend: kv_default
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::localfs
|
||||
backend: kv_default
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_id: reference
|
||||
provider_type: inline::reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: batches
|
||||
backend: kv_default
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
|
||||
sql_default:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||
- shield_id: code-scanner
|
||||
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
@ -46,6 +46,9 @@ providers:
|
|||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -3,3 +3,5 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .oci import get_distribution_template # noqa: F401
|
||||
35
src/llama_stack/distributions/oci/build.yaml
Normal file
35
src/llama_stack/distributions/oci/build.yaml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
version: 2
|
||||
distribution_spec:
|
||||
description: Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM
|
||||
inference with scalable cloud services
|
||||
providers:
|
||||
inference:
|
||||
- provider_type: remote::oci
|
||||
vector_io:
|
||||
- provider_type: inline::faiss
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
safety:
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_type: inline::meta-reference
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
140
src/llama_stack/distributions/oci/doc_template.md
Normal file
140
src/llama_stack/distributions/oci/doc_template.md
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# OCI Distribution
|
||||
|
||||
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||
|
||||
{{ providers_table }}
|
||||
|
||||
{% if run_config_env_vars %}
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if default_models %}
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
{% for model in default_models %}
|
||||
- `{{ model.model_id }} {{ model.doc_string }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## Prerequisites
|
||||
### Oracle Cloud Infrastructure Setup
|
||||
|
||||
Before using the OCI Generative AI distribution, ensure you have:
|
||||
|
||||
1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/)
|
||||
2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy
|
||||
3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models
|
||||
4. **Authentication**: Configure authentication using either:
|
||||
- **Instance Principal** (recommended for cloud-hosted deployments)
|
||||
- **API Key** (for on-premises or development environments)
|
||||
|
||||
### Authentication Methods
|
||||
|
||||
#### Instance Principal Authentication (Recommended)
|
||||
Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments.
|
||||
|
||||
Requirements:
|
||||
- Instance must be running in an Oracle Cloud Infrastructure compartment
|
||||
- Instance must have appropriate IAM policies to access Generative AI services
|
||||
|
||||
#### API Key Authentication
|
||||
For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file.
|
||||
|
||||
### Required IAM Policies
|
||||
|
||||
Ensure your OCI user or instance has the following policy statements:
|
||||
|
||||
```
|
||||
Allow group <group_name> to use generative-ai-inference-endpoints in compartment <compartment_name>
|
||||
Allow group <group_name> to manage generative-ai-inference-endpoints in compartment <compartment_name>
|
||||
```
|
||||
|
||||
## Supported Services
|
||||
|
||||
### Inference: OCI Generative AI
|
||||
Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports:
|
||||
|
||||
- **Chat Completions**: Conversational AI with context awareness
|
||||
- **Text Generation**: Complete prompts and generate text content
|
||||
|
||||
#### Available Models
|
||||
Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models.
|
||||
|
||||
### Safety: Llama Guard
|
||||
For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide:
|
||||
- Content filtering and moderation
|
||||
- Policy compliance checking
|
||||
- Harmful content detection
|
||||
|
||||
### Vector Storage: Multiple Options
|
||||
The distribution supports several vector storage providers:
|
||||
- **FAISS**: Local in-memory vector search
|
||||
- **ChromaDB**: Distributed vector database
|
||||
- **PGVector**: PostgreSQL with vector extensions
|
||||
|
||||
### Additional Services
|
||||
- **Dataset I/O**: Local filesystem and Hugging Face integration
|
||||
- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities
|
||||
- **Evaluation**: Meta reference evaluation framework
|
||||
|
||||
## Running Llama Stack with OCI
|
||||
|
||||
You can run the OCI distribution via Docker or local virtual environment.
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||
|
||||
```bash
|
||||
OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci
|
||||
```
|
||||
|
||||
### Configuration Examples
|
||||
|
||||
#### Using Instance Principal (Recommended for Production)
|
||||
```bash
|
||||
export OCI_AUTH_TYPE=instance_principal
|
||||
export OCI_REGION=us-chicago-1
|
||||
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..<your-compartment-id>
|
||||
```
|
||||
|
||||
#### Using API Key Authentication (Development)
|
||||
```bash
|
||||
export OCI_AUTH_TYPE=config_file
|
||||
export OCI_CONFIG_FILE_PATH=~/.oci/config
|
||||
export OCI_CLI_PROFILE=DEFAULT
|
||||
export OCI_REGION=us-chicago-1
|
||||
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id
|
||||
```
|
||||
|
||||
## Regional Endpoints
|
||||
|
||||
OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit:
|
||||
|
||||
https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Authentication Errors**: Verify your OCI credentials and IAM policies
|
||||
2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region
|
||||
3. **Permission Denied**: Check compartment permissions and Generative AI service access
|
||||
4. **Region Unavailable**: Verify the specified region supports Generative AI services
|
||||
|
||||
### Getting Help
|
||||
|
||||
For additional support:
|
||||
- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)
|
||||
- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues)
|
||||
108
src/llama_stack/distributions/oci/oci.py
Normal file
108
src/llama_stack/distributions/oci/oci.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.oci.config import OCIConfig
|
||||
|
||||
|
||||
def get_distribution_template(name: str = "oci") -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": [BuildProvider(provider_type="remote::oci")],
|
||||
"vector_io": [
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
],
|
||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"datasetio": [
|
||||
BuildProvider(provider_type="remote::huggingface"),
|
||||
BuildProvider(provider_type="inline::localfs"),
|
||||
],
|
||||
"scoring": [
|
||||
BuildProvider(provider_type="inline::basic"),
|
||||
BuildProvider(provider_type="inline::llm-as-judge"),
|
||||
BuildProvider(provider_type="inline::braintrust"),
|
||||
],
|
||||
"tool_runtime": [
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
provider_id="oci",
|
||||
provider_type="remote::oci",
|
||||
config=OCIConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
vector_io_provider = Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="remote_hosted",
|
||||
description="Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM inference with scalable cloud services",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"OCI_AUTH_TYPE": (
|
||||
"instance_principal",
|
||||
"OCI authentication type (instance_principal or config_file)",
|
||||
),
|
||||
"OCI_REGION": (
|
||||
"",
|
||||
"OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1)",
|
||||
),
|
||||
"OCI_COMPARTMENT_OCID": (
|
||||
"",
|
||||
"OCI compartment ID for the Generative AI service",
|
||||
),
|
||||
"OCI_CONFIG_FILE_PATH": (
|
||||
"~/.oci/config",
|
||||
"OCI config file path (required if OCI_AUTH_TYPE is config_file)",
|
||||
),
|
||||
"OCI_CLI_PROFILE": (
|
||||
"DEFAULT",
|
||||
"OCI CLI profile name to use from config file",
|
||||
),
|
||||
},
|
||||
)
|
||||
136
src/llama_stack/distributions/oci/run.yaml
Normal file
136
src/llama_stack/distributions/oci/run.yaml
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
version: 2
|
||||
image_name: oci
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: oci
|
||||
provider_type: remote::oci
|
||||
config:
|
||||
oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal}
|
||||
oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}
|
||||
oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT}
|
||||
oci_region: ${env.OCI_REGION:=us-ashburn-1}
|
||||
oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
persistence:
|
||||
namespace: vector_io::faiss
|
||||
backend: kv_default
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: eval
|
||||
backend: kv_default
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::huggingface
|
||||
backend: kv_default
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::localfs
|
||||
backend: kv_default
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/oci/files}
|
||||
metadata_store:
|
||||
table_name: files_metadata
|
||||
backend: sql_default
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/kvstore.db
|
||||
sql_default:
|
||||
type: sql_sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/sql_store.db
|
||||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .postgres_demo import get_distribution_template # noqa: F401
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
version: 2
|
||||
distribution_spec:
|
||||
description: Quick start template for running Llama Stack with several popular providers
|
||||
providers:
|
||||
inference:
|
||||
- provider_type: remote::vllm
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_type: remote::chromadb
|
||||
safety:
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_type: inline::meta-reference
|
||||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- asyncpg
|
||||
- psycopg2-binary
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
@ -1,125 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import (
|
||||
BuildProvider,
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.distributions.template import (
|
||||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import SentenceTransformersInferenceConfig
|
||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
inference_providers = [
|
||||
Provider(
|
||||
provider_id="vllm-inference",
|
||||
provider_type="remote::vllm",
|
||||
config=VLLMInferenceAdapterConfig.sample_run_config(
|
||||
url="${env.VLLM_URL:=http://localhost:8000/v1}",
|
||||
),
|
||||
),
|
||||
]
|
||||
providers = {
|
||||
"inference": [
|
||||
BuildProvider(provider_type="remote::vllm"),
|
||||
BuildProvider(provider_type="inline::sentence-transformers"),
|
||||
],
|
||||
"vector_io": [BuildProvider(provider_type="remote::chromadb")],
|
||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"tool_runtime": [
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
}
|
||||
name = "postgres-demo"
|
||||
|
||||
vector_io_providers = [
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_CHROMADB:+chromadb}",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
url="${env.CHROMADB_URL:=}",
|
||||
),
|
||||
),
|
||||
]
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
|
||||
default_models = [
|
||||
ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="vllm-inference",
|
||||
)
|
||||
]
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
embedding_model = ModelInput(
|
||||
model_id="nomic-embed-text-v1.5",
|
||||
provider_id=embedding_provider.provider_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
},
|
||||
)
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
description="Quick start template for running Llama Stack with several popular providers",
|
||||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider={},
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": inference_providers + [embedding_provider],
|
||||
"vector_io": vector_io_providers,
|
||||
},
|
||||
default_models=default_models + [embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||
storage_backends={
|
||||
"kv_default": PostgresKVStoreConfig.sample_run_config(
|
||||
table_name="llamastack_kvstore",
|
||||
),
|
||||
"sql_default": PostgresSqlStoreConfig.sample_run_config(),
|
||||
},
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
|
@ -57,4 +57,5 @@ image_type: venv
|
|||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- asyncpg
|
||||
- psycopg2-binary
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,296 @@
|
|||
version: 2
|
||||
image_name: starter-gpu
|
||||
apis:
|
||||
- agents
|
||||
- batches
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
api_key: ${env.OPENAI_API_KEY:=}
|
||||
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
config:
|
||||
api_key: ${env.ANTHROPIC_API_KEY:=}
|
||||
- provider_id: gemini
|
||||
provider_type: remote::gemini
|
||||
config:
|
||||
api_key: ${env.GEMINI_API_KEY:=}
|
||||
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
|
||||
provider_type: remote::vertexai
|
||||
config:
|
||||
project: ${env.VERTEX_AI_PROJECT:=}
|
||||
location: ${env.VERTEX_AI_LOCATION:=us-central1}
|
||||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
persistence:
|
||||
namespace: vector_io::faiss
|
||||
backend: kv_default
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec.db
|
||||
persistence:
|
||||
namespace: vector_io::sqlite_vec
|
||||
backend: kv_default
|
||||
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||
provider_type: inline::milvus
|
||||
config:
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter-gpu}/milvus.db
|
||||
persistence:
|
||||
namespace: vector_io::milvus
|
||||
backend: kv_default
|
||||
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
persistence:
|
||||
namespace: vector_io::chroma_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
host: ${env.PGVECTOR_HOST:=localhost}
|
||||
port: ${env.PGVECTOR_PORT:=5432}
|
||||
db: ${env.PGVECTOR_DB:=}
|
||||
user: ${env.PGVECTOR_USER:=}
|
||||
password: ${env.PGVECTOR_PASSWORD:=}
|
||||
persistence:
|
||||
namespace: vector_io::pgvector
|
||||
backend: kv_default
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
persistence:
|
||||
namespace: vector_io::qdrant_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
persistence:
|
||||
namespace: vector_io::weaviate
|
||||
backend: kv_default
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter-gpu/files}
|
||||
metadata_store:
|
||||
table_name: files_metadata
|
||||
backend: sql_default
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
- provider_id: code-scanner
|
||||
provider_type: inline::code-scanner
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
post_training:
|
||||
- provider_id: huggingface-gpu
|
||||
provider_type: inline::huggingface-gpu
|
||||
config:
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
dpo_output_dir: ~/.llama/distributions/starter-gpu/dpo_output
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: eval
|
||||
backend: kv_default
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::huggingface
|
||||
backend: kv_default
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::localfs
|
||||
backend: kv_default
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_id: reference
|
||||
provider_type: inline::reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: batches
|
||||
backend: kv_default
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
|
||||
sql_default:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||
- shield_id: code-scanner
|
||||
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
@ -46,6 +46,9 @@ providers:
|
|||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -57,4 +57,5 @@ image_type: venv
|
|||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- asyncpg
|
||||
- psycopg2-binary
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,293 @@
|
|||
version: 2
|
||||
image_name: starter
|
||||
apis:
|
||||
- agents
|
||||
- batches
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
api_key: ${env.OPENAI_API_KEY:=}
|
||||
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
config:
|
||||
api_key: ${env.ANTHROPIC_API_KEY:=}
|
||||
- provider_id: gemini
|
||||
provider_type: remote::gemini
|
||||
config:
|
||||
api_key: ${env.GEMINI_API_KEY:=}
|
||||
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
|
||||
provider_type: remote::vertexai
|
||||
config:
|
||||
project: ${env.VERTEX_AI_PROJECT:=}
|
||||
location: ${env.VERTEX_AI_LOCATION:=us-central1}
|
||||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
persistence:
|
||||
namespace: vector_io::faiss
|
||||
backend: kv_default
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
|
||||
persistence:
|
||||
namespace: vector_io::sqlite_vec
|
||||
backend: kv_default
|
||||
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||
provider_type: inline::milvus
|
||||
config:
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
|
||||
persistence:
|
||||
namespace: vector_io::milvus
|
||||
backend: kv_default
|
||||
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
persistence:
|
||||
namespace: vector_io::chroma_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
host: ${env.PGVECTOR_HOST:=localhost}
|
||||
port: ${env.PGVECTOR_PORT:=5432}
|
||||
db: ${env.PGVECTOR_DB:=}
|
||||
user: ${env.PGVECTOR_USER:=}
|
||||
password: ${env.PGVECTOR_PASSWORD:=}
|
||||
persistence:
|
||||
namespace: vector_io::pgvector
|
||||
backend: kv_default
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
persistence:
|
||||
namespace: vector_io::qdrant_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
persistence:
|
||||
namespace: vector_io::weaviate
|
||||
backend: kv_default
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||
metadata_store:
|
||||
table_name: files_metadata
|
||||
backend: sql_default
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
- provider_id: code-scanner
|
||||
provider_type: inline::code-scanner
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
post_training:
|
||||
- provider_id: torchtune-cpu
|
||||
provider_type: inline::torchtune-cpu
|
||||
config:
|
||||
checkpoint_format: meta
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: eval
|
||||
backend: kv_default
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::huggingface
|
||||
backend: kv_default
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::localfs
|
||||
backend: kv_default
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_id: reference
|
||||
provider_type: inline::reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: batches
|
||||
backend: kv_default
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
|
||||
sql_default:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||
- shield_id: code-scanner
|
||||
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
@ -46,6 +46,9 @@ providers:
|
|||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
|
|||
)
|
||||
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
|
||||
|
||||
|
|
@ -148,10 +149,11 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
BuildProvider(provider_type="inline::reference"),
|
||||
],
|
||||
}
|
||||
files_config = LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}")
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
config=files_config,
|
||||
)
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
|
|
@ -181,6 +183,90 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
|
||||
),
|
||||
]
|
||||
postgres_sql_config = PostgresSqlStoreConfig.sample_run_config()
|
||||
postgres_kv_config = PostgresKVStoreConfig.sample_run_config()
|
||||
default_overrides = {
|
||||
"inference": remote_inference_providers + [embedding_provider],
|
||||
"vector_io": [
|
||||
Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.MILVUS_URL:+milvus}",
|
||||
provider_type="inline::milvus",
|
||||
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.CHROMADB_URL:+chromadb}",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}/",
|
||||
url="${env.CHROMADB_URL:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.PGVECTOR_DB:+pgvector}",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
db="${env.PGVECTOR_DB:=}",
|
||||
user="${env.PGVECTOR_USER:=}",
|
||||
password="${env.PGVECTOR_PASSWORD:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.QDRANT_URL:+qdrant}",
|
||||
provider_type="remote::qdrant",
|
||||
config=QdrantVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
url="${env.QDRANT_URL:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
cluster_url="${env.WEAVIATE_CLUSTER_URL:=}",
|
||||
),
|
||||
),
|
||||
],
|
||||
"files": [files_provider],
|
||||
}
|
||||
|
||||
base_run_settings = RunConfigSettings(
|
||||
provider_overrides=default_overrides,
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=default_shields,
|
||||
vector_stores_config=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="sentence-transformers",
|
||||
model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
),
|
||||
),
|
||||
safety_config=SafetyConfig(
|
||||
default_shield_id="llama-guard",
|
||||
),
|
||||
)
|
||||
|
||||
postgres_run_settings = base_run_settings.model_copy(
|
||||
update={
|
||||
"storage_backends": {
|
||||
"kv_default": postgres_kv_config,
|
||||
"sql_default": postgres_sql_config,
|
||||
}
|
||||
},
|
||||
deep=True,
|
||||
)
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
|
@ -189,78 +275,10 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
|
||||
additional_pip_packages=list(set(PostgresSqlStoreConfig.pip_packages() + PostgresKVStoreConfig.pip_packages())),
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": remote_inference_providers + [embedding_provider],
|
||||
"vector_io": [
|
||||
Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.MILVUS_URL:+milvus}",
|
||||
provider_type="inline::milvus",
|
||||
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.CHROMADB_URL:+chromadb}",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}/",
|
||||
url="${env.CHROMADB_URL:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.PGVECTOR_DB:+pgvector}",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
db="${env.PGVECTOR_DB:=}",
|
||||
user="${env.PGVECTOR_USER:=}",
|
||||
password="${env.PGVECTOR_PASSWORD:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.QDRANT_URL:+qdrant}",
|
||||
provider_type="remote::qdrant",
|
||||
config=QdrantVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
url="${env.QDRANT_URL:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
cluster_url="${env.WEAVIATE_CLUSTER_URL:=}",
|
||||
),
|
||||
),
|
||||
],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=default_shields,
|
||||
vector_stores_config=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="sentence-transformers",
|
||||
model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
),
|
||||
),
|
||||
safety_config=SafetyConfig(
|
||||
default_shield_id="llama-guard",
|
||||
),
|
||||
),
|
||||
"run.yaml": base_run_settings,
|
||||
"run-with-postgres-store.yaml": postgres_run_settings,
|
||||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
|
|
|
|||
|
|
@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
||||
|
||||
from ..checkpoint import maybe_reshard_state_dict
|
||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
|
||||
from .args import ModelArgs
|
||||
from .chat_format import ChatFormat, LLMInput
|
||||
from .model import Transformer
|
||||
|
|
|
|||
|
|
@ -15,13 +15,10 @@ from pathlib import Path
|
|||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat
|
||||
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from . import template_data
|
||||
from .chat_format import ChatFormat
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import textwrap
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ import json
|
|||
import re
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolPromptFormat
|
||||
|
||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
from ..datatypes import RecursiveType
|
||||
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
import textwrap
|
||||
|
||||
from llama_stack.apis.inference import ToolDefinition
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||
PromptTemplate,
|
||||
PromptTemplateGeneratorBase,
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -4,21 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
Agents,
|
||||
AgentSessionCreateResponse,
|
||||
AgentStepResponse,
|
||||
AgentToolGroup,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
|
|
@ -26,19 +14,12 @@ from llama_stack.apis.agents import (
|
|||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
Order,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrail
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponsePrompt, OpenAIResponseText
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
|
|
@ -46,12 +27,9 @@ from llama_stack.apis.vector_io import VectorIO
|
|||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
from .persistence import AgentInfo
|
||||
from .responses.openai_responses import OpenAIResponsesImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
|
@ -97,229 +75,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
conversations_api=self.conversations_api,
|
||||
)
|
||||
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse:
|
||||
agent_id = str(uuid.uuid4())
|
||||
created_at = datetime.now(UTC)
|
||||
|
||||
agent_info = AgentInfo(
|
||||
**agent_config.model_dump(),
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
# Store the agent info
|
||||
await self.persistence_store.set(
|
||||
key=f"agent:{agent_id}",
|
||||
value=agent_info.model_dump_json(),
|
||||
)
|
||||
|
||||
return AgentCreateResponse(
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||
agent_info_json = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
if not agent_info_json:
|
||||
raise ValueError(f"Could not find agent info for {agent_id}")
|
||||
|
||||
try:
|
||||
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not validate agent info for {agent_id}") from e
|
||||
|
||||
return ChatAgent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_info,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
vector_io_api=self.vector_io_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
persistence_store=(
|
||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
created_at=agent_info.created_at.isoformat(),
|
||||
policy=self.policy,
|
||||
telemetry_enabled=self.telemetry_enabled,
|
||||
)
|
||||
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_id = await agent.create_session(session_name)
|
||||
return AgentSessionCreateResponse(
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
toolgroups=toolgroups,
|
||||
documents=documents,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _create_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: list[ToolResponse],
|
||||
stream: bool | None = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
tool_responses=tool_responses,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
return self._continue_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _continue_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnResumeRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.resume_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
if turn is None:
|
||||
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
|
||||
return turn
|
||||
|
||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
|
||||
for step in turn.steps:
|
||||
if step.step_id == step_id:
|
||||
return AgentStepResponse(step=step)
|
||||
raise ValueError(f"Provided step_id {step_id} could not be found")
|
||||
|
||||
async def get_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
turns = await agent.storage.get_session_turns(session_id)
|
||||
if turn_ids:
|
||||
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||
return Session(
|
||||
session_name=session_info.session_name,
|
||||
session_id=session_id,
|
||||
turns=turns,
|
||||
started_at=session_info.started_at,
|
||||
)
|
||||
|
||||
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
# Delete turns first, then the session
|
||||
await agent.storage.delete_session_turns(session_id)
|
||||
await agent.storage.delete_session(session_id)
|
||||
|
||||
async def delete_agent(self, agent_id: str) -> None:
|
||||
# First get all sessions for this agent
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
sessions = await agent.storage.list_sessions()
|
||||
|
||||
# Delete all sessions
|
||||
for session in sessions:
|
||||
await self.delete_agents_session(agent_id, session.session_id)
|
||||
|
||||
# Finally delete the agent itself
|
||||
await self.persistence_store.delete(f"agent:{agent_id}")
|
||||
|
||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
|
||||
agent_list: list[Agent] = []
|
||||
for agent_key in agent_keys:
|
||||
agent_id = agent_key.split(":")[1]
|
||||
|
||||
# Get the agent info using the key
|
||||
agent_info_json = await self.persistence_store.get(agent_key)
|
||||
if not agent_info_json:
|
||||
logger.error(f"Could not find agent info for key {agent_key}")
|
||||
continue
|
||||
|
||||
try:
|
||||
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||
agent_list.append(
|
||||
Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_info,
|
||||
created_at=agent_info.created_at,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing agent info for {agent_id}: {e}")
|
||||
continue
|
||||
|
||||
# Convert Agent objects to dictionaries
|
||||
agent_dicts = [agent.model_dump() for agent in agent_list]
|
||||
return paginate_records(agent_dicts, start_index, limit)
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
chat_agent = await self._get_agent_impl(agent_id)
|
||||
agent = Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=chat_agent.agent_config,
|
||||
created_at=datetime.fromisoformat(chat_agent.created_at),
|
||||
)
|
||||
return agent
|
||||
|
||||
async def list_agent_sessions(
|
||||
self, agent_id: str, start_index: int | None = None, limit: int | None = None
|
||||
) -> PaginatedResponse:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
sessions = await agent.storage.list_sessions()
|
||||
# Convert Session objects to dictionaries
|
||||
session_dicts = [session.model_dump() for session in sessions]
|
||||
return paginate_records(session_dicts, start_index, limit)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
|
|
@ -347,6 +102,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[ResponseGuardrail] | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
result = await self.openai_responses_impl.create_openai_response(
|
||||
|
|
@ -364,6 +120,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
include,
|
||||
max_infer_iters,
|
||||
guardrails,
|
||||
max_tool_calls,
|
||||
)
|
||||
return result # type: ignore[no-any-return]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,261 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.conditions import User as ProtocolUser
|
||||
from llama_stack.core.access_control.datatypes import AccessRule, Action
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class AgentSessionInfo(Session):
|
||||
# TODO: is this used anywhere?
|
||||
vector_store_id: str | None = None
|
||||
started_at: datetime
|
||||
owner: User | None = None
|
||||
identifier: str | None = None
|
||||
type: str = "session"
|
||||
|
||||
|
||||
class AgentInfo(AgentConfig):
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionResource:
|
||||
"""Concrete implementation of ProtectedResource for session access control."""
|
||||
|
||||
type: str
|
||||
identifier: str
|
||||
owner: ProtocolUser # Use the protocol type for structural compatibility
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||
self.agent_id = agent_id
|
||||
self.kvstore = kvstore
|
||||
self.policy = policy
|
||||
|
||||
async def create_session(self, name: str) -> str:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Get current user's auth attributes for new sessions
|
||||
user = get_authenticated_user()
|
||||
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(UTC),
|
||||
owner=user,
|
||||
turns=[],
|
||||
identifier=name, # should this be qualified in any way?
|
||||
)
|
||||
# Only perform access control if we have an authenticated user
|
||||
if user is not None and session_info.identifier is not None:
|
||||
resource = SessionResource(
|
||||
type=session_info.type,
|
||||
identifier=session_info.identifier,
|
||||
owner=user,
|
||||
)
|
||||
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
|
||||
raise AccessDeniedError(Action.CREATE, resource, user)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
return session_id
|
||||
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
if not value:
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
session_info = AgentSessionInfo(**json.loads(value))
|
||||
|
||||
# Check access to session
|
||||
if not self._check_session_access(session_info):
|
||||
return None
|
||||
|
||||
return session_info
|
||||
|
||||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||
"""Check if current user has access to the session."""
|
||||
# Handle backward compatibility for old sessions without access control
|
||||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||
return True
|
||||
|
||||
# Get current user - if None, skip access control (e.g., in tests)
|
||||
user = get_authenticated_user()
|
||||
if user is None:
|
||||
return True
|
||||
|
||||
# Access control requires identifier and owner to be set
|
||||
if session_info.identifier is None or session_info.owner is None:
|
||||
return True
|
||||
|
||||
# At this point, both identifier and owner are guaranteed to be non-None
|
||||
resource = SessionResource(
|
||||
type=session_info.type,
|
||||
identifier=session_info.identifier,
|
||||
owner=session_info.owner,
|
||||
)
|
||||
return is_action_allowed(self.policy, Action.READ, resource, user)
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if not session_info:
|
||||
return None
|
||||
|
||||
return session_info
|
||||
|
||||
async def add_vector_db_to_session(self, session_id: str, vector_store_id: str):
|
||||
session_info = await self.get_session_if_accessible(session_id)
|
||||
if session_info is None:
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
session_info.vector_store_id = vector_store_id
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||
value=turn.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_session_turns(self, session_id: str) -> list[Turn]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
values = await self.kvstore.values_in_range(
|
||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
turns = []
|
||||
for value in values:
|
||||
try:
|
||||
turn = Turn(**json.loads(value))
|
||||
turns.append(turn)
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing turn: {e}")
|
||||
continue
|
||||
|
||||
# The kvstore does not guarantee order, so we sort by started_at
|
||||
# to ensure consistent ordering of turns.
|
||||
turns.sort(key=lambda t: t.started_at)
|
||||
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
if not value:
|
||||
return None
|
||||
return Turn(**json.loads(value))
|
||||
|
||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=step.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
||||
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return int(value) if value else None
|
||||
|
||||
async def list_sessions(self) -> list[Session]:
|
||||
values = await self.kvstore.values_in_range(
|
||||
start_key=f"session:{self.agent_id}:",
|
||||
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
sessions = []
|
||||
for value in values:
|
||||
try:
|
||||
data = json.loads(value)
|
||||
if "turn_id" in data:
|
||||
continue
|
||||
|
||||
session_info = Session(**data)
|
||||
sessions.append(session_info)
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing session info: {e}")
|
||||
continue
|
||||
return sessions
|
||||
|
||||
async def delete_session_turns(self, session_id: str) -> None:
|
||||
"""Delete all turns and their associated data for a session.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session whose turns should be deleted.
|
||||
"""
|
||||
turns = await self.get_session_turns(session_id)
|
||||
for turn in turns:
|
||||
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
|
||||
|
||||
async def delete_session(self, session_id: str) -> None:
|
||||
"""Delete a session and all its associated turns.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session does not exist.
|
||||
"""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")
|
||||
|
|
@ -255,6 +255,7 @@ class OpenAIResponsesImpl:
|
|||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
|
@ -270,6 +271,9 @@ class OpenAIResponsesImpl:
|
|||
if not conversation.startswith("conv_"):
|
||||
raise InvalidConversationIdError(conversation)
|
||||
|
||||
if max_tool_calls is not None and max_tool_calls < 1:
|
||||
raise ValueError(f"Invalid {max_tool_calls=}; should be >= 1")
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
conversation=conversation,
|
||||
|
|
@ -282,6 +286,7 @@ class OpenAIResponsesImpl:
|
|||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
guardrail_ids=guardrail_ids,
|
||||
max_tool_calls=max_tool_calls,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
@ -331,6 +336,7 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# These should never be None when called from create_openai_response (which sets defaults)
|
||||
# but we assert here to help mypy understand the types
|
||||
|
|
@ -373,6 +379,7 @@ class OpenAIResponsesImpl:
|
|||
safety_api=self.safety_api,
|
||||
guardrail_ids=guardrail_ids,
|
||||
instructions=instructions,
|
||||
max_tool_calls=max_tool_calls,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
|
|
|
|||
|
|
@ -115,6 +115,7 @@ class StreamingResponseOrchestrator:
|
|||
safety_api,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
prompt: OpenAIResponsePrompt | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
|
|
@ -126,6 +127,10 @@ class StreamingResponseOrchestrator:
|
|||
self.safety_api = safety_api
|
||||
self.guardrail_ids = guardrail_ids or []
|
||||
self.prompt = prompt
|
||||
# System message that is inserted into the model's context
|
||||
self.instructions = instructions
|
||||
# Max number of total calls to built-in tools that can be processed in a response
|
||||
self.max_tool_calls = max_tool_calls
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||
|
|
@ -139,8 +144,8 @@ class StreamingResponseOrchestrator:
|
|||
self.accumulated_usage: OpenAIResponseUsage | None = None
|
||||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
# system message that is inserted into the model's context
|
||||
self.instructions = instructions
|
||||
# Track total calls made to built-in tools
|
||||
self.accumulated_builtin_tool_calls = 0
|
||||
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
|
|
@ -186,6 +191,7 @@ class StreamingResponseOrchestrator:
|
|||
usage=self.accumulated_usage,
|
||||
instructions=self.instructions,
|
||||
prompt=self.prompt,
|
||||
max_tool_calls=self.max_tool_calls,
|
||||
)
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
|
|
@ -894,6 +900,11 @@ class StreamingResponseOrchestrator:
|
|||
"""Coordinate execution of both function and non-function tool calls."""
|
||||
# Execute non-function tool calls
|
||||
for tool_call in non_function_tool_calls:
|
||||
# Check if total calls made to built-in and mcp tools exceed max_tool_calls
|
||||
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
|
||||
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
|
||||
break
|
||||
|
||||
# Find the item_id for this tool call
|
||||
matching_item_id = None
|
||||
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
||||
|
|
@ -974,6 +985,9 @@ class StreamingResponseOrchestrator:
|
|||
if tool_response_message:
|
||||
next_turn_messages.append(tool_response_message)
|
||||
|
||||
# Track number of calls made to built-in and mcp tools
|
||||
self.accumulated_builtin_tool_calls += 1
|
||||
|
||||
# Execute function tool calls (client-side)
|
||||
for tool_call in function_tool_calls:
|
||||
# Find the item_id for this tool call from our tracking dictionary
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.agents import Agents, StepType
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
|
@ -18,13 +18,9 @@ from llama_stack.apis.inference import (
|
|||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
|
|
@ -118,49 +114,6 @@ class MetaReferenceEvalImpl(
|
|||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
create_response = await self.agents_api.create_agent(candidate.config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
generations = []
|
||||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=input_messages,
|
||||
stream=True,
|
||||
)
|
||||
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
|
||||
final_event = turn_response[-1].event.payload
|
||||
|
||||
# check if there's a memory retrieval step and extract the context
|
||||
memory_rag_context = None
|
||||
for step in final_event.turn.steps:
|
||||
if step.step_type == StepType.tool_execution.value:
|
||||
for tool_response in step.tool_responses:
|
||||
if tool_response.tool_name == MEMORY_QUERY_TOOL:
|
||||
memory_rag_context = " ".join(x.text for x in tool_response.content)
|
||||
|
||||
agent_generation = {}
|
||||
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
|
||||
if memory_rag_context:
|
||||
agent_generation[ColumnName.context.value] = memory_rag_context
|
||||
|
||||
generations.append(agent_generation)
|
||||
|
||||
return generations
|
||||
|
||||
async def _run_model_generation(
|
||||
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
|
|
@ -215,9 +168,8 @@ class MetaReferenceEvalImpl(
|
|||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
generations = await self._run_agent_generation(input_rows, benchmark_config)
|
||||
elif candidate.type == "model":
|
||||
# Agent evaluation removed
|
||||
if candidate.type == "model":
|
||||
generations = await self._run_model_generation(input_rows, benchmark_config)
|
||||
else:
|
||||
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
|
@ -14,21 +13,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
|
|||
from llama_stack.apis.inference import (
|
||||
GreedySamplingStrategy,
|
||||
JsonSchemaResponseFormat,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import QuantizationMode
|
||||
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
|
||||
from llama_stack.models.llama.llama3.generation import Llama3
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.generation import Llama4
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
get_default_tool_prompt_format,
|
||||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
|
|
@ -106,14 +103,6 @@ def _infer_sampling_params(sampling_params: SamplingParams):
|
|||
return temperature, top_p
|
||||
|
||||
|
||||
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
||||
tool_config = request.tool_config
|
||||
if tool_config is not None and tool_config.tool_prompt_format is not None:
|
||||
return tool_config.tool_prompt_format
|
||||
else:
|
||||
return get_default_tool_prompt_format(request.model)
|
||||
|
||||
|
||||
class LlamaGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -157,55 +146,56 @@ class LlamaGenerator:
|
|||
self.args = self.inner_generator.args
|
||||
self.formatter = self.inner_generator.formatter
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: list[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(first_request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
request: OpenAIChatCompletionRequestWithExtraBody,
|
||||
raw_messages: list,
|
||||
):
|
||||
"""Generate chat completion using OpenAI request format.
|
||||
|
||||
Args:
|
||||
request: OpenAI chat completion request
|
||||
raw_messages: Pre-converted list of RawMessage objects
|
||||
"""
|
||||
|
||||
# Determine tool prompt format
|
||||
tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json
|
||||
|
||||
# Prepare sampling params
|
||||
sampling_params = SamplingParams()
|
||||
if request.temperature is not None or request.top_p is not None:
|
||||
sampling_params.strategy = TopPSamplingStrategy(
|
||||
temperature=request.temperature if request.temperature is not None else 1.0,
|
||||
top_p=request.top_p if request.top_p is not None else 1.0,
|
||||
)
|
||||
if request.max_tokens:
|
||||
sampling_params.max_tokens = request.max_tokens
|
||||
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
|
||||
# Get logits processor for response format
|
||||
logits_processor = None
|
||||
if request.response_format:
|
||||
if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
|
||||
# Extract the actual schema from OpenAIJSONSchema TypedDict
|
||||
schema_dict = request.response_format.json_schema.get("schema") or {}
|
||||
json_schema_format = JsonSchemaResponseFormat(
|
||||
type=ResponseFormatType.json_schema,
|
||||
json_schema=schema_dict,
|
||||
)
|
||||
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
|
||||
|
||||
# Generate
|
||||
yield from self.inner_generator.generate(
|
||||
llm_inputs=[
|
||||
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
|
||||
for request in request_batch
|
||||
],
|
||||
llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(first_request.logprobs),
|
||||
logprobs=False,
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,12 +5,19 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionUsage,
|
||||
OpenAIChoice,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIUserMessageParam,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
|
|
@ -19,12 +26,20 @@ from llama_stack.apis.inference.inference import (
|
|||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||
JsonCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||
)
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily
|
||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
|
|
@ -44,6 +59,170 @@ log = get_logger(__name__, category="inference")
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition:
|
||||
"""Convert OpenAI tool format to ToolDefinition format."""
|
||||
# OpenAI tools have function.name and function.parameters
|
||||
return ToolDefinition(
|
||||
tool_name=tool.function.name,
|
||||
description=tool.function.description or "",
|
||||
parameters=tool.function.parameters or {},
|
||||
)
|
||||
|
||||
|
||||
def _get_tool_choice_prompt(tool_choice, tools) -> str:
|
||||
"""Generate prompt text for tool_choice behavior."""
|
||||
if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto":
|
||||
return ""
|
||||
elif tool_choice == ToolChoice.required or tool_choice == "required":
|
||||
return "You MUST use one of the provided functions/tools to answer the user query."
|
||||
elif tool_choice == ToolChoice.none or tool_choice == "none":
|
||||
return ""
|
||||
else:
|
||||
# Specific tool specified
|
||||
return f"You MUST use the tool `{tool_choice}` to answer the user query."
|
||||
|
||||
|
||||
def _raw_content_as_str(content) -> str:
|
||||
"""Convert RawContent to string for system messages."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, RawTextItem):
|
||||
return content.text
|
||||
elif isinstance(content, list):
|
||||
return "\n".join(_raw_content_as_str(c) for c in content)
|
||||
else:
|
||||
return "<media>"
|
||||
|
||||
|
||||
def _augment_raw_messages_for_tools_llama_3_1(
|
||||
raw_messages: list[RawMessage],
|
||||
tools: list,
|
||||
tool_choice,
|
||||
) -> list[RawMessage]:
|
||||
"""Augment raw messages with tool definitions for Llama 3.1 style models."""
|
||||
messages = raw_messages.copy()
|
||||
existing_system_message = None
|
||||
if messages and messages[0].role == "system":
|
||||
existing_system_message = messages.pop(0)
|
||||
|
||||
sys_content = ""
|
||||
|
||||
# Add tool definitions first (if present)
|
||||
if tools:
|
||||
# Convert OpenAI tools to ToolDefinitions
|
||||
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
|
||||
|
||||
# For OpenAI format, all tools are custom (have string names)
|
||||
tool_gen = JsonCustomToolGenerator()
|
||||
tool_template = tool_gen.gen(tool_definitions)
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
# Add default system prompt
|
||||
default_gen = SystemDefaultGenerator()
|
||||
default_template = default_gen.gen()
|
||||
sys_content += default_template.render()
|
||||
|
||||
# Add existing system message if present
|
||||
if existing_system_message:
|
||||
sys_content += "\n" + _raw_content_as_str(existing_system_message.content)
|
||||
|
||||
# Add tool choice prompt if needed
|
||||
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
# Create new system message
|
||||
new_system_message = RawMessage(
|
||||
role="system",
|
||||
content=[RawTextItem(text=sys_content.strip())],
|
||||
)
|
||||
|
||||
return [new_system_message] + messages
|
||||
|
||||
|
||||
def _augment_raw_messages_for_tools_llama_4(
|
||||
raw_messages: list[RawMessage],
|
||||
tools: list,
|
||||
tool_choice,
|
||||
) -> list[RawMessage]:
|
||||
"""Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models."""
|
||||
messages = raw_messages.copy()
|
||||
existing_system_message = None
|
||||
if messages and messages[0].role == "system":
|
||||
existing_system_message = messages.pop(0)
|
||||
|
||||
sys_content = ""
|
||||
|
||||
# Add tool definitions if present
|
||||
if tools:
|
||||
# Convert OpenAI tools to ToolDefinitions
|
||||
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
|
||||
|
||||
# Use python_list format for Llama 4
|
||||
tool_gen = PythonListCustomToolGeneratorLlama4()
|
||||
system_prompt = None
|
||||
if existing_system_message:
|
||||
system_prompt = _raw_content_as_str(existing_system_message.content)
|
||||
|
||||
tool_template = tool_gen.gen(tool_definitions, system_prompt)
|
||||
sys_content = tool_template.render()
|
||||
elif existing_system_message:
|
||||
# No tools, just use existing system message
|
||||
sys_content = _raw_content_as_str(existing_system_message.content)
|
||||
|
||||
# Add tool choice prompt if needed
|
||||
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
if sys_content:
|
||||
new_system_message = RawMessage(
|
||||
role="system",
|
||||
content=[RawTextItem(text=sys_content.strip())],
|
||||
)
|
||||
return [new_system_message] + messages
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def augment_raw_messages_for_tools(
|
||||
raw_messages: list[RawMessage],
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
llama_model,
|
||||
) -> list[RawMessage]:
|
||||
"""Augment raw messages with tool definitions based on model family."""
|
||||
if not params.tools:
|
||||
return raw_messages
|
||||
|
||||
# Determine augmentation strategy based on model family
|
||||
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
|
||||
):
|
||||
# Llama 3.1 and Llama 3.2 multimodal use JSON format
|
||||
return _augment_raw_messages_for_tools_llama_3_1(
|
||||
raw_messages,
|
||||
params.tools,
|
||||
params.tool_choice,
|
||||
)
|
||||
elif llama_model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
):
|
||||
# Llama 3.2/3.3/4 use python_list format
|
||||
return _augment_raw_messages_for_tools_llama_4(
|
||||
raw_messages,
|
||||
params.tools,
|
||||
params.tool_choice,
|
||||
)
|
||||
else:
|
||||
# Default to Llama 3.1 style
|
||||
return _augment_raw_messages_for_tools_llama_3_1(
|
||||
raw_messages,
|
||||
params.tools,
|
||||
params.tool_choice,
|
||||
)
|
||||
|
||||
|
||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||
return LlamaGenerator(config, model_id, llama_model)
|
||||
|
||||
|
|
@ -136,17 +315,20 @@ class MetaReferenceInferenceImpl(
|
|||
self.llama_model = llama_model
|
||||
|
||||
log.info("Warming up...")
|
||||
|
||||
await self.openai_chat_completion(
|
||||
model=model_id,
|
||||
messages=[{"role": "user", "content": "Hi how are you?"}],
|
||||
max_tokens=20,
|
||||
params=OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=model_id,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")],
|
||||
max_tokens=20,
|
||||
)
|
||||
)
|
||||
log.info("Warmed up!")
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
"No available model yet, please register your requested model or add your model in the resources first"
|
||||
)
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
|
|
@ -155,4 +337,207 @@ class MetaReferenceInferenceImpl(
|
|||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
|
||||
self.check_model(params)
|
||||
|
||||
# Convert OpenAI messages to RawMessages
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_openai_message_to_raw_message,
|
||||
decode_assistant_message,
|
||||
)
|
||||
|
||||
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
|
||||
|
||||
# Augment messages with tool definitions if tools are present
|
||||
raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model)
|
||||
|
||||
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
|
||||
if isinstance(self.generator, LlamaGenerator):
|
||||
generator = self.generator.chat_completion(params, raw_messages)
|
||||
else:
|
||||
# Model parallel: submit task to process group
|
||||
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
|
||||
|
||||
# Check if streaming is requested
|
||||
if params.stream:
|
||||
return self._stream_chat_completion(generator, params)
|
||||
|
||||
# Non-streaming: collect all generated text
|
||||
generated_text = ""
|
||||
for result_batch in generator:
|
||||
for result in result_batch:
|
||||
if not result.ignore_token and result.source == "output":
|
||||
generated_text += result.text
|
||||
|
||||
# Decode assistant message to extract tool calls and determine stop_reason
|
||||
# Default to end_of_turn if generation completed normally
|
||||
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
|
||||
|
||||
# Convert tool calls to OpenAI format
|
||||
openai_tool_calls = None
|
||||
if decoded_message.tool_calls:
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
)
|
||||
|
||||
openai_tool_calls = [
|
||||
OpenAIChatCompletionToolCall(
|
||||
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
type="function",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=str(tc.tool_name),
|
||||
arguments=tc.arguments,
|
||||
),
|
||||
)
|
||||
for tc in decoded_message.tool_calls
|
||||
]
|
||||
|
||||
# Determine finish_reason based on whether tool calls are present
|
||||
finish_reason = "tool_calls" if openai_tool_calls else "stop"
|
||||
|
||||
# Extract content from decoded message
|
||||
content = ""
|
||||
if isinstance(decoded_message.content, str):
|
||||
content = decoded_message.content
|
||||
elif isinstance(decoded_message.content, list):
|
||||
for item in decoded_message.content:
|
||||
if isinstance(item, RawTextItem):
|
||||
content += item.text
|
||||
|
||||
# Create OpenAI response
|
||||
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
||||
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||
created = int(time.time())
|
||||
|
||||
return OpenAIChatCompletion(
|
||||
id=response_id,
|
||||
object="chat.completion",
|
||||
created=created,
|
||||
model=params.model,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
index=0,
|
||||
message=OpenAIAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=openai_tool_calls,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=OpenAIChatCompletionUsage(
|
||||
prompt_tokens=0, # TODO: calculate properly
|
||||
completion_tokens=0, # TODO: calculate properly
|
||||
total_tokens=0, # TODO: calculate properly
|
||||
),
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self,
|
||||
generator,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Stream chat completion chunks as they're generated."""
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoiceDelta,
|
||||
OpenAIChunkChoice,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
|
||||
|
||||
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||
created = int(time.time())
|
||||
generated_text = ""
|
||||
|
||||
# Yield chunks as tokens are generated
|
||||
for result_batch in generator:
|
||||
for result in result_batch:
|
||||
if result.ignore_token or result.source != "output":
|
||||
continue
|
||||
|
||||
generated_text += result.text
|
||||
|
||||
# Yield delta chunk with the new text
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=response_id,
|
||||
object="chat.completion.chunk",
|
||||
created=created,
|
||||
model=params.model,
|
||||
choices=[
|
||||
OpenAIChunkChoice(
|
||||
index=0,
|
||||
delta=OpenAIChoiceDelta(
|
||||
role="assistant",
|
||||
content=result.text,
|
||||
),
|
||||
finish_reason="",
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
yield chunk
|
||||
|
||||
# After generation completes, decode the full message to extract tool calls
|
||||
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
|
||||
|
||||
# If tool calls are present, yield a final chunk with tool_calls
|
||||
if decoded_message.tool_calls:
|
||||
openai_tool_calls = [
|
||||
OpenAIChatCompletionToolCall(
|
||||
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
type="function",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=str(tc.tool_name),
|
||||
arguments=tc.arguments,
|
||||
),
|
||||
)
|
||||
for tc in decoded_message.tool_calls
|
||||
]
|
||||
|
||||
# Yield chunk with tool_calls
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=response_id,
|
||||
object="chat.completion.chunk",
|
||||
created=created,
|
||||
model=params.model,
|
||||
choices=[
|
||||
OpenAIChunkChoice(
|
||||
index=0,
|
||||
delta=OpenAIChoiceDelta(
|
||||
role="assistant",
|
||||
tool_calls=openai_tool_calls,
|
||||
),
|
||||
finish_reason="",
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
yield chunk
|
||||
|
||||
finish_reason = "tool_calls"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
|
||||
# Yield final chunk with finish_reason
|
||||
final_chunk = OpenAIChatCompletionChunk(
|
||||
id=response_id,
|
||||
object="chat.completion.chunk",
|
||||
created=created,
|
||||
model=params.model,
|
||||
choices=[
|
||||
OpenAIChunkChoice(
|
||||
index=0,
|
||||
delta=OpenAIChoiceDelta(),
|
||||
finish_reason=finish_reason,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
yield final_chunk
|
||||
|
|
|
|||
|
|
@ -4,17 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable, Generator
|
||||
from copy import deepcopy
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
|
@ -23,12 +18,14 @@ class ModelRunner:
|
|||
def __init__(self, llama):
|
||||
self.llama = llama
|
||||
|
||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||
def __call__(self, task: Any):
|
||||
if task[0] == "chat_completion":
|
||||
return self.llama.chat_completion(task[1])
|
||||
task_type = task[0]
|
||||
if task_type == "chat_completion":
|
||||
# task[1] is [params, raw_messages]
|
||||
params, raw_messages = task[1]
|
||||
return self.llama.chat_completion(params, raw_messages)
|
||||
else:
|
||||
raise ValueError(f"Unexpected task type {task[0]}")
|
||||
raise ValueError(f"Unexpected task type {task_type}")
|
||||
|
||||
|
||||
def init_model_cb(
|
||||
|
|
@ -78,19 +75,3 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self.group.stop()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: list[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("completion", req_obj))
|
||||
yield from gen
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("chat_completion", req_obj))
|
||||
yield from gen
|
||||
|
|
|
|||
|
|
@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
|
@ -69,10 +65,7 @@ class CancelSentinel(BaseModel):
|
|||
|
||||
class TaskRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||
task: tuple[
|
||||
str,
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
]
|
||||
task: tuple[str, list]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
|
|
@ -328,10 +321,7 @@ class ModelParallelProcessGroup:
|
|||
|
||||
def run_inference(
|
||||
self,
|
||||
req: tuple[
|
||||
str,
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
],
|
||||
req: tuple[str, list],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue