mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 10:10:36 +00:00
wiprouters
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
357be98279
commit
8df9340dd3
155 changed files with 61817 additions and 95863 deletions
File diff suppressed because it is too large
Load diff
18640
docs/static/deprecated-llama-stack-spec.json
vendored
18640
docs/static/deprecated-llama-stack-spec.json
vendored
File diff suppressed because it is too large
Load diff
28360
docs/static/deprecated-llama-stack-spec.yaml
vendored
28360
docs/static/deprecated-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
5135
docs/static/experimental-llama-stack-spec.json
vendored
5135
docs/static/experimental-llama-stack-spec.json
vendored
File diff suppressed because it is too large
Load diff
3937
docs/static/experimental-llama-stack-spec.yaml
vendored
3937
docs/static/experimental-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
11541
docs/static/llama-stack-spec.json
vendored
11541
docs/static/llama-stack-spec.json
vendored
File diff suppressed because it is too large
Load diff
14957
docs/static/llama-stack-spec.yaml
vendored
14957
docs/static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
16303
docs/static/stainless-llama-stack-spec.json
vendored
16303
docs/static/stainless-llama-stack-spec.json
vendored
File diff suppressed because it is too large
Load diff
20539
docs/static/stainless-llama-stack-spec.yaml
vendored
20539
docs/static/stainless-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -15,6 +15,7 @@ using multiple validation tools and approaches.
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -44,6 +45,8 @@ def validate_openapi_schema(schema: dict[str, Any], schema_name: str = "OpenAPI
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ {schema_name} validation error: {e}")
|
print(f"❌ {schema_name} validation error: {e}")
|
||||||
|
print(" Traceback:")
|
||||||
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,4 +4,108 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .agents import *
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .agents_service import AgentsService
|
||||||
|
from .models import (
|
||||||
|
Agent,
|
||||||
|
AgentConfig,
|
||||||
|
AgentConfigCommon,
|
||||||
|
AgentConfigOverridablePerTurn,
|
||||||
|
AgentCreateResponse,
|
||||||
|
AgentSessionCreateResponse,
|
||||||
|
AgentStepResponse,
|
||||||
|
AgentToolGroup,
|
||||||
|
AgentToolGroupWithArgs,
|
||||||
|
AgentTurnCreateRequest,
|
||||||
|
AgentTurnResponseEvent,
|
||||||
|
AgentTurnResponseEventPayload,
|
||||||
|
AgentTurnResponseEventType,
|
||||||
|
AgentTurnResponseStepCompletePayload,
|
||||||
|
AgentTurnResponseStepProgressPayload,
|
||||||
|
AgentTurnResponseStepStartPayload,
|
||||||
|
AgentTurnResponseStreamChunk,
|
||||||
|
AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
|
AgentTurnResponseTurnCompletePayload,
|
||||||
|
AgentTurnResponseTurnStartPayload,
|
||||||
|
AgentTurnResumeRequest,
|
||||||
|
Attachment,
|
||||||
|
CreateAgentSessionRequest,
|
||||||
|
CreateOpenAIResponseRequest,
|
||||||
|
Document,
|
||||||
|
InferenceStep,
|
||||||
|
MemoryRetrievalStep,
|
||||||
|
ResponseGuardrail,
|
||||||
|
ResponseGuardrailSpec,
|
||||||
|
Session,
|
||||||
|
ShieldCallStep,
|
||||||
|
Step,
|
||||||
|
StepCommon,
|
||||||
|
StepType,
|
||||||
|
ToolExecutionStep,
|
||||||
|
Turn,
|
||||||
|
)
|
||||||
|
from .openai_responses import (
|
||||||
|
ListOpenAIResponseInputItem,
|
||||||
|
ListOpenAIResponseObject,
|
||||||
|
OpenAIDeleteResponseObject,
|
||||||
|
OpenAIResponseInput,
|
||||||
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponseObject,
|
||||||
|
OpenAIResponseObjectStream,
|
||||||
|
OpenAIResponsePrompt,
|
||||||
|
OpenAIResponseText,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backward compatibility - export Agents as alias for AgentsService
|
||||||
|
Agents = AgentsService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Agents",
|
||||||
|
"AgentsService",
|
||||||
|
"Agent",
|
||||||
|
"AgentConfig",
|
||||||
|
"AgentConfigCommon",
|
||||||
|
"AgentConfigOverridablePerTurn",
|
||||||
|
"AgentCreateResponse",
|
||||||
|
"AgentSessionCreateResponse",
|
||||||
|
"AgentStepResponse",
|
||||||
|
"AgentToolGroup",
|
||||||
|
"AgentToolGroupWithArgs",
|
||||||
|
"AgentTurnCreateRequest",
|
||||||
|
"AgentTurnResumeRequest",
|
||||||
|
"AgentTurnResponseEvent",
|
||||||
|
"AgentTurnResponseEventPayload",
|
||||||
|
"AgentTurnResponseEventType",
|
||||||
|
"AgentTurnResponseStepCompletePayload",
|
||||||
|
"AgentTurnResponseStepProgressPayload",
|
||||||
|
"AgentTurnResponseStepStartPayload",
|
||||||
|
"AgentTurnResponseStreamChunk",
|
||||||
|
"AgentTurnResponseTurnAwaitingInputPayload",
|
||||||
|
"AgentTurnResponseTurnCompletePayload",
|
||||||
|
"AgentTurnResponseTurnStartPayload",
|
||||||
|
"Attachment",
|
||||||
|
"CreateAgentSessionRequest",
|
||||||
|
"CreateOpenAIResponseRequest",
|
||||||
|
"Document",
|
||||||
|
"InferenceStep",
|
||||||
|
"MemoryRetrievalStep",
|
||||||
|
"ResponseGuardrail",
|
||||||
|
"ResponseGuardrailSpec",
|
||||||
|
"Session",
|
||||||
|
"ShieldCallStep",
|
||||||
|
"Step",
|
||||||
|
"StepCommon",
|
||||||
|
"StepType",
|
||||||
|
"ToolExecutionStep",
|
||||||
|
"Turn",
|
||||||
|
"ListOpenAIResponseInputItem",
|
||||||
|
"ListOpenAIResponseObject",
|
||||||
|
"OpenAIDeleteResponseObject",
|
||||||
|
"OpenAIResponseInput",
|
||||||
|
"OpenAIResponseInputTool",
|
||||||
|
"OpenAIResponseObject",
|
||||||
|
"OpenAIResponseObjectStream",
|
||||||
|
"OpenAIResponsePrompt",
|
||||||
|
"OpenAIResponseText",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,814 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
|
||||||
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
CompletionMessage,
|
|
||||||
ResponseFormat,
|
|
||||||
SamplingParams,
|
|
||||||
ToolCall,
|
|
||||||
ToolChoice,
|
|
||||||
ToolConfig,
|
|
||||||
ToolPromptFormat,
|
|
||||||
ToolResponse,
|
|
||||||
ToolResponseMessage,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
|
||||||
from llama_stack.apis.tools import ToolDef
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
|
||||||
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
|
|
||||||
|
|
||||||
from .openai_responses import (
|
|
||||||
ListOpenAIResponseInputItem,
|
|
||||||
ListOpenAIResponseObject,
|
|
||||||
OpenAIDeleteResponseObject,
|
|
||||||
OpenAIResponseInput,
|
|
||||||
OpenAIResponseInputTool,
|
|
||||||
OpenAIResponseObject,
|
|
||||||
OpenAIResponseObjectStream,
|
|
||||||
OpenAIResponsePrompt,
|
|
||||||
OpenAIResponseText,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ResponseGuardrailSpec(BaseModel):
|
|
||||||
"""Specification for a guardrail to apply during response generation.
|
|
||||||
|
|
||||||
:param type: The type/identifier of the guardrail.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: str
|
|
||||||
# TODO: more fields to be added for guardrail configuration
|
|
||||||
|
|
||||||
|
|
||||||
ResponseGuardrail = str | ResponseGuardrailSpec
|
|
||||||
|
|
||||||
|
|
||||||
class Attachment(BaseModel):
|
|
||||||
"""An attachment to an agent turn.
|
|
||||||
|
|
||||||
:param content: The content of the attachment.
|
|
||||||
:param mime_type: The MIME type of the attachment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
content: InterleavedContent | URL
|
|
||||||
mime_type: str
|
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel):
|
|
||||||
"""A document to be used by an agent.
|
|
||||||
|
|
||||||
:param content: The content of the document.
|
|
||||||
:param mime_type: The MIME type of the document.
|
|
||||||
"""
|
|
||||||
|
|
||||||
content: InterleavedContent | URL
|
|
||||||
mime_type: str
|
|
||||||
|
|
||||||
|
|
||||||
class StepCommon(BaseModel):
|
|
||||||
"""A common step in an agent turn.
|
|
||||||
|
|
||||||
:param turn_id: The ID of the turn.
|
|
||||||
:param step_id: The ID of the step.
|
|
||||||
:param started_at: The time the step started.
|
|
||||||
:param completed_at: The time the step completed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
turn_id: str
|
|
||||||
step_id: str
|
|
||||||
started_at: datetime | None = None
|
|
||||||
completed_at: datetime | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class StepType(StrEnum):
|
|
||||||
"""Type of the step in an agent turn.
|
|
||||||
|
|
||||||
:cvar inference: The step is an inference step that calls an LLM.
|
|
||||||
:cvar tool_execution: The step is a tool execution step that executes a tool call.
|
|
||||||
:cvar shield_call: The step is a shield call step that checks for safety violations.
|
|
||||||
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
inference = "inference"
|
|
||||||
tool_execution = "tool_execution"
|
|
||||||
shield_call = "shield_call"
|
|
||||||
memory_retrieval = "memory_retrieval"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class InferenceStep(StepCommon):
|
|
||||||
"""An inference step in an agent turn.
|
|
||||||
|
|
||||||
:param model_response: The response from the LLM.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
step_type: Literal[StepType.inference] = StepType.inference
|
|
||||||
model_response: CompletionMessage
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolExecutionStep(StepCommon):
|
|
||||||
"""A tool execution step in an agent turn.
|
|
||||||
|
|
||||||
:param tool_calls: The tool calls to execute.
|
|
||||||
:param tool_responses: The tool responses from the tool calls.
|
|
||||||
"""
|
|
||||||
|
|
||||||
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
|
|
||||||
tool_calls: list[ToolCall]
|
|
||||||
tool_responses: list[ToolResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ShieldCallStep(StepCommon):
|
|
||||||
"""A shield call step in an agent turn.
|
|
||||||
|
|
||||||
:param violation: The violation from the shield call.
|
|
||||||
"""
|
|
||||||
|
|
||||||
step_type: Literal[StepType.shield_call] = StepType.shield_call
|
|
||||||
violation: SafetyViolation | None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryRetrievalStep(StepCommon):
|
|
||||||
"""A memory retrieval step in an agent turn.
|
|
||||||
|
|
||||||
:param vector_store_ids: The IDs of the vector databases to retrieve context from.
|
|
||||||
:param inserted_context: The context retrieved from the vector databases.
|
|
||||||
"""
|
|
||||||
|
|
||||||
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
|
|
||||||
# TODO: should this be List[str]?
|
|
||||||
vector_store_ids: str
|
|
||||||
inserted_context: InterleavedContent
|
|
||||||
|
|
||||||
|
|
||||||
Step = Annotated[
|
|
||||||
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
|
|
||||||
Field(discriminator="step_type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Turn(BaseModel):
|
|
||||||
"""A single turn in an interaction with an Agentic System.
|
|
||||||
|
|
||||||
:param turn_id: Unique identifier for the turn within a session
|
|
||||||
:param session_id: Unique identifier for the conversation session
|
|
||||||
:param input_messages: List of messages that initiated this turn
|
|
||||||
:param steps: Ordered list of processing steps executed during this turn
|
|
||||||
:param output_message: The model's generated response containing content and metadata
|
|
||||||
:param output_attachments: (Optional) Files or media attached to the agent's response
|
|
||||||
:param started_at: Timestamp when the turn began
|
|
||||||
:param completed_at: (Optional) Timestamp when the turn finished, if completed
|
|
||||||
"""
|
|
||||||
|
|
||||||
turn_id: str
|
|
||||||
session_id: str
|
|
||||||
input_messages: list[UserMessage | ToolResponseMessage]
|
|
||||||
steps: list[Step]
|
|
||||||
output_message: CompletionMessage
|
|
||||||
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
|
|
||||||
|
|
||||||
started_at: datetime
|
|
||||||
completed_at: datetime | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Session(BaseModel):
|
|
||||||
"""A single session of an interaction with an Agentic System.
|
|
||||||
|
|
||||||
:param session_id: Unique identifier for the conversation session
|
|
||||||
:param session_name: Human-readable name for the session
|
|
||||||
:param turns: List of all turns that have occurred in this session
|
|
||||||
:param started_at: Timestamp when the session was created
|
|
||||||
"""
|
|
||||||
|
|
||||||
session_id: str
|
|
||||||
session_name: str
|
|
||||||
turns: list[Turn]
|
|
||||||
started_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
class AgentToolGroupWithArgs(BaseModel):
|
|
||||||
name: str
|
|
||||||
args: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
AgentToolGroup = str | AgentToolGroupWithArgs
|
|
||||||
register_schema(AgentToolGroup, name="AgentTool")
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
|
||||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
|
||||||
|
|
||||||
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
|
||||||
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
|
||||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
|
||||||
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
|
||||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
|
||||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
|
||||||
tool_config: ToolConfig | None = Field(default=None)
|
|
||||||
|
|
||||||
max_infer_iters: int | None = 10
|
|
||||||
|
|
||||||
def model_post_init(self, __context):
|
|
||||||
if self.tool_config:
|
|
||||||
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
|
|
||||||
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
|
||||||
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
|
||||||
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
|
||||||
else:
|
|
||||||
params = {}
|
|
||||||
if self.tool_choice:
|
|
||||||
params["tool_choice"] = self.tool_choice
|
|
||||||
if self.tool_prompt_format:
|
|
||||||
params["tool_prompt_format"] = self.tool_prompt_format
|
|
||||||
self.tool_config = ToolConfig(**params)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentConfig(AgentConfigCommon):
|
|
||||||
"""Configuration for an agent.
|
|
||||||
|
|
||||||
:param model: The model identifier to use for the agent
|
|
||||||
:param instructions: The system instructions for the agent
|
|
||||||
:param name: Optional name for the agent, used in telemetry and identification
|
|
||||||
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
|
|
||||||
:param response_format: Optional response format configuration
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: str
|
|
||||||
instructions: str
|
|
||||||
name: str | None = None
|
|
||||||
enable_session_persistence: bool | None = False
|
|
||||||
response_format: ResponseFormat | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Agent(BaseModel):
|
|
||||||
"""An agent instance with configuration and metadata.
|
|
||||||
|
|
||||||
:param agent_id: Unique identifier for the agent
|
|
||||||
:param agent_config: Configuration settings for the agent
|
|
||||||
:param created_at: Timestamp when the agent was created
|
|
||||||
"""
|
|
||||||
|
|
||||||
agent_id: str
|
|
||||||
agent_config: AgentConfig
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
|
||||||
instructions: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class AgentTurnResponseEventType(StrEnum):
|
|
||||||
step_start = "step_start"
|
|
||||||
step_complete = "step_complete"
|
|
||||||
step_progress = "step_progress"
|
|
||||||
|
|
||||||
turn_start = "turn_start"
|
|
||||||
turn_complete = "turn_complete"
|
|
||||||
turn_awaiting_input = "turn_awaiting_input"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
|
||||||
"""Payload for step start events in agent turn responses.
|
|
||||||
|
|
||||||
:param event_type: Type of event being reported
|
|
||||||
:param step_type: Type of step being executed
|
|
||||||
:param step_id: Unique identifier for the step within a turn
|
|
||||||
:param metadata: (Optional) Additional metadata for the step
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
|
|
||||||
step_type: StepType
|
|
||||||
step_id: str
|
|
||||||
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
|
||||||
"""Payload for step completion events in agent turn responses.
|
|
||||||
|
|
||||||
:param event_type: Type of event being reported
|
|
||||||
:param step_type: Type of step being executed
|
|
||||||
:param step_id: Unique identifier for the step within a turn
|
|
||||||
:param step_details: Complete details of the executed step
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
|
|
||||||
step_type: StepType
|
|
||||||
step_id: str
|
|
||||||
step_details: Step
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
|
||||||
"""Payload for step progress events in agent turn responses.
|
|
||||||
|
|
||||||
:param event_type: Type of event being reported
|
|
||||||
:param step_type: Type of step being executed
|
|
||||||
:param step_id: Unique identifier for the step within a turn
|
|
||||||
:param delta: Incremental content changes during step execution
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
|
|
||||||
step_type: StepType
|
|
||||||
step_id: str
|
|
||||||
|
|
||||||
delta: ContentDelta
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
|
||||||
"""Payload for turn start events in agent turn responses.
|
|
||||||
|
|
||||||
:param event_type: Type of event being reported
|
|
||||||
:param turn_id: Unique identifier for the turn within a session
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
|
|
||||||
turn_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
|
||||||
"""Payload for turn completion events in agent turn responses.
|
|
||||||
|
|
||||||
:param event_type: Type of event being reported
|
|
||||||
:param turn: Complete turn data including all steps and results
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
|
|
||||||
turn: Turn
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
|
||||||
"""Payload for turn awaiting input events in agent turn responses.
|
|
||||||
|
|
||||||
:param event_type: Type of event being reported
|
|
||||||
:param turn: Turn data when waiting for external tool responses
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
|
|
||||||
turn: Turn
|
|
||||||
|
|
||||||
|
|
||||||
AgentTurnResponseEventPayload = Annotated[
|
|
||||||
AgentTurnResponseStepStartPayload
|
|
||||||
| AgentTurnResponseStepProgressPayload
|
|
||||||
| AgentTurnResponseStepCompletePayload
|
|
||||||
| AgentTurnResponseTurnStartPayload
|
|
||||||
| AgentTurnResponseTurnCompletePayload
|
|
||||||
| AgentTurnResponseTurnAwaitingInputPayload,
|
|
||||||
Field(discriminator="event_type"),
|
|
||||||
]
|
|
||||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseEvent(BaseModel):
|
|
||||||
"""An event in an agent turn response stream.
|
|
||||||
|
|
||||||
:param payload: Event-specific payload containing event data
|
|
||||||
"""
|
|
||||||
|
|
||||||
payload: AgentTurnResponseEventPayload
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentCreateResponse(BaseModel):
|
|
||||||
"""Response returned when creating a new agent.
|
|
||||||
|
|
||||||
:param agent_id: Unique identifier for the created agent
|
|
||||||
"""
|
|
||||||
|
|
||||||
agent_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentSessionCreateResponse(BaseModel):
|
|
||||||
"""Response returned when creating a new agent session.
|
|
||||||
|
|
||||||
:param session_id: Unique identifier for the created session
|
|
||||||
"""
|
|
||||||
|
|
||||||
session_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|
||||||
"""Request to create a new turn for an agent.
|
|
||||||
|
|
||||||
:param agent_id: Unique identifier for the agent
|
|
||||||
:param session_id: Unique identifier for the conversation session
|
|
||||||
:param messages: List of messages to start the turn with
|
|
||||||
:param documents: (Optional) List of documents to provide to the agent
|
|
||||||
:param toolgroups: (Optional) List of tool groups to make available for this turn
|
|
||||||
:param stream: (Optional) Whether to stream the response
|
|
||||||
:param tool_config: (Optional) Tool configuration to override agent defaults
|
|
||||||
"""
|
|
||||||
|
|
||||||
agent_id: str
|
|
||||||
session_id: str
|
|
||||||
|
|
||||||
# TODO: figure out how we can simplify this and make why
|
|
||||||
# ToolResponseMessage needs to be here (it is function call
|
|
||||||
# execution from outside the system)
|
|
||||||
messages: list[UserMessage | ToolResponseMessage]
|
|
||||||
|
|
||||||
documents: list[Document] | None = None
|
|
||||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
|
||||||
|
|
||||||
stream: bool | None = False
|
|
||||||
tool_config: ToolConfig | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResumeRequest(BaseModel):
|
|
||||||
"""Request to resume an agent turn with tool responses.
|
|
||||||
|
|
||||||
:param agent_id: Unique identifier for the agent
|
|
||||||
:param session_id: Unique identifier for the conversation session
|
|
||||||
:param turn_id: Unique identifier for the turn within a session
|
|
||||||
:param tool_responses: List of tool responses to submit to continue the turn
|
|
||||||
:param stream: (Optional) Whether to stream the response
|
|
||||||
"""
|
|
||||||
|
|
||||||
agent_id: str
|
|
||||||
session_id: str
|
|
||||||
turn_id: str
|
|
||||||
tool_responses: list[ToolResponse]
|
|
||||||
stream: bool | None = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseStreamChunk(BaseModel):
|
|
||||||
"""Streamed agent turn completion response.
|
|
||||||
|
|
||||||
:param event: Individual event in the agent turn response stream
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: AgentTurnResponseEvent
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentStepResponse(BaseModel):
|
|
||||||
"""Response containing details of a specific agent step.
|
|
||||||
|
|
||||||
:param step: The complete step data and execution details
|
|
||||||
"""
|
|
||||||
|
|
||||||
step: Step
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class Agents(Protocol):
|
|
||||||
"""Agents
|
|
||||||
|
|
||||||
APIs for creating and interacting with agentic systems."""
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents",
|
|
||||||
method="POST",
|
|
||||||
descriptive_name="create_agent",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def create_agent(
|
|
||||||
self,
|
|
||||||
agent_config: AgentConfig,
|
|
||||||
) -> AgentCreateResponse:
|
|
||||||
"""Create an agent with the given configuration.
|
|
||||||
|
|
||||||
:param agent_config: The configuration for the agent.
|
|
||||||
:returns: An AgentCreateResponse with the agent ID.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn",
|
|
||||||
method="POST",
|
|
||||||
descriptive_name="create_agent_turn",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def create_agent_turn(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
messages: list[UserMessage | ToolResponseMessage],
|
|
||||||
stream: bool | None = False,
|
|
||||||
documents: list[Document] | None = None,
|
|
||||||
toolgroups: list[AgentToolGroup] | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
|
||||||
"""Create a new turn for an agent.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to create the turn for.
|
|
||||||
:param session_id: The ID of the session to create the turn for.
|
|
||||||
:param messages: List of messages to start the turn with.
|
|
||||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
|
||||||
:param documents: (Optional) List of documents to create the turn with.
|
|
||||||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
|
||||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
|
||||||
:returns: If stream=False, returns a Turn object.
|
|
||||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
|
||||||
method="POST",
|
|
||||||
descriptive_name="resume_agent_turn",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def resume_agent_turn(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
turn_id: str,
|
|
||||||
tool_responses: list[ToolResponse],
|
|
||||||
stream: bool | None = False,
|
|
||||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
|
||||||
"""Resume an agent turn with executed tool call responses.
|
|
||||||
|
|
||||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to resume.
|
|
||||||
:param session_id: The ID of the session to resume.
|
|
||||||
:param turn_id: The ID of the turn to resume.
|
|
||||||
:param tool_responses: The tool call responses to resume the turn with.
|
|
||||||
:param stream: Whether to stream the response.
|
|
||||||
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def get_agents_turn(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
turn_id: str,
|
|
||||||
) -> Turn:
|
|
||||||
"""Retrieve an agent turn by its ID.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to get the turn for.
|
|
||||||
:param session_id: The ID of the session to get the turn for.
|
|
||||||
:param turn_id: The ID of the turn to get.
|
|
||||||
:returns: A Turn.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def get_agents_step(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
turn_id: str,
|
|
||||||
step_id: str,
|
|
||||||
) -> AgentStepResponse:
|
|
||||||
"""Retrieve an agent step by its ID.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to get the step for.
|
|
||||||
:param session_id: The ID of the session to get the step for.
|
|
||||||
:param turn_id: The ID of the turn to get the step for.
|
|
||||||
:param step_id: The ID of the step to get.
|
|
||||||
:returns: An AgentStepResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session",
|
|
||||||
method="POST",
|
|
||||||
descriptive_name="create_agent_session",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def create_agent_session(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_name: str,
|
|
||||||
) -> AgentSessionCreateResponse:
|
|
||||||
"""Create a new session for an agent.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to create the session for.
|
|
||||||
:param session_name: The name of the session to create.
|
|
||||||
:returns: An AgentSessionCreateResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def get_agents_session(
|
|
||||||
self,
|
|
||||||
session_id: str,
|
|
||||||
agent_id: str,
|
|
||||||
turn_ids: list[str] | None = None,
|
|
||||||
) -> Session:
|
|
||||||
"""Retrieve an agent session by its ID.
|
|
||||||
|
|
||||||
:param session_id: The ID of the session to get.
|
|
||||||
:param agent_id: The ID of the agent to get the session for.
|
|
||||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
|
||||||
:returns: A Session.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}",
|
|
||||||
method="DELETE",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def delete_agents_session(
|
|
||||||
self,
|
|
||||||
session_id: str,
|
|
||||||
agent_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Delete an agent session by its ID and its associated turns.
|
|
||||||
|
|
||||||
:param session_id: The ID of the session to delete.
|
|
||||||
:param agent_id: The ID of the agent to delete the session for.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def delete_agent(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Delete an agent by its ID and its associated sessions and turns.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to delete.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
|
||||||
"""List all agents.
|
|
||||||
|
|
||||||
:param start_index: The index to start the pagination from.
|
|
||||||
:param limit: The number of agents to return.
|
|
||||||
:returns: A PaginatedResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def get_agent(self, agent_id: str) -> Agent:
|
|
||||||
"""Describe an agent by its ID.
|
|
||||||
|
|
||||||
:param agent_id: ID of the agent.
|
|
||||||
:returns: An Agent of the agent.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def list_agent_sessions(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
start_index: int | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
) -> PaginatedResponse:
|
|
||||||
"""List all session(s) of a given agent.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to list sessions for.
|
|
||||||
:param start_index: The index to start the pagination from.
|
|
||||||
:param limit: The number of sessions to return.
|
|
||||||
:returns: A PaginatedResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
# We situate the OpenAI Responses API in the Agents API just like we did things
|
|
||||||
# for Inference. The Responses API, in its intent, serves the same purpose as
|
|
||||||
# the Agents API above -- it is essentially a lightweight "agentic loop" with
|
|
||||||
# integrated tool calling.
|
|
||||||
#
|
|
||||||
# Both of these APIs are inherently stateful.
|
|
||||||
|
|
||||||
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def get_openai_response(
|
|
||||||
self,
|
|
||||||
response_id: str,
|
|
||||||
) -> OpenAIResponseObject:
|
|
||||||
"""Get a model response.
|
|
||||||
|
|
||||||
:param response_id: The ID of the OpenAI response to retrieve.
|
|
||||||
:returns: An OpenAIResponseObject.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def create_openai_response(
|
|
||||||
self,
|
|
||||||
input: str | list[OpenAIResponseInput],
|
|
||||||
model: str,
|
|
||||||
prompt: OpenAIResponsePrompt | None = None,
|
|
||||||
instructions: str | None = None,
|
|
||||||
previous_response_id: str | None = None,
|
|
||||||
conversation: str | None = None,
|
|
||||||
store: bool | None = True,
|
|
||||||
stream: bool | None = False,
|
|
||||||
temperature: float | None = None,
|
|
||||||
text: OpenAIResponseText | None = None,
|
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
|
||||||
include: list[str] | None = None,
|
|
||||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
|
||||||
guardrails: Annotated[
|
|
||||||
list[ResponseGuardrail] | None,
|
|
||||||
ExtraBodyField(
|
|
||||||
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
|
|
||||||
),
|
|
||||||
] = None,
|
|
||||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
|
||||||
"""Create a model response.
|
|
||||||
|
|
||||||
:param input: Input message(s) to create the response.
|
|
||||||
:param model: The underlying LLM used for completions.
|
|
||||||
:param prompt: (Optional) Prompt object with ID, version, and variables.
|
|
||||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
|
||||||
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
|
|
||||||
:param include: (Optional) Additional fields to include in the response.
|
|
||||||
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
|
|
||||||
:returns: An OpenAIResponseObject.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def list_openai_responses(
|
|
||||||
self,
|
|
||||||
after: str | None = None,
|
|
||||||
limit: int | None = 50,
|
|
||||||
model: str | None = None,
|
|
||||||
order: Order | None = Order.desc,
|
|
||||||
) -> ListOpenAIResponseObject:
|
|
||||||
"""List all responses.
|
|
||||||
|
|
||||||
:param after: The ID of the last response to return.
|
|
||||||
:param limit: The number of responses to return.
|
|
||||||
:param model: The model to filter responses by.
|
|
||||||
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
|
|
||||||
:returns: A ListOpenAIResponseObject.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def list_openai_response_input_items(
|
|
||||||
self,
|
|
||||||
response_id: str,
|
|
||||||
after: str | None = None,
|
|
||||||
before: str | None = None,
|
|
||||||
include: list[str] | None = None,
|
|
||||||
limit: int | None = 20,
|
|
||||||
order: Order | None = Order.desc,
|
|
||||||
) -> ListOpenAIResponseInputItem:
|
|
||||||
"""List input items.
|
|
||||||
|
|
||||||
:param response_id: The ID of the response to retrieve input items for.
|
|
||||||
:param after: An item ID to list items after, used for pagination.
|
|
||||||
:param before: An item ID to list items before, used for pagination.
|
|
||||||
:param include: Additional fields to include in the response.
|
|
||||||
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
|
||||||
:param order: The order to return the input items in. Default is desc.
|
|
||||||
:returns: An ListOpenAIResponseInputItem.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
|
||||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
|
||||||
"""Delete a response.
|
|
||||||
|
|
||||||
:param response_id: The ID of the OpenAI response to delete.
|
|
||||||
:returns: An OpenAIDeleteResponseObject
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
308
src/llama_stack/apis/agents/agents_service.py
Normal file
308
src/llama_stack/apis/agents/agents_service.py
Normal file
|
|
@ -0,0 +1,308 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Annotated, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
||||||
|
from llama_stack.apis.inference import ToolConfig, ToolResponse, ToolResponseMessage, UserMessage
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import ExtraBodyField
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
Agent,
|
||||||
|
AgentConfig,
|
||||||
|
AgentCreateResponse,
|
||||||
|
AgentSessionCreateResponse,
|
||||||
|
AgentStepResponse,
|
||||||
|
AgentToolGroup,
|
||||||
|
AgentTurnResponseStreamChunk,
|
||||||
|
Document,
|
||||||
|
ResponseGuardrail,
|
||||||
|
Session,
|
||||||
|
Turn,
|
||||||
|
)
|
||||||
|
from .openai_responses import (
|
||||||
|
ListOpenAIResponseInputItem,
|
||||||
|
ListOpenAIResponseObject,
|
||||||
|
OpenAIDeleteResponseObject,
|
||||||
|
OpenAIResponseInput,
|
||||||
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponseObject,
|
||||||
|
OpenAIResponseObjectStream,
|
||||||
|
OpenAIResponsePrompt,
|
||||||
|
OpenAIResponseText,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class AgentsService(Protocol):
|
||||||
|
"""Agents
|
||||||
|
|
||||||
|
APIs for creating and interacting with agentic systems."""
|
||||||
|
|
||||||
|
async def create_agent(
|
||||||
|
self,
|
||||||
|
agent_config: AgentConfig,
|
||||||
|
) -> AgentCreateResponse:
|
||||||
|
"""Create an agent with the given configuration.
|
||||||
|
|
||||||
|
:param agent_config: The configuration for the agent.
|
||||||
|
:returns: An AgentCreateResponse with the agent ID.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def create_agent_turn(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
messages: list[UserMessage | ToolResponseMessage],
|
||||||
|
stream: bool | None = False,
|
||||||
|
documents: list[Document] | None = None,
|
||||||
|
toolgroups: list[AgentToolGroup] | None = None,
|
||||||
|
tool_config: ToolConfig | None = None,
|
||||||
|
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||||
|
"""Create a new turn for an agent.
|
||||||
|
|
||||||
|
:param agent_id: The ID of the agent to create the turn for.
|
||||||
|
:param session_id: The ID of the session to create the turn for.
|
||||||
|
:param messages: List of messages to start the turn with.
|
||||||
|
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||||
|
:param documents: (Optional) List of documents to create the turn with.
|
||||||
|
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
||||||
|
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
||||||
|
:returns: If stream=False, returns a Turn object.
|
||||||
|
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def resume_agent_turn(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
tool_responses: list[ToolResponse],
|
||||||
|
stream: bool | None = False,
|
||||||
|
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||||
|
"""Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
|
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||||
|
|
||||||
|
:param agent_id: The ID of the agent to resume.
|
||||||
|
:param session_id: The ID of the session to resume.
|
||||||
|
:param turn_id: The ID of the turn to resume.
|
||||||
|
:param tool_responses: The tool call responses to resume the turn with.
|
||||||
|
:param stream: Whether to stream the response.
|
||||||
|
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_agents_turn(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
) -> Turn:
|
||||||
|
"""Retrieve an agent turn by its ID.
|
||||||
|
|
||||||
|
:param agent_id: The ID of the agent to get the turn for.
|
||||||
|
:param session_id: The ID of the session to get the turn for.
|
||||||
|
:param turn_id: The ID of the turn to get.
|
||||||
|
:returns: A Turn.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_agents_step(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
step_id: str,
|
||||||
|
) -> AgentStepResponse:
|
||||||
|
"""Retrieve an agent step by its ID.
|
||||||
|
|
||||||
|
:param agent_id: The ID of the agent to get the step for.
|
||||||
|
:param session_id: The ID of the session to get the step for.
|
||||||
|
:param turn_id: The ID of the turn to get the step for.
|
||||||
|
:param step_id: The ID of the step to get.
|
||||||
|
:returns: An AgentStepResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def create_agent_session(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_name: str,
|
||||||
|
) -> AgentSessionCreateResponse:
|
||||||
|
"""Create a new session for an agent.
|
||||||
|
|
||||||
|
:param agent_id: The ID of the agent to create the session for.
|
||||||
|
:param session_name: The name of the session to create.
|
||||||
|
:returns: An AgentSessionCreateResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_agents_session(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
turn_ids: list[str] | None = None,
|
||||||
|
) -> Session:
|
||||||
|
"""Retrieve an agent session by its ID.
|
||||||
|
|
||||||
|
:param session_id: The ID of the session to get.
|
||||||
|
:param agent_id: The ID of the agent to get the session for.
|
||||||
|
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
||||||
|
:returns: A Session.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def delete_agents_session(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Delete an agent session by its ID and its associated turns.
|
||||||
|
|
||||||
|
:param session_id: The ID of the session to delete.
|
||||||
|
:param agent_id: The ID of the agent to delete the session for.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def delete_agent(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Delete an agent by its ID and its associated sessions and turns.
|
||||||
|
|
||||||
|
:param agent_id: The ID of the agent to delete.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||||
|
"""List all agents.
|
||||||
|
|
||||||
|
:param start_index: The index to start the pagination from.
|
||||||
|
:param limit: The number of agents to return.
|
||||||
|
:returns: A PaginatedResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_agent(self, agent_id: str) -> Agent:
|
||||||
|
"""Describe an agent by its ID.
|
||||||
|
|
||||||
|
:param agent_id: ID of the agent.
|
||||||
|
:returns: An Agent of the agent.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def list_agent_sessions(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
start_index: int | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> PaginatedResponse:
|
||||||
|
"""List all session(s) of a given agent.
|
||||||
|
|
||||||
|
:param agent_id: The ID of the agent to list sessions for.
|
||||||
|
:param start_index: The index to start the pagination from.
|
||||||
|
:param limit: The number of sessions to return.
|
||||||
|
:returns: A PaginatedResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_openai_response(
|
||||||
|
self,
|
||||||
|
response_id: str,
|
||||||
|
) -> OpenAIResponseObject:
|
||||||
|
"""Get a model response.
|
||||||
|
|
||||||
|
:param response_id: The ID of the OpenAI response to retrieve.
|
||||||
|
:returns: An OpenAIResponseObject.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def create_openai_response(
|
||||||
|
self,
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
model: str,
|
||||||
|
prompt: OpenAIResponsePrompt | None = None,
|
||||||
|
instructions: str | None = None,
|
||||||
|
previous_response_id: str | None = None,
|
||||||
|
conversation: str | None = None,
|
||||||
|
store: bool | None = True,
|
||||||
|
stream: bool | None = False,
|
||||||
|
temperature: float | None = None,
|
||||||
|
text: OpenAIResponseText | None = None,
|
||||||
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
|
include: list[str] | None = None,
|
||||||
|
max_infer_iters: int | None = 10,
|
||||||
|
guardrails: Annotated[
|
||||||
|
list[ResponseGuardrail] | None,
|
||||||
|
ExtraBodyField(
|
||||||
|
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
"""Create a model response.
|
||||||
|
|
||||||
|
:param input: Input message(s) to create the response.
|
||||||
|
:param model: The underlying LLM used for completions.
|
||||||
|
:param prompt: (Optional) Prompt object with ID, version, and variables.
|
||||||
|
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||||
|
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
|
||||||
|
:param include: (Optional) Additional fields to include in the response.
|
||||||
|
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
|
||||||
|
:returns: An OpenAIResponseObject.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def list_openai_responses(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 50,
|
||||||
|
model: str | None = None,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIResponseObject:
|
||||||
|
"""List all responses.
|
||||||
|
|
||||||
|
:param after: The ID of the last response to return.
|
||||||
|
:param limit: The number of responses to return.
|
||||||
|
:param model: The model to filter responses by.
|
||||||
|
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
|
||||||
|
:returns: A ListOpenAIResponseObject.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def list_openai_response_input_items(
|
||||||
|
self,
|
||||||
|
response_id: str,
|
||||||
|
after: str | None = None,
|
||||||
|
before: str | None = None,
|
||||||
|
include: list[str] | None = None,
|
||||||
|
limit: int | None = 20,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIResponseInputItem:
|
||||||
|
"""List input items.
|
||||||
|
|
||||||
|
:param response_id: The ID of the response to retrieve input items for.
|
||||||
|
:param after: An item ID to list items after, used for pagination.
|
||||||
|
:param before: An item ID to list items before, used for pagination.
|
||||||
|
:param include: Additional fields to include in the response.
|
||||||
|
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
||||||
|
:param order: The order to return the input items in. Default is desc.
|
||||||
|
:returns: An ListOpenAIResponseInputItem.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
|
"""Delete a response.
|
||||||
|
|
||||||
|
:param response_id: The ID of the OpenAI response to delete.
|
||||||
|
:returns: An OpenAIDeleteResponseObject
|
||||||
|
"""
|
||||||
|
...
|
||||||
409
src/llama_stack/apis/agents/models.py
Normal file
409
src/llama_stack/apis/agents/models.py
Normal file
|
|
@ -0,0 +1,409 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
CompletionMessage,
|
||||||
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
ToolCall,
|
||||||
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
|
ToolPromptFormat,
|
||||||
|
ToolResponse,
|
||||||
|
ToolResponseMessage,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
|
from llama_stack.apis.tools import ToolDef
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
from .openai_responses import (
|
||||||
|
OpenAIResponseInput,
|
||||||
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponsePrompt,
|
||||||
|
OpenAIResponseText,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ResponseGuardrailSpec(BaseModel):
|
||||||
|
"""Specification for a guardrail to apply during response generation."""
|
||||||
|
|
||||||
|
type: str = Field(description="The type/identifier of the guardrail.")
|
||||||
|
# TODO: more fields to be added for guardrail configuration
|
||||||
|
|
||||||
|
|
||||||
|
ResponseGuardrail = str | ResponseGuardrailSpec
|
||||||
|
|
||||||
|
|
||||||
|
class Attachment(BaseModel):
|
||||||
|
"""An attachment to an agent turn."""
|
||||||
|
|
||||||
|
content: InterleavedContent | URL = Field(description="The content of the attachment.")
|
||||||
|
mime_type: str = Field(description="The MIME type of the attachment.")
|
||||||
|
|
||||||
|
|
||||||
|
class Document(BaseModel):
|
||||||
|
"""A document to be used by an agent."""
|
||||||
|
|
||||||
|
content: InterleavedContent | URL = Field(description="The content of the document.")
|
||||||
|
mime_type: str = Field(description="The MIME type of the document.")
|
||||||
|
|
||||||
|
|
||||||
|
class StepCommon(BaseModel):
|
||||||
|
"""A common step in an agent turn."""
|
||||||
|
|
||||||
|
turn_id: str = Field(description="The ID of the turn.")
|
||||||
|
step_id: str = Field(description="The ID of the step.")
|
||||||
|
started_at: datetime | None = Field(default=None, description="The time the step started.")
|
||||||
|
completed_at: datetime | None = Field(default=None, description="The time the step completed.")
|
||||||
|
|
||||||
|
|
||||||
|
class StepType(StrEnum):
|
||||||
|
"""Type of the step in an agent turn."""
|
||||||
|
|
||||||
|
inference = "inference"
|
||||||
|
tool_execution = "tool_execution"
|
||||||
|
shield_call = "shield_call"
|
||||||
|
memory_retrieval = "memory_retrieval"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class InferenceStep(StepCommon):
|
||||||
|
"""An inference step in an agent turn."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
step_type: Literal[StepType.inference] = Field(default=StepType.inference)
|
||||||
|
model_response: CompletionMessage = Field(description="The response from the LLM.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolExecutionStep(StepCommon):
|
||||||
|
"""A tool execution step in an agent turn."""
|
||||||
|
|
||||||
|
step_type: Literal[StepType.tool_execution] = Field(default=StepType.tool_execution)
|
||||||
|
tool_calls: list[ToolCall] = Field(description="The tool calls to execute.")
|
||||||
|
tool_responses: list[ToolResponse] = Field(description="The tool responses from the tool calls.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ShieldCallStep(StepCommon):
|
||||||
|
"""A shield call step in an agent turn."""
|
||||||
|
|
||||||
|
step_type: Literal[StepType.shield_call] = Field(default=StepType.shield_call)
|
||||||
|
violation: SafetyViolation | None = Field(default=None, description="The violation from the shield call.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MemoryRetrievalStep(StepCommon):
|
||||||
|
"""A memory retrieval step in an agent turn."""
|
||||||
|
|
||||||
|
step_type: Literal[StepType.memory_retrieval] = Field(default=StepType.memory_retrieval)
|
||||||
|
# TODO: should this be List[str]?
|
||||||
|
vector_store_ids: str = Field(description="The IDs of the vector databases to retrieve context from.")
|
||||||
|
inserted_context: InterleavedContent = Field(description="The context retrieved from the vector databases.")
|
||||||
|
|
||||||
|
|
||||||
|
Step = Annotated[
|
||||||
|
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
|
||||||
|
Field(discriminator="step_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Turn(BaseModel):
|
||||||
|
"""A single turn in an interaction with an Agentic System."""
|
||||||
|
|
||||||
|
turn_id: str = Field(description="Unique identifier for the turn within a session")
|
||||||
|
session_id: str = Field(description="Unique identifier for the conversation session")
|
||||||
|
input_messages: list[UserMessage | ToolResponseMessage] = Field(
|
||||||
|
description="List of messages that initiated this turn"
|
||||||
|
)
|
||||||
|
steps: list[Step] = Field(description="Ordered list of processing steps executed during this turn")
|
||||||
|
output_message: CompletionMessage = Field(
|
||||||
|
description="The model's generated response containing content and metadata"
|
||||||
|
)
|
||||||
|
output_attachments: list[Attachment] | None = Field(
|
||||||
|
default_factory=lambda: [], description="Files or media attached to the agent's response"
|
||||||
|
)
|
||||||
|
|
||||||
|
started_at: datetime = Field(description="Timestamp when the turn began")
|
||||||
|
completed_at: datetime | None = Field(default=None, description="Timestamp when the turn finished, if completed")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Session(BaseModel):
|
||||||
|
"""A single session of an interaction with an Agentic System."""
|
||||||
|
|
||||||
|
session_id: str = Field(description="Unique identifier for the conversation session")
|
||||||
|
session_name: str = Field(description="Human-readable name for the session")
|
||||||
|
turns: list[Turn] = Field(description="List of all turns that have occurred in this session")
|
||||||
|
started_at: datetime = Field(description="Timestamp when the session was created")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentToolGroupWithArgs(BaseModel):
|
||||||
|
name: str = Field()
|
||||||
|
args: dict[str, Any] = Field()
|
||||||
|
|
||||||
|
|
||||||
|
AgentToolGroup = str | AgentToolGroupWithArgs
|
||||||
|
register_schema(AgentToolGroup, name="AgentTool")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfigCommon(BaseModel):
|
||||||
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
|
||||||
|
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||||
|
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||||
|
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||||
|
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
||||||
|
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||||
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||||
|
tool_config: ToolConfig | None = Field(default=None)
|
||||||
|
|
||||||
|
max_infer_iters: int | None = 10
|
||||||
|
|
||||||
|
def model_post_init(self, __context):
|
||||||
|
if self.tool_config:
|
||||||
|
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
|
||||||
|
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
||||||
|
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
||||||
|
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
||||||
|
else:
|
||||||
|
params = {}
|
||||||
|
if self.tool_choice:
|
||||||
|
params["tool_choice"] = self.tool_choice
|
||||||
|
if self.tool_prompt_format:
|
||||||
|
params["tool_prompt_format"] = self.tool_prompt_format
|
||||||
|
self.tool_config = ToolConfig(**params)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentConfig(AgentConfigCommon):
|
||||||
|
"""Configuration for an agent."""
|
||||||
|
|
||||||
|
model: str = Field(description="The model identifier to use for the agent")
|
||||||
|
instructions: str = Field(description="The system instructions for the agent")
|
||||||
|
name: str | None = Field(
|
||||||
|
default=None, description="Optional name for the agent, used in telemetry and identification"
|
||||||
|
)
|
||||||
|
enable_session_persistence: bool | None = Field(
|
||||||
|
default=False, description="Optional flag indicating whether session data has to be persisted"
|
||||||
|
)
|
||||||
|
response_format: ResponseFormat | None = Field(default=None, description="Optional response format configuration")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Agent(BaseModel):
|
||||||
|
"""An agent instance with configuration and metadata."""
|
||||||
|
|
||||||
|
agent_id: str = Field(description="Unique identifier for the agent")
|
||||||
|
agent_config: AgentConfig = Field(description="Configuration settings for the agent")
|
||||||
|
created_at: datetime = Field(description="Timestamp when the agent was created")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
|
instructions: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTurnResponseEventType(StrEnum):
|
||||||
|
step_start = "step_start"
|
||||||
|
step_complete = "step_complete"
|
||||||
|
step_progress = "step_progress"
|
||||||
|
|
||||||
|
turn_start = "turn_start"
|
||||||
|
turn_complete = "turn_complete"
|
||||||
|
turn_awaiting_input = "turn_awaiting_input"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||||
|
"""Payload for step start events in agent turn responses."""
|
||||||
|
|
||||||
|
event_type: Literal[AgentTurnResponseEventType.step_start] = Field(
|
||||||
|
default=AgentTurnResponseEventType.step_start, description="Type of event being reported"
|
||||||
|
)
|
||||||
|
step_type: StepType = Field(description="Type of step being executed")
|
||||||
|
step_id: str = Field(description="Unique identifier for the step within a turn")
|
||||||
|
metadata: dict[str, Any] | None = Field(default_factory=lambda: {}, description="Additional metadata for the step")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||||
|
"""Payload for step completion events in agent turn responses."""
|
||||||
|
|
||||||
|
event_type: Literal[AgentTurnResponseEventType.step_complete] = Field(
|
||||||
|
default=AgentTurnResponseEventType.step_complete, description="Type of event being reported"
|
||||||
|
)
|
||||||
|
step_type: StepType = Field(description="Type of step being executed")
|
||||||
|
step_id: str = Field(description="Unique identifier for the step within a turn")
|
||||||
|
step_details: Step = Field(description="Complete details of the executed step")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||||
|
"""Payload for step progress events in agent turn responses."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
event_type: Literal[AgentTurnResponseEventType.step_progress] = Field(
|
||||||
|
default=AgentTurnResponseEventType.step_progress, description="Type of event being reported"
|
||||||
|
)
|
||||||
|
step_type: StepType = Field(description="Type of step being executed")
|
||||||
|
step_id: str = Field(description="Unique identifier for the step within a turn")
|
||||||
|
|
||||||
|
delta: ContentDelta = Field(description="Incremental content changes during step execution")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||||
|
"""Payload for turn start events in agent turn responses."""
|
||||||
|
|
||||||
|
event_type: Literal[AgentTurnResponseEventType.turn_start] = Field(
|
||||||
|
default=AgentTurnResponseEventType.turn_start, description="Type of event being reported"
|
||||||
|
)
|
||||||
|
turn_id: str = Field(description="Unique identifier for the turn within a session")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||||
|
"""Payload for turn completion events in agent turn responses."""
|
||||||
|
|
||||||
|
event_type: Literal[AgentTurnResponseEventType.turn_complete] = Field(
|
||||||
|
default=AgentTurnResponseEventType.turn_complete, description="Type of event being reported"
|
||||||
|
)
|
||||||
|
turn: Turn = Field(description="Complete turn data including all steps and results")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||||
|
"""Payload for turn awaiting input events in agent turn responses."""
|
||||||
|
|
||||||
|
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = Field(
|
||||||
|
default=AgentTurnResponseEventType.turn_awaiting_input, description="Type of event being reported"
|
||||||
|
)
|
||||||
|
turn: Turn = Field(description="Turn data when waiting for external tool responses")
|
||||||
|
|
||||||
|
|
||||||
|
AgentTurnResponseEventPayload = Annotated[
|
||||||
|
AgentTurnResponseStepStartPayload
|
||||||
|
| AgentTurnResponseStepProgressPayload
|
||||||
|
| AgentTurnResponseStepCompletePayload
|
||||||
|
| AgentTurnResponseTurnStartPayload
|
||||||
|
| AgentTurnResponseTurnCompletePayload
|
||||||
|
| AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
|
Field(discriminator="event_type"),
|
||||||
|
]
|
||||||
|
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResponseEvent(BaseModel):
|
||||||
|
"""An event in an agent turn response stream."""
|
||||||
|
|
||||||
|
payload: AgentTurnResponseEventPayload = Field(description="Event-specific payload containing event data")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentCreateResponse(BaseModel):
|
||||||
|
"""Response returned when creating a new agent."""
|
||||||
|
|
||||||
|
agent_id: str = Field(description="Unique identifier for the created agent")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentSessionCreateResponse(BaseModel):
|
||||||
|
"""Response returned when creating a new agent session."""
|
||||||
|
|
||||||
|
session_id: str = Field(description="Unique identifier for the created session")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
|
"""Request to create a new turn for an agent."""
|
||||||
|
|
||||||
|
agent_id: str = Field(description="Unique identifier for the agent")
|
||||||
|
session_id: str = Field(description="Unique identifier for the conversation session")
|
||||||
|
|
||||||
|
# TODO: figure out how we can simplify this and make why
|
||||||
|
# ToolResponseMessage needs to be here (it is function call
|
||||||
|
# execution from outside the system)
|
||||||
|
messages: list[UserMessage | ToolResponseMessage] = Field(description="List of messages to start the turn with")
|
||||||
|
|
||||||
|
documents: list[Document] | None = Field(default=None, description="List of documents to provide to the agent")
|
||||||
|
toolgroups: list[AgentToolGroup] | None = Field(
|
||||||
|
default_factory=lambda: [], description="List of tool groups to make available for this turn"
|
||||||
|
)
|
||||||
|
|
||||||
|
stream: bool | None = Field(default=False, description="Whether to stream the response")
|
||||||
|
tool_config: ToolConfig | None = Field(default=None, description="Tool configuration to override agent defaults")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResumeRequest(BaseModel):
|
||||||
|
"""Request to resume an agent turn with tool responses."""
|
||||||
|
|
||||||
|
agent_id: str = Field(description="Unique identifier for the agent")
|
||||||
|
session_id: str = Field(description="Unique identifier for the conversation session")
|
||||||
|
turn_id: str = Field(description="Unique identifier for the turn within a session")
|
||||||
|
tool_responses: list[ToolResponse] = Field(description="List of tool responses to submit to continue the turn")
|
||||||
|
stream: bool | None = Field(default=False, description="Whether to stream the response")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResponseStreamChunk(BaseModel):
|
||||||
|
"""Streamed agent turn completion response."""
|
||||||
|
|
||||||
|
event: AgentTurnResponseEvent = Field(description="Individual event in the agent turn response stream")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentStepResponse(BaseModel):
|
||||||
|
"""Response containing details of a specific agent step."""
|
||||||
|
|
||||||
|
step: Step = Field(description="The complete step data and execution details")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CreateAgentSessionRequest(BaseModel):
|
||||||
|
"""Request to create a new session for an agent."""
|
||||||
|
|
||||||
|
agent_id: str = Field(..., description="The ID of the agent to create the session for")
|
||||||
|
session_name: str = Field(..., description="The name of the session to create")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CreateOpenAIResponseRequest(BaseModel):
|
||||||
|
"""Request to create a model response."""
|
||||||
|
|
||||||
|
input: str | list[OpenAIResponseInput] = Field(..., description="Input message(s) to create the response")
|
||||||
|
model: str = Field(..., description="The underlying LLM used for completions")
|
||||||
|
prompt: OpenAIResponsePrompt | None = Field(None, description="Prompt object with ID, version, and variables")
|
||||||
|
instructions: str | None = Field(None, description="System instructions")
|
||||||
|
previous_response_id: str | None = Field(
|
||||||
|
None, description="If specified, the new response will be a continuation of the previous response"
|
||||||
|
)
|
||||||
|
conversation: str | None = Field(
|
||||||
|
None, description="The ID of a conversation to add the response to. Must begin with 'conv_'"
|
||||||
|
)
|
||||||
|
store: bool = Field(True, description="Whether to store the response")
|
||||||
|
stream: bool = Field(False, description="Whether to stream the response")
|
||||||
|
temperature: float | None = Field(None, description="Sampling temperature")
|
||||||
|
text: OpenAIResponseText | None = Field(None, description="Text generation parameters")
|
||||||
|
tools: list[OpenAIResponseInputTool] | None = Field(None, description="Tools to make available")
|
||||||
|
include: list[str] | None = Field(None, description="Additional fields to include in the response")
|
||||||
|
max_infer_iters: int = Field(10, description="Maximum number of inference iterations (extension to the OpenAI API)")
|
||||||
|
guardrails: list[ResponseGuardrail] | None = Field(
|
||||||
|
None, description="List of guardrails to apply during response generation"
|
||||||
|
)
|
||||||
File diff suppressed because it is too large
Load diff
452
src/llama_stack/apis/agents/routes.py
Normal file
452
src/llama_stack/apis/agents/routes.py
Normal file
|
|
@ -0,0 +1,452 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Query, Request
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .agents_service import AgentsService
|
||||||
|
from .models import (
|
||||||
|
Agent,
|
||||||
|
AgentConfig,
|
||||||
|
AgentCreateResponse,
|
||||||
|
AgentSessionCreateResponse,
|
||||||
|
AgentStepResponse,
|
||||||
|
AgentTurnCreateRequest,
|
||||||
|
AgentTurnResumeRequest,
|
||||||
|
CreateAgentSessionRequest,
|
||||||
|
CreateOpenAIResponseRequest,
|
||||||
|
Session,
|
||||||
|
Turn,
|
||||||
|
)
|
||||||
|
from .openai_responses import (
|
||||||
|
ListOpenAIResponseInputItem,
|
||||||
|
ListOpenAIResponseObject,
|
||||||
|
OpenAIDeleteResponseObject,
|
||||||
|
OpenAIResponseObject,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_agents_service(request: Request) -> AgentsService:
|
||||||
|
"""Dependency to get the agents service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.agents not in impls:
|
||||||
|
raise ValueError("Agents API implementation not found")
|
||||||
|
return impls[Api.agents]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Agents"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
router_v1alpha = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
|
||||||
|
tags=["Agents"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/agents",
|
||||||
|
response_model=AgentCreateResponse,
|
||||||
|
summary="Create an agent.",
|
||||||
|
description="Create an agent with the given configuration.",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.post(
|
||||||
|
"/agents",
|
||||||
|
response_model=AgentCreateResponse,
|
||||||
|
summary="Create an agent.",
|
||||||
|
description="Create an agent with the given configuration.",
|
||||||
|
)
|
||||||
|
async def create_agent(
|
||||||
|
agent_config: AgentConfig = Body(...),
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
) -> AgentCreateResponse:
|
||||||
|
"""Create an agent with the given configuration."""
|
||||||
|
return await svc.create_agent(agent_config=agent_config)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/agents/{agent_id}/session/{session_id}/turn",
|
||||||
|
summary="Create a new turn for an agent.",
|
||||||
|
description="Create a new turn for an agent.",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.post(
|
||||||
|
"/agents/{{agent_id}}/session/{{session_id}}/turn",
|
||||||
|
summary="Create a new turn for an agent.",
|
||||||
|
description="Create a new turn for an agent.",
|
||||||
|
)
|
||||||
|
async def create_agent_turn(
|
||||||
|
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to create the turn for.")],
|
||||||
|
session_id: Annotated[str, FastAPIPath(..., description="The ID of the session to create the turn for.")],
|
||||||
|
body: AgentTurnCreateRequest = Body(...),
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
):
|
||||||
|
"""Create a new turn for an agent."""
|
||||||
|
return await svc.create_agent_turn(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=body.messages,
|
||||||
|
stream=body.stream,
|
||||||
|
documents=body.documents,
|
||||||
|
toolgroups=body.toolgroups,
|
||||||
|
tool_config=body.tool_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||||
|
summary="Resume an agent turn.",
|
||||||
|
description="Resume an agent turn with executed tool call responses.",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.post(
|
||||||
|
"/agents/{{agent_id}}/session/{{session_id}}/turn/{{turn_id}}/resume",
|
||||||
|
summary="Resume an agent turn.",
|
||||||
|
description="Resume an agent turn with executed tool call responses.",
|
||||||
|
)
|
||||||
|
async def resume_agent_turn(
|
||||||
|
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to resume.")],
|
||||||
|
session_id: Annotated[str, FastAPIPath(..., description="The ID of the session to resume.")],
|
||||||
|
turn_id: Annotated[str, FastAPIPath(..., description="The ID of the turn to resume.")],
|
||||||
|
body: AgentTurnResumeRequest = Body(...),
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
):
|
||||||
|
"""Resume an agent turn with executed tool call responses."""
|
||||||
|
return await svc.resume_agent_turn(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_responses=body.tool_responses,
|
||||||
|
stream=body.stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||||
|
response_model=Turn,
|
||||||
|
summary="Retrieve an agent turn.",
|
||||||
|
description="Retrieve an agent turn by its ID.",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/agents/{{agent_id}}/session/{{session_id}}/turn/{{turn_id}}",
|
||||||
|
response_model=Turn,
|
||||||
|
summary="Retrieve an agent turn.",
|
||||||
|
description="Retrieve an agent turn by its ID.",
|
||||||
|
)
|
||||||
|
async def get_agents_turn(
|
||||||
|
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to get the turn for.")],
|
||||||
|
session_id: Annotated[str, FastAPIPath(..., description="The ID of the session to get the turn for.")],
|
||||||
|
turn_id: Annotated[str, FastAPIPath(..., description="The ID of the turn to get.")],
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
) -> Turn:
|
||||||
|
"""Retrieve an agent turn by its ID."""
|
||||||
|
return await svc.get_agents_turn(agent_id=agent_id, session_id=session_id, turn_id=turn_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||||
|
response_model=AgentStepResponse,
|
||||||
|
summary="Retrieve an agent step.",
|
||||||
|
description="Retrieve an agent step by its ID.",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/agents/{{agent_id}}/session/{{session_id}}/turn/{{turn_id}}/step/{{step_id}}",
|
||||||
|
response_model=AgentStepResponse,
|
||||||
|
summary="Retrieve an agent step.",
|
||||||
|
description="Retrieve an agent step by its ID.",
|
||||||
|
)
|
||||||
|
async def get_agents_step(
|
||||||
|
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to get the step for.")],
|
||||||
|
session_id: Annotated[str, FastAPIPath(..., description="The ID of the session to get the step for.")],
|
||||||
|
turn_id: Annotated[str, FastAPIPath(..., description="The ID of the turn to get the step for.")],
|
||||||
|
step_id: Annotated[str, FastAPIPath(..., description="The ID of the step to get.")],
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
) -> AgentStepResponse:
|
||||||
|
"""Retrieve an agent step by its ID."""
|
||||||
|
return await svc.get_agents_step(agent_id=agent_id, session_id=session_id, turn_id=turn_id, step_id=step_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/agents/{agent_id}/session",
|
||||||
|
response_model=AgentSessionCreateResponse,
|
||||||
|
summary="Create a new session for an agent.",
|
||||||
|
description="Create a new session for an agent.",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.post(
|
||||||
|
"/agents/{{agent_id}}/session",
|
||||||
|
response_model=AgentSessionCreateResponse,
|
||||||
|
summary="Create a new session for an agent.",
|
||||||
|
description="Create a new session for an agent.",
|
||||||
|
)
|
||||||
|
async def create_agent_session(
|
||||||
|
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to create the session for.")],
|
||||||
|
body: CreateAgentSessionRequest = Body(...),
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
) -> AgentSessionCreateResponse:
|
||||||
|
"""Create a new session for an agent."""
|
||||||
|
return await svc.create_agent_session(agent_id=agent_id, session_name=body.session_name)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/agents/{agent_id}/session/{session_id}",
|
||||||
|
response_model=Session,
|
||||||
|
summary="Retrieve an agent session.",
|
||||||
|
description="Retrieve an agent session by its ID.",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/agents/{{agent_id}}/session/{{session_id}}",
|
||||||
|
response_model=Session,
|
||||||
|
summary="Retrieve an agent session.",
|
||||||
|
description="Retrieve an agent session by its ID.",
|
||||||
|
)
|
||||||
|
async def get_agents_session(
|
||||||
|
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to get the session for.")],
|
||||||
|
session_id: Annotated[str, FastAPIPath(..., description="The ID of the session to get.")],
|
||||||
|
turn_ids: list[str] | None = Query(None, description="List of turn IDs to filter the session by."),
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
) -> Session:
|
||||||
|
"""Retrieve an agent session by its ID."""
|
||||||
|
return await svc.get_agents_session(session_id=session_id, agent_id=agent_id, turn_ids=turn_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/agents/{agent_id}/session/{session_id}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Delete an agent session.",
|
||||||
|
description="Delete an agent session by its ID.",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.delete(
|
||||||
|
"/agents/{{agent_id}}/session/{{session_id}}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Delete an agent session.",
|
||||||
|
description="Delete an agent session by its ID.",
|
||||||
|
)
|
||||||
|
async def delete_agents_session(
|
||||||
|
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to delete the session for.")],
|
||||||
|
session_id: Annotated[str, FastAPIPath(..., description="The ID of the session to delete.")],
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
) -> None:
|
||||||
|
"""Delete an agent session by its ID and its associated turns."""
|
||||||
|
await svc.delete_agents_session(session_id=session_id, agent_id=agent_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/agents/{agent_id}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Delete an agent.",
|
||||||
|
description="Delete an agent by its ID.",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.delete(
|
||||||
|
"/agents/{{agent_id}}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Delete an agent.",
|
||||||
|
description="Delete an agent by its ID.",
|
||||||
|
)
|
||||||
|
async def delete_agent(
|
||||||
|
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to delete.")],
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
) -> None:
|
||||||
|
"""Delete an agent by its ID and its associated sessions and turns."""
|
||||||
|
await svc.delete_agent(agent_id=agent_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/agents",
|
||||||
|
summary="List all agents.",
|
||||||
|
description="List all agents.",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/agents",
|
||||||
|
summary="List all agents.",
|
||||||
|
description="List all agents.",
|
||||||
|
)
|
||||||
|
async def list_agents(
|
||||||
|
start_index: int | None = Query(None, description="The index to start the pagination from."),
|
||||||
|
limit: int | None = Query(None, description="The number of agents to return."),
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
):
|
||||||
|
"""List all agents."""
|
||||||
|
return await svc.list_agents(start_index=start_index, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/agents/{agent_id}",
|
||||||
|
response_model=Agent,
|
||||||
|
summary="Describe an agent.",
|
||||||
|
description="Describe an agent by its ID.",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/agents/{{agent_id}}",
|
||||||
|
response_model=Agent,
|
||||||
|
summary="Describe an agent.",
|
||||||
|
description="Describe an agent by its ID.",
|
||||||
|
)
|
||||||
|
async def get_agent(
|
||||||
|
agent_id: Annotated[str, FastAPIPath(..., description="ID of the agent.")],
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
) -> Agent:
|
||||||
|
"""Describe an agent by its ID."""
|
||||||
|
return await svc.get_agent(agent_id=agent_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/agents/{agent_id}/sessions",
|
||||||
|
summary="List all sessions of an agent.",
|
||||||
|
description="List all session(s) of a given agent.",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/agents/{{agent_id}}/sessions",
|
||||||
|
summary="List all sessions of an agent.",
|
||||||
|
description="List all session(s) of a given agent.",
|
||||||
|
)
|
||||||
|
async def list_agent_sessions(
|
||||||
|
agent_id: Annotated[str, FastAPIPath(..., description="The ID of the agent to list sessions for.")],
|
||||||
|
start_index: int | None = Query(None, description="The index to start the pagination from."),
|
||||||
|
limit: int | None = Query(None, description="The number of sessions to return."),
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
):
|
||||||
|
"""List all session(s) of a given agent."""
|
||||||
|
return await svc.list_agent_sessions(agent_id=agent_id, start_index=start_index, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
|
# OpenAI Responses API endpoints
|
||||||
|
@router.get(
|
||||||
|
"/responses/{response_id}",
|
||||||
|
response_model=OpenAIResponseObject,
|
||||||
|
summary="Get a model response.",
|
||||||
|
description="Get a model response.",
|
||||||
|
)
|
||||||
|
async def get_openai_response(
|
||||||
|
response_id: Annotated[str, FastAPIPath(..., description="The ID of the OpenAI response to retrieve.")],
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
) -> OpenAIResponseObject:
|
||||||
|
"""Get a model response."""
|
||||||
|
return await svc.get_openai_response(response_id=response_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/responses",
|
||||||
|
summary="Create a model response.",
|
||||||
|
description="Create a model response.",
|
||||||
|
)
|
||||||
|
async def create_openai_response(
|
||||||
|
body: CreateOpenAIResponseRequest = Body(...),
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
):
|
||||||
|
"""Create a model response."""
|
||||||
|
return await svc.create_openai_response(
|
||||||
|
input=body.input,
|
||||||
|
model=body.model,
|
||||||
|
prompt=body.prompt,
|
||||||
|
instructions=body.instructions,
|
||||||
|
previous_response_id=body.previous_response_id,
|
||||||
|
conversation=body.conversation,
|
||||||
|
store=body.store,
|
||||||
|
stream=body.stream,
|
||||||
|
temperature=body.temperature,
|
||||||
|
text=body.text,
|
||||||
|
tools=body.tools,
|
||||||
|
include=body.include,
|
||||||
|
max_infer_iters=body.max_infer_iters,
|
||||||
|
guardrails=body.guardrails,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/responses",
|
||||||
|
response_model=ListOpenAIResponseObject,
|
||||||
|
summary="List all responses.",
|
||||||
|
description="List all responses.",
|
||||||
|
)
|
||||||
|
async def list_openai_responses(
|
||||||
|
after: str | None = Query(None, description="The ID of the last response to return."),
|
||||||
|
limit: int | None = Query(50, description="The number of responses to return."),
|
||||||
|
model: str | None = Query(None, description="The model to filter responses by."),
|
||||||
|
order: Order | None = Query(
|
||||||
|
Order.desc, description="The order to sort responses by when sorted by created_at ('asc' or 'desc')."
|
||||||
|
),
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
) -> ListOpenAIResponseObject:
|
||||||
|
"""List all responses."""
|
||||||
|
return await svc.list_openai_responses(after=after, limit=limit, model=model, order=order)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/responses/{response_id}/input_items",
|
||||||
|
response_model=ListOpenAIResponseInputItem,
|
||||||
|
summary="List input items.",
|
||||||
|
description="List input items.",
|
||||||
|
)
|
||||||
|
async def list_openai_response_input_items(
|
||||||
|
response_id: Annotated[str, FastAPIPath(..., description="The ID of the response to retrieve input items for.")],
|
||||||
|
after: str | None = Query(None, description="An item ID to list items after, used for pagination."),
|
||||||
|
before: str | None = Query(None, description="An item ID to list items before, used for pagination."),
|
||||||
|
include: list[str] | None = Query(None, description="Additional fields to include in the response."),
|
||||||
|
limit: int | None = Query(
|
||||||
|
20,
|
||||||
|
description="A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.",
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
),
|
||||||
|
order: Order | None = Query(Order.desc, description="The order to return the input items in. Default is desc."),
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
) -> ListOpenAIResponseInputItem:
|
||||||
|
"""List input items."""
|
||||||
|
return await svc.list_openai_response_input_items(
|
||||||
|
response_id=response_id, after=after, before=before, include=include, limit=limit, order=order
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/responses/{response_id}",
|
||||||
|
response_model=OpenAIDeleteResponseObject,
|
||||||
|
summary="Delete a response.",
|
||||||
|
description="Delete a response.",
|
||||||
|
)
|
||||||
|
async def delete_openai_response(
|
||||||
|
response_id: Annotated[str, FastAPIPath(..., description="The ID of the OpenAI response to delete.")],
|
||||||
|
svc: AgentsService = Depends(get_agents_service),
|
||||||
|
) -> OpenAIDeleteResponseObject:
|
||||||
|
"""Delete a response."""
|
||||||
|
return await svc.delete_openai_response(response_id=response_id)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_agents_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Agents API (legacy compatibility)."""
|
||||||
|
main_router = APIRouter()
|
||||||
|
main_router.include_router(router)
|
||||||
|
main_router.include_router(router_v1alpha)
|
||||||
|
return main_router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.agents, create_agents_router)
|
||||||
|
|
@ -4,6 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .batches import Batches, BatchObject, ListBatchesResponse
|
try:
|
||||||
|
from openai.types import Batch as BatchObject
|
||||||
|
except ImportError:
|
||||||
|
BatchObject = None # type: ignore[assignment,misc]
|
||||||
|
|
||||||
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .batches_service import BatchService
|
||||||
|
from .models import CreateBatchRequest, ListBatchesResponse
|
||||||
|
|
||||||
|
# Backward compatibility - export Batches as alias for BatchService
|
||||||
|
Batches = BatchService
|
||||||
|
|
||||||
|
__all__ = ["Batches", "BatchService", "BatchObject", "ListBatchesResponse", "CreateBatchRequest"]
|
||||||
|
|
|
||||||
|
|
@ -1,96 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Literal, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
try:
|
|
||||||
from openai.types import Batch as BatchObject
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListBatchesResponse(BaseModel):
|
|
||||||
"""Response containing a list of batch objects."""
|
|
||||||
|
|
||||||
object: Literal["list"] = "list"
|
|
||||||
data: list[BatchObject] = Field(..., description="List of batch objects")
|
|
||||||
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
|
||||||
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
|
||||||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class Batches(Protocol):
|
|
||||||
"""
|
|
||||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
|
||||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
|
||||||
cost-effective inference at scale.
|
|
||||||
|
|
||||||
The API is designed to allow use of openai client libraries for seamless integration.
|
|
||||||
|
|
||||||
This API provides the following extensions:
|
|
||||||
- idempotent batch creation
|
|
||||||
|
|
||||||
Note: This API is currently under active development and may undergo changes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def create_batch(
|
|
||||||
self,
|
|
||||||
input_file_id: str,
|
|
||||||
endpoint: str,
|
|
||||||
completion_window: Literal["24h"],
|
|
||||||
metadata: dict[str, str] | None = None,
|
|
||||||
idempotency_key: str | None = None,
|
|
||||||
) -> BatchObject:
|
|
||||||
"""Create a new batch for processing multiple API requests.
|
|
||||||
|
|
||||||
:param input_file_id: The ID of an uploaded file containing requests for the batch.
|
|
||||||
:param endpoint: The endpoint to be used for all requests in the batch.
|
|
||||||
:param completion_window: The time window within which the batch should be processed.
|
|
||||||
:param metadata: Optional metadata for the batch.
|
|
||||||
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
|
|
||||||
:returns: The created batch object.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
|
||||||
"""Retrieve information about a specific batch.
|
|
||||||
|
|
||||||
:param batch_id: The ID of the batch to retrieve.
|
|
||||||
:returns: The batch object.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
|
||||||
"""Cancel a batch that is in progress.
|
|
||||||
|
|
||||||
:param batch_id: The ID of the batch to cancel.
|
|
||||||
:returns: The updated batch object.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def list_batches(
|
|
||||||
self,
|
|
||||||
after: str | None = None,
|
|
||||||
limit: int = 20,
|
|
||||||
) -> ListBatchesResponse:
|
|
||||||
"""List all batches for the current user.
|
|
||||||
|
|
||||||
:param after: A cursor for pagination; returns batches after this batch ID.
|
|
||||||
:param limit: Number of batches to return (default 20, max 100).
|
|
||||||
:returns: A list of batch objects.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
56
src/llama_stack/apis/batches/batches_service.py
Normal file
56
src/llama_stack/apis/batches/batches_service.py
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
try:
|
||||||
|
from openai.types import Batch as BatchObject
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||||
|
|
||||||
|
from .models import ListBatchesResponse
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class BatchService(Protocol):
|
||||||
|
"""The Batches API enables efficient processing of multiple requests in a single operation,
|
||||||
|
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||||
|
cost-effective inference at scale.
|
||||||
|
|
||||||
|
The API is designed to allow use of openai client libraries for seamless integration.
|
||||||
|
|
||||||
|
This API provides the following extensions:
|
||||||
|
- idempotent batch creation
|
||||||
|
|
||||||
|
Note: This API is currently under active development and may undergo changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def create_batch(
|
||||||
|
self,
|
||||||
|
input_file_id: str,
|
||||||
|
endpoint: str,
|
||||||
|
completion_window: Literal["24h"],
|
||||||
|
metadata: dict[str, str] | None = None,
|
||||||
|
idempotency_key: str | None = None,
|
||||||
|
) -> BatchObject:
|
||||||
|
"""Create a new batch for processing multiple API requests."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||||
|
"""Retrieve information about a specific batch."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||||
|
"""Cancel a batch that is in progress."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def list_batches(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> ListBatchesResponse:
|
||||||
|
"""List all batches for the current user."""
|
||||||
|
...
|
||||||
42
src/llama_stack/apis/batches/models.py
Normal file
42
src/llama_stack/apis/batches/models.py
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
try:
|
||||||
|
from openai.types import Batch as BatchObject
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CreateBatchRequest(BaseModel):
|
||||||
|
"""Request model for creating a batch."""
|
||||||
|
|
||||||
|
input_file_id: str = Field(..., description="The ID of an uploaded file containing requests for the batch.")
|
||||||
|
endpoint: str = Field(..., description="The endpoint to be used for all requests in the batch.")
|
||||||
|
completion_window: Literal["24h"] = Field(
|
||||||
|
..., description="The time window within which the batch should be processed."
|
||||||
|
)
|
||||||
|
metadata: dict[str, str] | None = Field(default=None, description="Optional metadata for the batch.")
|
||||||
|
idempotency_key: str | None = Field(
|
||||||
|
default=None, description="Optional idempotency key. When provided, enables idempotent behavior."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ListBatchesResponse(BaseModel):
|
||||||
|
"""Response containing a list of batch objects."""
|
||||||
|
|
||||||
|
object: Literal["list"] = Field(default="list", description="The object type, which is always 'list'.")
|
||||||
|
data: list[BatchObject] = Field(..., description="List of batch objects.")
|
||||||
|
first_id: str | None = Field(default=None, description="ID of the first batch in the list.")
|
||||||
|
last_id: str | None = Field(default=None, description="ID of the last batch in the list.")
|
||||||
|
has_more: bool = Field(default=False, description="Whether there are more batches available.")
|
||||||
111
src/llama_stack/apis/batches/routes.py
Normal file
111
src/llama_stack/apis/batches/routes.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Query, Request
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
|
||||||
|
try:
|
||||||
|
from openai.types import Batch as BatchObject
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .batches_service import BatchService
|
||||||
|
from .models import CreateBatchRequest, ListBatchesResponse
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch_service(request: Request) -> BatchService:
|
||||||
|
"""Dependency to get the batch service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.batches not in impls:
|
||||||
|
raise ValueError("Batches API implementation not found")
|
||||||
|
return impls[Api.batches]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Batches"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/batches",
|
||||||
|
response_model=BatchObject,
|
||||||
|
summary="Create a new batch for processing multiple API requests.",
|
||||||
|
description="Create a new batch for processing multiple API requests.",
|
||||||
|
)
|
||||||
|
async def create_batch(
|
||||||
|
request: CreateBatchRequest = Body(...),
|
||||||
|
svc: BatchService = Depends(get_batch_service),
|
||||||
|
) -> BatchObject:
|
||||||
|
"""Create a new batch."""
|
||||||
|
return await svc.create_batch(
|
||||||
|
input_file_id=request.input_file_id,
|
||||||
|
endpoint=request.endpoint,
|
||||||
|
completion_window=request.completion_window,
|
||||||
|
metadata=request.metadata,
|
||||||
|
idempotency_key=request.idempotency_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/batches/{batch_id}",
|
||||||
|
response_model=BatchObject,
|
||||||
|
summary="Retrieve information about a specific batch.",
|
||||||
|
description="Retrieve information about a specific batch.",
|
||||||
|
)
|
||||||
|
async def retrieve_batch(
|
||||||
|
batch_id: Annotated[str, FastAPIPath(..., description="The ID of the batch to retrieve.")],
|
||||||
|
svc: BatchService = Depends(get_batch_service),
|
||||||
|
) -> BatchObject:
|
||||||
|
"""Retrieve batch information."""
|
||||||
|
return await svc.retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/batches/{batch_id}/cancel",
|
||||||
|
response_model=BatchObject,
|
||||||
|
summary="Cancel a batch that is in progress.",
|
||||||
|
description="Cancel a batch that is in progress.",
|
||||||
|
)
|
||||||
|
async def cancel_batch(
|
||||||
|
batch_id: Annotated[str, FastAPIPath(..., description="The ID of the batch to cancel.")],
|
||||||
|
svc: BatchService = Depends(get_batch_service),
|
||||||
|
) -> BatchObject:
|
||||||
|
"""Cancel a batch."""
|
||||||
|
return await svc.cancel_batch(batch_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/batches",
|
||||||
|
response_model=ListBatchesResponse,
|
||||||
|
summary="List all batches for the current user.",
|
||||||
|
description="List all batches for the current user.",
|
||||||
|
)
|
||||||
|
async def list_batches(
|
||||||
|
after: str | None = Query(None, description="A cursor for pagination; returns batches after this batch ID."),
|
||||||
|
limit: int = Query(20, description="Number of batches to return (default 20, max 100).", ge=1, le=100),
|
||||||
|
svc: BatchService = Depends(get_batch_service),
|
||||||
|
) -> ListBatchesResponse:
|
||||||
|
"""List all batches."""
|
||||||
|
return await svc.list_batches(after=after, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_batches_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Batches API (legacy compatibility)."""
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.batches, create_batches_router)
|
||||||
|
|
@ -4,4 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .benchmarks import *
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .benchmarks_service import BenchmarksService
|
||||||
|
from .models import (
|
||||||
|
Benchmark,
|
||||||
|
BenchmarkInput,
|
||||||
|
CommonBenchmarkFields,
|
||||||
|
ListBenchmarksResponse,
|
||||||
|
RegisterBenchmarkRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backward compatibility - export Benchmarks as alias for BenchmarksService
|
||||||
|
Benchmarks = BenchmarksService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Benchmarks",
|
||||||
|
"BenchmarksService",
|
||||||
|
"Benchmark",
|
||||||
|
"BenchmarkInput",
|
||||||
|
"CommonBenchmarkFields",
|
||||||
|
"ListBenchmarksResponse",
|
||||||
|
"RegisterBenchmarkRequest",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,104 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
from typing import Any, Literal, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
class CommonBenchmarkFields(BaseModel):
|
|
||||||
dataset_id: str
|
|
||||||
scoring_functions: list[str]
|
|
||||||
metadata: dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Metadata for this evaluation task",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Benchmark(CommonBenchmarkFields, Resource):
|
|
||||||
"""A benchmark resource for evaluating model performance.
|
|
||||||
|
|
||||||
:param dataset_id: Identifier of the dataset to use for the benchmark evaluation
|
|
||||||
:param scoring_functions: List of scoring function identifiers to apply during evaluation
|
|
||||||
:param metadata: Metadata for this evaluation task
|
|
||||||
:param type: The resource type, always benchmark
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal[ResourceType.benchmark] = ResourceType.benchmark
|
|
||||||
|
|
||||||
@property
|
|
||||||
def benchmark_id(self) -> str:
|
|
||||||
return self.identifier
|
|
||||||
|
|
||||||
@property
|
|
||||||
def provider_benchmark_id(self) -> str | None:
|
|
||||||
return self.provider_resource_id
|
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
|
||||||
benchmark_id: str
|
|
||||||
provider_id: str | None = None
|
|
||||||
provider_benchmark_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ListBenchmarksResponse(BaseModel):
|
|
||||||
data: list[Benchmark]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class Benchmarks(Protocol):
|
|
||||||
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
|
||||||
"""List all benchmarks.
|
|
||||||
|
|
||||||
:returns: A ListBenchmarksResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def get_benchmark(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
) -> Benchmark:
|
|
||||||
"""Get a benchmark by its ID.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to get.
|
|
||||||
:returns: A Benchmark.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def register_benchmark(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
dataset_id: str,
|
|
||||||
scoring_functions: list[str],
|
|
||||||
provider_benchmark_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Register a benchmark.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to register.
|
|
||||||
:param dataset_id: The ID of the dataset to use for the benchmark.
|
|
||||||
:param scoring_functions: The scoring functions to use for the benchmark.
|
|
||||||
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
|
|
||||||
:param provider_id: The ID of the provider to use for the benchmark.
|
|
||||||
:param metadata: The metadata to use for the benchmark.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
|
||||||
"""Unregister a benchmark.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to unregister.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
42
src/llama_stack/apis/benchmarks/benchmarks_service.py
Normal file
42
src/llama_stack/apis/benchmarks/benchmarks_service.py
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
from .models import Benchmark, ListBenchmarksResponse
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class BenchmarksService(Protocol):
|
||||||
|
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||||
|
"""List all benchmarks."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_benchmark(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
) -> Benchmark:
|
||||||
|
"""Get a benchmark by its ID."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def register_benchmark(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: list[str],
|
||||||
|
provider_benchmark_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Register a benchmark."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||||
|
"""Unregister a benchmark."""
|
||||||
|
...
|
||||||
58
src/llama_stack/apis/benchmarks/models.py
Normal file
58
src/llama_stack/apis/benchmarks/models.py
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class CommonBenchmarkFields(BaseModel):
|
||||||
|
dataset_id: str = Field(..., description="The ID of the dataset to use for the benchmark")
|
||||||
|
scoring_functions: list[str] = Field(..., description="The scoring functions to use for the benchmark")
|
||||||
|
metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Metadata for this evaluation task",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Benchmark(CommonBenchmarkFields, Resource):
|
||||||
|
"""A benchmark resource for evaluating model performance."""
|
||||||
|
|
||||||
|
type: Literal[ResourceType.benchmark] = Field(
|
||||||
|
default=ResourceType.benchmark, description="The resource type, always benchmark"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ListBenchmarksResponse(BaseModel):
|
||||||
|
"""Response model for listing benchmarks."""
|
||||||
|
|
||||||
|
data: list[Benchmark] = Field(..., description="List of benchmark resources")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RegisterBenchmarkRequest(BaseModel):
|
||||||
|
"""Request model for registering a benchmark."""
|
||||||
|
|
||||||
|
benchmark_id: str = Field(..., description="The ID of the benchmark to register")
|
||||||
|
dataset_id: str = Field(..., description="The ID of the dataset to use for the benchmark")
|
||||||
|
scoring_functions: list[str] = Field(..., description="The scoring functions to use for the benchmark")
|
||||||
|
provider_benchmark_id: str | None = Field(
|
||||||
|
default=None, description="The ID of the provider benchmark to use for the benchmark"
|
||||||
|
)
|
||||||
|
provider_id: str | None = Field(default=None, description="The ID of the provider to use for the benchmark")
|
||||||
|
metadata: dict[str, Any] | None = Field(default=None, description="The metadata to use for the benchmark")
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
||||||
|
benchmark_id: str = Field(..., description="The ID of the benchmark to use for the benchmark")
|
||||||
|
provider_id: str | None = Field(default=None, description="The ID of the provider to use for the benchmark")
|
||||||
|
provider_benchmark_id: str | None = Field(
|
||||||
|
default=None, description="The ID of the provider benchmark to use for the benchmark"
|
||||||
|
)
|
||||||
144
src/llama_stack/apis/benchmarks/routes.py
Normal file
144
src/llama_stack/apis/benchmarks/routes.py
Normal file
|
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Request
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .benchmarks_service import BenchmarksService
|
||||||
|
from .models import Benchmark, ListBenchmarksResponse, RegisterBenchmarkRequest
|
||||||
|
|
||||||
|
|
||||||
|
def get_benchmarks_service(request: Request) -> BenchmarksService:
|
||||||
|
"""Dependency to get the benchmarks service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.benchmarks not in impls:
|
||||||
|
raise ValueError("Benchmarks API implementation not found")
|
||||||
|
return impls[Api.benchmarks]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Benchmarks"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
router_v1alpha = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
|
||||||
|
tags=["Benchmarks"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/eval/benchmarks",
|
||||||
|
response_model=ListBenchmarksResponse,
|
||||||
|
summary="List all benchmarks",
|
||||||
|
description="List all benchmarks",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/eval/benchmarks",
|
||||||
|
response_model=ListBenchmarksResponse,
|
||||||
|
summary="List all benchmarks",
|
||||||
|
description="List all benchmarks",
|
||||||
|
)
|
||||||
|
async def list_benchmarks(svc: BenchmarksService = Depends(get_benchmarks_service)) -> ListBenchmarksResponse:
|
||||||
|
"""List all benchmarks."""
|
||||||
|
return await svc.list_benchmarks()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/eval/benchmarks/{benchmark_id}",
|
||||||
|
response_model=Benchmark,
|
||||||
|
summary="Get a benchmark by its ID",
|
||||||
|
description="Get a benchmark by its ID",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/eval/benchmarks/{{benchmark_id}}",
|
||||||
|
response_model=Benchmark,
|
||||||
|
summary="Get a benchmark by its ID",
|
||||||
|
description="Get a benchmark by its ID",
|
||||||
|
)
|
||||||
|
async def get_benchmark(
|
||||||
|
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to get")],
|
||||||
|
svc: BenchmarksService = Depends(get_benchmarks_service),
|
||||||
|
) -> Benchmark:
|
||||||
|
"""Get a benchmark by its ID."""
|
||||||
|
return await svc.get_benchmark(benchmark_id=benchmark_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/eval/benchmarks",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Register a benchmark",
|
||||||
|
description="Register a benchmark",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.post(
|
||||||
|
"/eval/benchmarks",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Register a benchmark",
|
||||||
|
description="Register a benchmark",
|
||||||
|
)
|
||||||
|
async def register_benchmark(
|
||||||
|
body: RegisterBenchmarkRequest = Body(...),
|
||||||
|
svc: BenchmarksService = Depends(get_benchmarks_service),
|
||||||
|
) -> None:
|
||||||
|
"""Register a benchmark."""
|
||||||
|
return await svc.register_benchmark(
|
||||||
|
benchmark_id=body.benchmark_id,
|
||||||
|
dataset_id=body.dataset_id,
|
||||||
|
scoring_functions=body.scoring_functions,
|
||||||
|
provider_benchmark_id=body.provider_benchmark_id,
|
||||||
|
provider_id=body.provider_id,
|
||||||
|
metadata=body.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/eval/benchmarks/{benchmark_id}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Unregister a benchmark",
|
||||||
|
description="Unregister a benchmark",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.delete(
|
||||||
|
"/eval/benchmarks/{{benchmark_id}}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Unregister a benchmark",
|
||||||
|
description="Unregister a benchmark",
|
||||||
|
)
|
||||||
|
async def unregister_benchmark(
|
||||||
|
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to unregister")],
|
||||||
|
svc: BenchmarksService = Depends(get_benchmarks_service),
|
||||||
|
) -> None:
|
||||||
|
"""Unregister a benchmark."""
|
||||||
|
await svc.unregister_benchmark(benchmark_id=benchmark_id)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_benchmarks_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Benchmarks API (legacy compatibility)."""
|
||||||
|
main_router = APIRouter()
|
||||||
|
main_router.include_router(router)
|
||||||
|
main_router.include_router(router_v1alpha)
|
||||||
|
return main_router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.benchmarks, create_benchmarks_router)
|
||||||
|
|
@ -15,21 +15,13 @@ from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class URL(BaseModel):
|
class URL(BaseModel):
|
||||||
"""A URL reference to external content.
|
"""A URL reference to external content."""
|
||||||
|
|
||||||
:param uri: The URL string pointing to the resource
|
|
||||||
"""
|
|
||||||
|
|
||||||
uri: str
|
uri: str
|
||||||
|
|
||||||
|
|
||||||
class _URLOrData(BaseModel):
|
class _URLOrData(BaseModel):
|
||||||
"""
|
"""A URL or a base64 encoded string."""
|
||||||
A URL or a base64 encoded string
|
|
||||||
|
|
||||||
:param url: A URL of the image or data URL in the format of data:image/{type};base64,{data}. Note that URL could have length limits.
|
|
||||||
:param data: base64 encoded image data as string
|
|
||||||
"""
|
|
||||||
|
|
||||||
url: URL | None = None
|
url: URL | None = None
|
||||||
# data is a base64 encoded string, hint with contentEncoding=base64
|
# data is a base64 encoded string, hint with contentEncoding=base64
|
||||||
|
|
@ -45,11 +37,7 @@ class _URLOrData(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ImageContentItem(BaseModel):
|
class ImageContentItem(BaseModel):
|
||||||
"""A image content item
|
"""A image content item."""
|
||||||
|
|
||||||
:param type: Discriminator type of the content item. Always "image"
|
|
||||||
:param image: Image as a base64 encoded string or an URL
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["image"] = "image"
|
type: Literal["image"] = "image"
|
||||||
image: _URLOrData
|
image: _URLOrData
|
||||||
|
|
@ -57,11 +45,7 @@ class ImageContentItem(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TextContentItem(BaseModel):
|
class TextContentItem(BaseModel):
|
||||||
"""A text content item
|
"""A text content item."""
|
||||||
|
|
||||||
:param type: Discriminator type of the content item. Always "text"
|
|
||||||
:param text: Text content
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["text"] = "text"
|
type: Literal["text"] = "text"
|
||||||
text: str
|
text: str
|
||||||
|
|
@ -81,11 +65,7 @@ register_schema(InterleavedContent, name="InterleavedContent")
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TextDelta(BaseModel):
|
class TextDelta(BaseModel):
|
||||||
"""A text content delta for streaming responses.
|
"""A text content delta for streaming responses."""
|
||||||
|
|
||||||
:param type: Discriminator type of the delta. Always "text"
|
|
||||||
:param text: The incremental text content
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["text"] = "text"
|
type: Literal["text"] = "text"
|
||||||
text: str
|
text: str
|
||||||
|
|
@ -93,23 +73,14 @@ class TextDelta(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ImageDelta(BaseModel):
|
class ImageDelta(BaseModel):
|
||||||
"""An image content delta for streaming responses.
|
"""An image content delta for streaming responses."""
|
||||||
|
|
||||||
:param type: Discriminator type of the delta. Always "image"
|
|
||||||
:param image: The incremental image data as bytes
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["image"] = "image"
|
type: Literal["image"] = "image"
|
||||||
image: bytes
|
image: bytes
|
||||||
|
|
||||||
|
|
||||||
class ToolCallParseStatus(Enum):
|
class ToolCallParseStatus(Enum):
|
||||||
"""Status of tool call parsing during streaming.
|
"""Status of tool call parsing during streaming."""
|
||||||
:cvar started: Tool call parsing has begun
|
|
||||||
:cvar in_progress: Tool call parsing is ongoing
|
|
||||||
:cvar failed: Tool call parsing failed
|
|
||||||
:cvar succeeded: Tool call parsing completed successfully
|
|
||||||
"""
|
|
||||||
|
|
||||||
started = "started"
|
started = "started"
|
||||||
in_progress = "in_progress"
|
in_progress = "in_progress"
|
||||||
|
|
@ -119,12 +90,7 @@ class ToolCallParseStatus(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolCallDelta(BaseModel):
|
class ToolCallDelta(BaseModel):
|
||||||
"""A tool call content delta for streaming responses.
|
"""A tool call content delta for streaming responses."""
|
||||||
|
|
||||||
:param type: Discriminator type of the delta. Always "tool_call"
|
|
||||||
:param tool_call: Either an in-progress tool call string or the final parsed tool call
|
|
||||||
:param parse_status: Current parsing status of the tool call
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["tool_call"] = "tool_call"
|
type: Literal["tool_call"] = "tool_call"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,11 +28,7 @@ class JobStatus(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Job(BaseModel):
|
class Job(BaseModel):
|
||||||
"""A job execution instance with status tracking.
|
"""A job execution instance with status tracking."""
|
||||||
|
|
||||||
:param job_id: Unique identifier for the job
|
|
||||||
:param status: Current execution status of the job
|
|
||||||
"""
|
|
||||||
|
|
||||||
job_id: str
|
job_id: str
|
||||||
status: JobStatus
|
status: JobStatus
|
||||||
|
|
|
||||||
|
|
@ -7,16 +7,13 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class Order(Enum):
|
class Order(Enum):
|
||||||
"""Sort order for paginated responses.
|
"""Sort order for paginated responses."""
|
||||||
:cvar asc: Ascending order
|
|
||||||
:cvar desc: Descending order
|
|
||||||
"""
|
|
||||||
|
|
||||||
asc = "asc"
|
asc = "asc"
|
||||||
desc = "desc"
|
desc = "desc"
|
||||||
|
|
@ -24,13 +21,8 @@ class Order(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PaginatedResponse(BaseModel):
|
class PaginatedResponse(BaseModel):
|
||||||
"""A generic paginated response that follows a simple format.
|
"""A generic paginated response that follows a simple format."""
|
||||||
|
|
||||||
:param data: The list of items for the current page
|
data: list[dict[str, Any]] = Field(description="The list of items for the current page.")
|
||||||
:param has_more: Whether there are more items available after this set
|
has_more: bool = Field(description="Whether there are more items available after this set.")
|
||||||
:param url: The URL for accessing this list
|
url: str | None = Field(description="The URL for accessing this list.")
|
||||||
"""
|
|
||||||
|
|
||||||
data: list[dict[str, Any]]
|
|
||||||
has_more: bool
|
|
||||||
url: str | None = None
|
|
||||||
|
|
|
||||||
|
|
@ -6,42 +6,28 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PostTrainingMetric(BaseModel):
|
class PostTrainingMetric(BaseModel):
|
||||||
"""Training metrics captured during post-training jobs.
|
"""Training metrics captured during post-training jobs."""
|
||||||
|
|
||||||
:param epoch: Training epoch number
|
epoch: int = Field(description="Training epoch number.")
|
||||||
:param train_loss: Loss value on the training dataset
|
train_loss: float = Field(description="Loss value on the training dataset.")
|
||||||
:param validation_loss: Loss value on the validation dataset
|
validation_loss: float = Field(description="Loss value on the validation dataset.")
|
||||||
:param perplexity: Perplexity metric indicating model confidence
|
perplexity: float = Field(description="Perplexity metric indicating model confidence.")
|
||||||
"""
|
|
||||||
|
|
||||||
epoch: int
|
|
||||||
train_loss: float
|
|
||||||
validation_loss: float
|
|
||||||
perplexity: float
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Checkpoint(BaseModel):
|
class Checkpoint(BaseModel):
|
||||||
"""Checkpoint created during training runs.
|
"""Checkpoint created during training runs."""
|
||||||
|
|
||||||
:param identifier: Unique identifier for the checkpoint
|
identifier: str = Field(description="Unique identifier for the checkpoint.")
|
||||||
:param created_at: Timestamp when the checkpoint was created
|
created_at: datetime = Field(description="Timestamp when the checkpoint was created.")
|
||||||
:param epoch: Training epoch when the checkpoint was saved
|
epoch: int = Field(description="Training epoch when the checkpoint was saved.")
|
||||||
:param post_training_job_id: Identifier of the training job that created this checkpoint
|
post_training_job_id: str = Field(description="Identifier of the training job that created this checkpoint.")
|
||||||
:param path: File system path where the checkpoint is stored
|
path: str = Field(description="File system path where the checkpoint is stored.")
|
||||||
:param training_metrics: (Optional) Training metrics associated with this checkpoint
|
training_metrics: PostTrainingMetric | None = Field(description="Training metrics associated with this checkpoint.")
|
||||||
"""
|
|
||||||
|
|
||||||
identifier: str
|
|
||||||
created_at: datetime
|
|
||||||
epoch: int
|
|
||||||
post_training_job_id: str
|
|
||||||
path: str
|
|
||||||
training_metrics: PostTrainingMetric | None = None
|
|
||||||
|
|
|
||||||
|
|
@ -4,28 +4,38 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .conversations import (
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .conversations_service import ConversationService
|
||||||
|
from .models import (
|
||||||
Conversation,
|
Conversation,
|
||||||
ConversationCreateRequest,
|
ConversationCreateRequest,
|
||||||
ConversationDeletedResource,
|
ConversationDeletedResource,
|
||||||
ConversationItem,
|
ConversationItem,
|
||||||
ConversationItemCreateRequest,
|
ConversationItemCreateRequest,
|
||||||
ConversationItemDeletedResource,
|
ConversationItemDeletedResource,
|
||||||
|
ConversationItemInclude,
|
||||||
ConversationItemList,
|
ConversationItemList,
|
||||||
Conversations,
|
ConversationMessage,
|
||||||
ConversationUpdateRequest,
|
ConversationUpdateRequest,
|
||||||
Metadata,
|
Metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Backward compatibility - export Conversations as alias for ConversationService
|
||||||
|
Conversations = ConversationService
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Conversation",
|
|
||||||
"ConversationCreateRequest",
|
|
||||||
"ConversationDeletedResource",
|
|
||||||
"ConversationItem",
|
|
||||||
"ConversationItemCreateRequest",
|
|
||||||
"ConversationItemDeletedResource",
|
|
||||||
"ConversationItemList",
|
|
||||||
"Conversations",
|
"Conversations",
|
||||||
|
"ConversationService",
|
||||||
|
"Conversation",
|
||||||
|
"ConversationMessage",
|
||||||
|
"ConversationItem",
|
||||||
|
"ConversationCreateRequest",
|
||||||
"ConversationUpdateRequest",
|
"ConversationUpdateRequest",
|
||||||
|
"ConversationDeletedResource",
|
||||||
|
"ConversationItemCreateRequest",
|
||||||
|
"ConversationItemList",
|
||||||
|
"ConversationItemDeletedResource",
|
||||||
|
"ConversationItemInclude",
|
||||||
"Metadata",
|
"Metadata",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
70
src/llama_stack/apis/conversations/conversations_service.py
Normal file
70
src/llama_stack/apis/conversations/conversations_service.py
Normal file
|
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
Conversation,
|
||||||
|
ConversationDeletedResource,
|
||||||
|
ConversationItem,
|
||||||
|
ConversationItemDeletedResource,
|
||||||
|
ConversationItemInclude,
|
||||||
|
ConversationItemList,
|
||||||
|
Metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class ConversationService(Protocol):
|
||||||
|
"""Conversations
|
||||||
|
|
||||||
|
Protocol for conversation management operations."""
|
||||||
|
|
||||||
|
async def create_conversation(
|
||||||
|
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||||
|
) -> Conversation:
|
||||||
|
"""Create a conversation."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||||
|
"""Retrieve a conversation."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||||
|
"""Update a conversation."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||||
|
"""Delete a conversation."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||||
|
"""Create items."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||||
|
"""Retrieve an item."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def list_items(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
after: str | None = None,
|
||||||
|
include: list[ConversationItemInclude] | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
order: Literal["asc", "desc"] | None = None,
|
||||||
|
) -> ConversationItemList:
|
||||||
|
"""List items."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def openai_delete_conversation_item(
|
||||||
|
self, conversation_id: str, item_id: str
|
||||||
|
) -> ConversationItemDeletedResource:
|
||||||
|
"""Delete an item."""
|
||||||
|
...
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Annotated, Literal, Protocol, runtime_checkable
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
@ -20,9 +20,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseOutputMessageMCPListTools,
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
|
|
||||||
Metadata = dict[str, str]
|
Metadata = dict[str, str]
|
||||||
|
|
||||||
|
|
@ -76,31 +74,6 @@ ConversationItem = Annotated[
|
||||||
]
|
]
|
||||||
register_schema(ConversationItem, name="ConversationItem")
|
register_schema(ConversationItem, name="ConversationItem")
|
||||||
|
|
||||||
# Using OpenAI types directly caused issues but some notes for reference:
|
|
||||||
# Note that ConversationItem is a Annotated Union of the types below:
|
|
||||||
# from openai.types.responses import *
|
|
||||||
# from openai.types.responses.response_item import *
|
|
||||||
# from openai.types.conversations import ConversationItem
|
|
||||||
# f = [
|
|
||||||
# ResponseFunctionToolCallItem,
|
|
||||||
# ResponseFunctionToolCallOutputItem,
|
|
||||||
# ResponseFileSearchToolCall,
|
|
||||||
# ResponseFunctionWebSearch,
|
|
||||||
# ImageGenerationCall,
|
|
||||||
# ResponseComputerToolCall,
|
|
||||||
# ResponseComputerToolCallOutputItem,
|
|
||||||
# ResponseReasoningItem,
|
|
||||||
# ResponseCodeInterpreterToolCall,
|
|
||||||
# LocalShellCall,
|
|
||||||
# LocalShellCallOutput,
|
|
||||||
# McpListTools,
|
|
||||||
# McpApprovalRequest,
|
|
||||||
# McpApprovalResponse,
|
|
||||||
# McpCall,
|
|
||||||
# ResponseCustomToolCall,
|
|
||||||
# ResponseCustomToolCallOutput
|
|
||||||
# ]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ConversationCreateRequest(BaseModel):
|
class ConversationCreateRequest(BaseModel):
|
||||||
|
|
@ -180,119 +153,3 @@ class ConversationItemDeletedResource(BaseModel):
|
||||||
id: str = Field(..., description="The deleted item identifier")
|
id: str = Field(..., description="The deleted item identifier")
|
||||||
object: str = Field(default="conversation.item.deleted", description="Object type")
|
object: str = Field(default="conversation.item.deleted", description="Object type")
|
||||||
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
@trace_protocol
|
|
||||||
class Conversations(Protocol):
|
|
||||||
"""Conversations
|
|
||||||
|
|
||||||
Protocol for conversation management operations."""
|
|
||||||
|
|
||||||
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def create_conversation(
|
|
||||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
|
||||||
) -> Conversation:
|
|
||||||
"""Create a conversation.
|
|
||||||
|
|
||||||
Create a conversation.
|
|
||||||
|
|
||||||
:param items: Initial items to include in the conversation context.
|
|
||||||
:param metadata: Set of key-value pairs that can be attached to an object.
|
|
||||||
:returns: The created conversation object.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
|
||||||
"""Retrieve a conversation.
|
|
||||||
|
|
||||||
Get a conversation with the given ID.
|
|
||||||
|
|
||||||
:param conversation_id: The conversation identifier.
|
|
||||||
:returns: The conversation object.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
|
||||||
"""Update a conversation.
|
|
||||||
|
|
||||||
Update a conversation's metadata with the given ID.
|
|
||||||
|
|
||||||
:param conversation_id: The conversation identifier.
|
|
||||||
:param metadata: Set of key-value pairs that can be attached to an object.
|
|
||||||
:returns: The updated conversation object.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
|
||||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
|
||||||
"""Delete a conversation.
|
|
||||||
|
|
||||||
Delete a conversation with the given ID.
|
|
||||||
|
|
||||||
:param conversation_id: The conversation identifier.
|
|
||||||
:returns: The deleted conversation resource.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
|
||||||
"""Create items.
|
|
||||||
|
|
||||||
Create items in the conversation.
|
|
||||||
|
|
||||||
:param conversation_id: The conversation identifier.
|
|
||||||
:param items: Items to include in the conversation context.
|
|
||||||
:returns: List of created items.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
|
||||||
"""Retrieve an item.
|
|
||||||
|
|
||||||
Retrieve a conversation item.
|
|
||||||
|
|
||||||
:param conversation_id: The conversation identifier.
|
|
||||||
:param item_id: The item identifier.
|
|
||||||
:returns: The conversation item.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def list_items(
|
|
||||||
self,
|
|
||||||
conversation_id: str,
|
|
||||||
after: str | None = None,
|
|
||||||
include: list[ConversationItemInclude] | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
order: Literal["asc", "desc"] | None = None,
|
|
||||||
) -> ConversationItemList:
|
|
||||||
"""List items.
|
|
||||||
|
|
||||||
List items in the conversation.
|
|
||||||
|
|
||||||
:param conversation_id: The conversation identifier.
|
|
||||||
:param after: An item ID to list items after, used in pagination.
|
|
||||||
:param include: Specify additional output data to include in the response.
|
|
||||||
:param limit: A limit on the number of objects to be returned (1-100, default 20).
|
|
||||||
:param order: The order to return items in (asc or desc, default desc).
|
|
||||||
:returns: List of conversation items.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
|
||||||
async def openai_delete_conversation_item(
|
|
||||||
self, conversation_id: str, item_id: str
|
|
||||||
) -> ConversationItemDeletedResource:
|
|
||||||
"""Delete an item.
|
|
||||||
|
|
||||||
Delete a conversation item.
|
|
||||||
|
|
||||||
:param conversation_id: The conversation identifier.
|
|
||||||
:param item_id: The item identifier.
|
|
||||||
:returns: The deleted item resource.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
177
src/llama_stack/apis/conversations/routes.py
Normal file
177
src/llama_stack/apis/conversations/routes.py
Normal file
|
|
@ -0,0 +1,177 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Query, Request
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .conversations_service import ConversationService
|
||||||
|
from .models import (
|
||||||
|
Conversation,
|
||||||
|
ConversationCreateRequest,
|
||||||
|
ConversationDeletedResource,
|
||||||
|
ConversationItem,
|
||||||
|
ConversationItemCreateRequest,
|
||||||
|
ConversationItemDeletedResource,
|
||||||
|
ConversationItemInclude,
|
||||||
|
ConversationItemList,
|
||||||
|
ConversationUpdateRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_conversation_service(request: Request) -> ConversationService:
|
||||||
|
"""Dependency to get the conversation service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.conversations not in impls:
|
||||||
|
raise ValueError("Conversations API implementation not found")
|
||||||
|
return impls[Api.conversations]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Conversations"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/conversations",
|
||||||
|
response_model=Conversation,
|
||||||
|
summary="Create a conversation",
|
||||||
|
description="Create a conversation",
|
||||||
|
)
|
||||||
|
async def create_conversation(
|
||||||
|
body: ConversationCreateRequest = Body(...),
|
||||||
|
svc: ConversationService = Depends(get_conversation_service),
|
||||||
|
) -> Conversation:
|
||||||
|
"""Create a conversation."""
|
||||||
|
return await svc.create_conversation(items=body.items, metadata=body.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/conversations/{conversation_id}",
|
||||||
|
response_model=Conversation,
|
||||||
|
summary="Retrieve a conversation",
|
||||||
|
description="Get a conversation with the given ID",
|
||||||
|
)
|
||||||
|
async def get_conversation(
|
||||||
|
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
|
||||||
|
svc: ConversationService = Depends(get_conversation_service),
|
||||||
|
) -> Conversation:
|
||||||
|
"""Get a conversation."""
|
||||||
|
return await svc.get_conversation(conversation_id=conversation_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/conversations/{conversation_id}",
|
||||||
|
response_model=Conversation,
|
||||||
|
summary="Update a conversation",
|
||||||
|
description="Update a conversation's metadata with the given ID",
|
||||||
|
)
|
||||||
|
async def update_conversation(
|
||||||
|
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
|
||||||
|
body: ConversationUpdateRequest = Body(...),
|
||||||
|
svc: ConversationService = Depends(get_conversation_service),
|
||||||
|
) -> Conversation:
|
||||||
|
"""Update a conversation."""
|
||||||
|
return await svc.update_conversation(conversation_id=conversation_id, metadata=body.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/conversations/{conversation_id}",
|
||||||
|
response_model=ConversationDeletedResource,
|
||||||
|
summary="Delete a conversation",
|
||||||
|
description="Delete a conversation with the given ID",
|
||||||
|
)
|
||||||
|
async def openai_delete_conversation(
|
||||||
|
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
|
||||||
|
svc: ConversationService = Depends(get_conversation_service),
|
||||||
|
) -> ConversationDeletedResource:
|
||||||
|
"""Delete a conversation."""
|
||||||
|
return await svc.openai_delete_conversation(conversation_id=conversation_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/conversations/{conversation_id}/items",
|
||||||
|
response_model=ConversationItemList,
|
||||||
|
summary="Create items",
|
||||||
|
description="Create items in the conversation",
|
||||||
|
)
|
||||||
|
async def add_items(
|
||||||
|
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
|
||||||
|
body: ConversationItemCreateRequest = Body(...),
|
||||||
|
svc: ConversationService = Depends(get_conversation_service),
|
||||||
|
) -> ConversationItemList:
|
||||||
|
"""Create items in the conversation."""
|
||||||
|
return await svc.add_items(conversation_id=conversation_id, items=body.items)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/conversations/{conversation_id}/items/{item_id}",
|
||||||
|
response_model=ConversationItem,
|
||||||
|
summary="Retrieve an item",
|
||||||
|
description="Retrieve a conversation item",
|
||||||
|
)
|
||||||
|
async def retrieve(
|
||||||
|
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
|
||||||
|
item_id: Annotated[str, FastAPIPath(..., description="The item identifier")],
|
||||||
|
svc: ConversationService = Depends(get_conversation_service),
|
||||||
|
) -> ConversationItem:
|
||||||
|
"""Retrieve a conversation item."""
|
||||||
|
return await svc.retrieve(conversation_id=conversation_id, item_id=item_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/conversations/{conversation_id}/items",
|
||||||
|
response_model=ConversationItemList,
|
||||||
|
summary="List items",
|
||||||
|
description="List items in the conversation",
|
||||||
|
)
|
||||||
|
async def list_items(
|
||||||
|
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
|
||||||
|
after: str | None = Query(None, description="An item ID to list items after, used in pagination"),
|
||||||
|
include: list[ConversationItemInclude] | None = Query(
|
||||||
|
None, description="Specify additional output data to include in the response"
|
||||||
|
),
|
||||||
|
limit: int | None = Query(None, description="A limit on the number of objects to be returned (1-100, default 20)"),
|
||||||
|
order: Literal["asc", "desc"] | None = Query(
|
||||||
|
None, description="The order to return items in (asc or desc, default desc)"
|
||||||
|
),
|
||||||
|
svc: ConversationService = Depends(get_conversation_service),
|
||||||
|
) -> ConversationItemList:
|
||||||
|
"""List items in the conversation."""
|
||||||
|
return await svc.list_items(conversation_id=conversation_id, after=after, include=include, limit=limit, order=order)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/conversations/{conversation_id}/items/{item_id}",
|
||||||
|
response_model=ConversationItemDeletedResource,
|
||||||
|
summary="Delete an item",
|
||||||
|
description="Delete a conversation item",
|
||||||
|
)
|
||||||
|
async def openai_delete_conversation_item(
|
||||||
|
conversation_id: Annotated[str, FastAPIPath(..., description="The conversation identifier")],
|
||||||
|
item_id: Annotated[str, FastAPIPath(..., description="The item identifier")],
|
||||||
|
svc: ConversationService = Depends(get_conversation_service),
|
||||||
|
) -> ConversationItemDeletedResource:
|
||||||
|
"""Delete a conversation item."""
|
||||||
|
return await svc.openai_delete_conversation_item(conversation_id=conversation_id, item_id=item_id)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_conversations_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Conversations API (legacy compatibility)."""
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.conversations, create_conversations_router)
|
||||||
|
|
@ -4,4 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datasetio import *
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .datasetio_service import DatasetIOService, DatasetStore
|
||||||
|
|
||||||
|
# Backward compatibility - export DatasetIO as alias for DatasetIOService
|
||||||
|
DatasetIO = DatasetIOService
|
||||||
|
|
||||||
|
__all__ = ["DatasetIO", "DatasetIOService", "DatasetStore"]
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ class DatasetStore(Protocol):
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class DatasetIO(Protocol):
|
class DatasetIOService(Protocol):
|
||||||
# keeping for aligning with inference/safety, but this is not used
|
# keeping for aligning with inference/safety, but this is not used
|
||||||
dataset_store: DatasetStore
|
dataset_store: DatasetStore
|
||||||
|
|
||||||
|
|
@ -28,28 +28,10 @@ class DatasetIO(Protocol):
|
||||||
start_index: int | None = None,
|
start_index: int | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
) -> PaginatedResponse:
|
) -> PaginatedResponse:
|
||||||
"""Get a paginated list of rows from a dataset.
|
"""Get a paginated list of rows from a dataset."""
|
||||||
|
|
||||||
Uses offset-based pagination where:
|
|
||||||
- start_index: The starting index (0-based). If None, starts from beginning.
|
|
||||||
- limit: Number of items to return. If None or -1, returns all items.
|
|
||||||
|
|
||||||
The response includes:
|
|
||||||
- data: List of items for the current page.
|
|
||||||
- has_more: Whether there are more items available after this set.
|
|
||||||
|
|
||||||
:param dataset_id: The ID of the dataset to get the rows from.
|
|
||||||
:param start_index: Index into dataset for the first row to get. Get all rows if None.
|
|
||||||
:param limit: The number of rows to get.
|
|
||||||
:returns: A PaginatedResponse.
|
|
||||||
"""
|
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1BETA)
|
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1BETA)
|
||||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||||
"""Append rows to a dataset.
|
"""Append rows to a dataset."""
|
||||||
|
|
||||||
:param dataset_id: The ID of the dataset to append the rows to.
|
|
||||||
:param rows: The rows to append to the dataset.
|
|
||||||
"""
|
|
||||||
...
|
...
|
||||||
5
src/llama_stack/apis/datasetio/models.py
Normal file
5
src/llama_stack/apis/datasetio/models.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
77
src/llama_stack/apis/datasetio/routes.py
Normal file
77
src/llama_stack/apis/datasetio/routes.py
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Query, Request
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
|
||||||
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1BETA
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .datasetio_service import DatasetIOService
|
||||||
|
|
||||||
|
|
||||||
|
def get_datasetio_service(request: Request) -> DatasetIOService:
|
||||||
|
"""Dependency to get the datasetio service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.datasetio not in impls:
|
||||||
|
raise ValueError("DatasetIO API implementation not found")
|
||||||
|
return impls[Api.datasetio]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1BETA}",
|
||||||
|
tags=["DatasetIO"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/datasetio/iterrows/{dataset_id:path}",
|
||||||
|
response_model=PaginatedResponse,
|
||||||
|
summary="Get a paginated list of rows from a dataset.",
|
||||||
|
description="Get a paginated list of rows from a dataset using offset-based pagination.",
|
||||||
|
)
|
||||||
|
async def iterrows(
|
||||||
|
dataset_id: Annotated[str, FastAPIPath(..., description="The ID of the dataset to get the rows from")],
|
||||||
|
start_index: int | None = Query(
|
||||||
|
None, description="Index into dataset for the first row to get. Get all rows if None."
|
||||||
|
),
|
||||||
|
limit: int | None = Query(None, description="The number of rows to get."),
|
||||||
|
svc: DatasetIOService = Depends(get_datasetio_service),
|
||||||
|
) -> PaginatedResponse:
|
||||||
|
"""Get a paginated list of rows from a dataset."""
|
||||||
|
return await svc.iterrows(dataset_id=dataset_id, start_index=start_index, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/datasetio/append-rows/{dataset_id:path}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Append rows to a dataset.",
|
||||||
|
description="Append rows to a dataset.",
|
||||||
|
)
|
||||||
|
async def append_rows(
|
||||||
|
dataset_id: Annotated[str, FastAPIPath(..., description="The ID of the dataset to append the rows to")],
|
||||||
|
body: list[dict[str, Any]] = Body(..., description="The rows to append to the dataset."),
|
||||||
|
svc: DatasetIOService = Depends(get_datasetio_service),
|
||||||
|
) -> None:
|
||||||
|
"""Append rows to a dataset."""
|
||||||
|
await svc.append_rows(dataset_id=dataset_id, rows=body)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_datasetio_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the DatasetIO API (legacy compatibility)."""
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.datasetio, create_datasetio_router)
|
||||||
|
|
@ -4,4 +4,36 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datasets import *
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .datasets_service import DatasetsService
|
||||||
|
from .models import (
|
||||||
|
CommonDatasetFields,
|
||||||
|
Dataset,
|
||||||
|
DatasetInput,
|
||||||
|
DatasetPurpose,
|
||||||
|
DatasetType,
|
||||||
|
DataSource,
|
||||||
|
ListDatasetsResponse,
|
||||||
|
RegisterDatasetRequest,
|
||||||
|
RowsDataSource,
|
||||||
|
URIDataSource,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backward compatibility - export Datasets as alias for DatasetsService
|
||||||
|
Datasets = DatasetsService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Datasets",
|
||||||
|
"DatasetsService",
|
||||||
|
"Dataset",
|
||||||
|
"DatasetInput",
|
||||||
|
"CommonDatasetFields",
|
||||||
|
"DatasetPurpose",
|
||||||
|
"DatasetType",
|
||||||
|
"DataSource",
|
||||||
|
"URIDataSource",
|
||||||
|
"RowsDataSource",
|
||||||
|
"ListDatasetsResponse",
|
||||||
|
"RegisterDatasetRequest",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,247 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum, StrEnum
|
|
||||||
from typing import Annotated, Any, Literal, Protocol
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1BETA
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetPurpose(StrEnum):
|
|
||||||
"""
|
|
||||||
Purpose of the dataset. Each purpose has a required input data schema.
|
|
||||||
|
|
||||||
:cvar post-training/messages: The dataset contains messages used for post-training.
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hello, world!"},
|
|
||||||
{"role": "assistant", "content": "Hello, world!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
:cvar eval/question-answer: The dataset contains a question column and an answer column.
|
|
||||||
{
|
|
||||||
"question": "What is the capital of France?",
|
|
||||||
"answer": "Paris"
|
|
||||||
}
|
|
||||||
:cvar eval/messages-answer: The dataset contains a messages column with list of messages and an answer column.
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hello, my name is John Doe."},
|
|
||||||
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
|
|
||||||
{"role": "user", "content": "What's my name?"},
|
|
||||||
],
|
|
||||||
"answer": "John Doe"
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
post_training_messages = "post-training/messages"
|
|
||||||
eval_question_answer = "eval/question-answer"
|
|
||||||
eval_messages_answer = "eval/messages-answer"
|
|
||||||
|
|
||||||
# TODO: add more schemas here
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetType(Enum):
|
|
||||||
"""
|
|
||||||
Type of the dataset source.
|
|
||||||
:cvar uri: The dataset can be obtained from a URI.
|
|
||||||
:cvar rows: The dataset is stored in rows.
|
|
||||||
"""
|
|
||||||
|
|
||||||
uri = "uri"
|
|
||||||
rows = "rows"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class URIDataSource(BaseModel):
|
|
||||||
"""A dataset that can be obtained from a URI.
|
|
||||||
:param uri: The dataset can be obtained from a URI. E.g.
|
|
||||||
- "https://mywebsite.com/mydata.jsonl"
|
|
||||||
- "lsfs://mydata.jsonl"
|
|
||||||
- "data:csv;base64,{base64_content}"
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["uri"] = "uri"
|
|
||||||
uri: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RowsDataSource(BaseModel):
|
|
||||||
"""A dataset stored in rows.
|
|
||||||
:param rows: The dataset is stored in rows. E.g.
|
|
||||||
- [
|
|
||||||
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["rows"] = "rows"
|
|
||||||
rows: list[dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
DataSource = Annotated[
|
|
||||||
URIDataSource | RowsDataSource,
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
register_schema(DataSource, name="DataSource")
|
|
||||||
|
|
||||||
|
|
||||||
class CommonDatasetFields(BaseModel):
|
|
||||||
"""
|
|
||||||
Common fields for a dataset.
|
|
||||||
|
|
||||||
:param purpose: Purpose of the dataset indicating its intended use
|
|
||||||
:param source: Data source configuration for the dataset
|
|
||||||
:param metadata: Additional metadata for the dataset
|
|
||||||
"""
|
|
||||||
|
|
||||||
purpose: DatasetPurpose
|
|
||||||
source: DataSource
|
|
||||||
metadata: dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Any additional metadata for this dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Dataset(CommonDatasetFields, Resource):
|
|
||||||
"""Dataset resource for storing and accessing training or evaluation data.
|
|
||||||
|
|
||||||
:param type: Type of resource, always 'dataset' for datasets
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal[ResourceType.dataset] = ResourceType.dataset
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dataset_id(self) -> str:
|
|
||||||
return self.identifier
|
|
||||||
|
|
||||||
@property
|
|
||||||
def provider_dataset_id(self) -> str | None:
|
|
||||||
return self.provider_resource_id
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
|
||||||
"""Input parameters for dataset operations.
|
|
||||||
|
|
||||||
:param dataset_id: Unique identifier for the dataset
|
|
||||||
"""
|
|
||||||
|
|
||||||
dataset_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class ListDatasetsResponse(BaseModel):
|
|
||||||
"""Response from listing datasets.
|
|
||||||
|
|
||||||
:param data: List of datasets
|
|
||||||
"""
|
|
||||||
|
|
||||||
data: list[Dataset]
|
|
||||||
|
|
||||||
|
|
||||||
class Datasets(Protocol):
|
|
||||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
|
|
||||||
async def register_dataset(
|
|
||||||
self,
|
|
||||||
purpose: DatasetPurpose,
|
|
||||||
source: DataSource,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
dataset_id: str | None = None,
|
|
||||||
) -> Dataset:
|
|
||||||
"""
|
|
||||||
Register a new dataset.
|
|
||||||
|
|
||||||
:param purpose: The purpose of the dataset.
|
|
||||||
One of:
|
|
||||||
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hello, world!"},
|
|
||||||
{"role": "assistant", "content": "Hello, world!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
|
|
||||||
{
|
|
||||||
"question": "What is the capital of France?",
|
|
||||||
"answer": "Paris"
|
|
||||||
}
|
|
||||||
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hello, my name is John Doe."},
|
|
||||||
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
|
|
||||||
{"role": "user", "content": "What's my name?"},
|
|
||||||
],
|
|
||||||
"answer": "John Doe"
|
|
||||||
}
|
|
||||||
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples:
|
|
||||||
- {
|
|
||||||
"type": "uri",
|
|
||||||
"uri": "https://mywebsite.com/mydata.jsonl"
|
|
||||||
}
|
|
||||||
- {
|
|
||||||
"type": "uri",
|
|
||||||
"uri": "lsfs://mydata.jsonl"
|
|
||||||
}
|
|
||||||
- {
|
|
||||||
"type": "uri",
|
|
||||||
"uri": "data:csv;base64,{base64_content}"
|
|
||||||
}
|
|
||||||
- {
|
|
||||||
"type": "uri",
|
|
||||||
"uri": "huggingface://llamastack/simpleqa?split=train"
|
|
||||||
}
|
|
||||||
- {
|
|
||||||
"type": "rows",
|
|
||||||
"rows": [
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hello, world!"},
|
|
||||||
{"role": "assistant", "content": "Hello, world!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
:param metadata: The metadata for the dataset.
|
|
||||||
- E.g. {"description": "My dataset"}.
|
|
||||||
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
|
||||||
:returns: A Dataset.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
|
|
||||||
async def get_dataset(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
) -> Dataset:
|
|
||||||
"""Get a dataset by its ID.
|
|
||||||
|
|
||||||
:param dataset_id: The ID of the dataset to get.
|
|
||||||
:returns: A Dataset.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1BETA)
|
|
||||||
async def list_datasets(self) -> ListDatasetsResponse:
|
|
||||||
"""List all datasets.
|
|
||||||
|
|
||||||
:returns: A ListDatasetsResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
|
|
||||||
async def unregister_dataset(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Unregister a dataset by its ID.
|
|
||||||
|
|
||||||
:param dataset_id: The ID of the dataset to unregister.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
65
src/llama_stack/apis/datasets/datasets_service.py
Normal file
65
src/llama_stack/apis/datasets/datasets_service.py
Normal file
|
|
@ -0,0 +1,65 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
from .models import Dataset, DatasetPurpose, DataSource, ListDatasetsResponse
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class DatasetsService(Protocol):
|
||||||
|
async def register_dataset(
|
||||||
|
self,
|
||||||
|
purpose: DatasetPurpose,
|
||||||
|
source: DataSource,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
dataset_id: str | None = None,
|
||||||
|
) -> Dataset:
|
||||||
|
"""
|
||||||
|
Register a new dataset.
|
||||||
|
|
||||||
|
:param purpose: The purpose of the dataset.
|
||||||
|
One of:
|
||||||
|
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
||||||
|
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
|
||||||
|
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
|
||||||
|
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset.
|
||||||
|
:param metadata: The metadata for the dataset.
|
||||||
|
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
||||||
|
:returns: A Dataset.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_dataset(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
) -> Dataset:
|
||||||
|
"""Get a dataset by its ID.
|
||||||
|
|
||||||
|
:param dataset_id: The ID of the dataset to get.
|
||||||
|
:returns: A Dataset.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
|
"""List all datasets.
|
||||||
|
|
||||||
|
:returns: A ListDatasetsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def unregister_dataset(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Unregister a dataset by its ID.
|
||||||
|
|
||||||
|
:param dataset_id: The ID of the dataset to unregister.
|
||||||
|
"""
|
||||||
|
...
|
||||||
134
src/llama_stack/apis/datasets/models.py
Normal file
134
src/llama_stack/apis/datasets/models.py
Normal file
|
|
@ -0,0 +1,134 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum, StrEnum
|
||||||
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetPurpose(StrEnum):
|
||||||
|
"""
|
||||||
|
Purpose of the dataset. Each purpose has a required input data schema.
|
||||||
|
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello, world!"},
|
||||||
|
{"role": "assistant", "content": "Hello, world!"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
{
|
||||||
|
"question": "What is the capital of France?",
|
||||||
|
"answer": "Paris"
|
||||||
|
}
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello, my name is John Doe."},
|
||||||
|
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
|
||||||
|
{"role": "user", "content": "What's my name?"},
|
||||||
|
],
|
||||||
|
"answer": "John Doe"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
post_training_messages = "post-training/messages"
|
||||||
|
eval_question_answer = "eval/question-answer"
|
||||||
|
eval_messages_answer = "eval/messages-answer"
|
||||||
|
|
||||||
|
# TODO: add more schemas here
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetType(Enum):
|
||||||
|
"""
|
||||||
|
Type of the dataset source.
|
||||||
|
"""
|
||||||
|
|
||||||
|
uri = "uri"
|
||||||
|
rows = "rows"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class URIDataSource(BaseModel):
|
||||||
|
"""A dataset that can be obtained from a URI."""
|
||||||
|
|
||||||
|
type: Literal["uri"] = Field(default="uri", description="The type of data source")
|
||||||
|
uri: str = Field(
|
||||||
|
...,
|
||||||
|
description="The dataset can be obtained from a URI. E.g. 'https://mywebsite.com/mydata.jsonl', 'lsfs://mydata.jsonl', 'data:csv;base64,{base64_content}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RowsDataSource(BaseModel):
|
||||||
|
"""A dataset stored in rows."""
|
||||||
|
|
||||||
|
type: Literal["rows"] = Field(default="rows", description="The type of data source")
|
||||||
|
rows: list[dict[str, Any]] = Field(
|
||||||
|
...,
|
||||||
|
description="The dataset is stored in rows. E.g. [{'messages': [{'role': 'user', 'content': 'Hello, world!'}, {'role': 'assistant', 'content': 'Hello, world!'}]}]",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DataSource = Annotated[
|
||||||
|
URIDataSource | RowsDataSource,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(DataSource, name="DataSource")
|
||||||
|
|
||||||
|
|
||||||
|
class CommonDatasetFields(BaseModel):
|
||||||
|
"""Common fields for a dataset."""
|
||||||
|
|
||||||
|
purpose: DatasetPurpose
|
||||||
|
source: DataSource
|
||||||
|
metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Any additional metadata for this dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Dataset(CommonDatasetFields, Resource):
|
||||||
|
"""Dataset resource for storing and accessing training or evaluation data."""
|
||||||
|
|
||||||
|
type: Literal[ResourceType.dataset] = Field(
|
||||||
|
default=ResourceType.dataset, description="Type of resource, always 'dataset' for datasets"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_dataset_id(self) -> str | None:
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||||
|
"""Input parameters for dataset operations."""
|
||||||
|
|
||||||
|
dataset_id: str = Field(..., description="Unique identifier for the dataset")
|
||||||
|
|
||||||
|
|
||||||
|
class ListDatasetsResponse(BaseModel):
|
||||||
|
"""Response from listing datasets."""
|
||||||
|
|
||||||
|
data: list[Dataset] = Field(..., description="List of datasets")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RegisterDatasetRequest(BaseModel):
|
||||||
|
"""Request model for registering a dataset."""
|
||||||
|
|
||||||
|
purpose: DatasetPurpose = Field(..., description="The purpose of the dataset")
|
||||||
|
source: DataSource = Field(..., description="The data source of the dataset")
|
||||||
|
metadata: dict[str, Any] | None = Field(default=None, description="The metadata for the dataset")
|
||||||
|
dataset_id: str | None = Field(
|
||||||
|
default=None, description="The ID of the dataset. If not provided, an ID will be generated"
|
||||||
|
)
|
||||||
140
src/llama_stack/apis/datasets/routes.py
Normal file
140
src/llama_stack/apis/datasets/routes.py
Normal file
|
|
@ -0,0 +1,140 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Request
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .datasets_service import DatasetsService
|
||||||
|
from .models import Dataset, ListDatasetsResponse, RegisterDatasetRequest
|
||||||
|
|
||||||
|
|
||||||
|
def get_datasets_service(request: Request) -> DatasetsService:
|
||||||
|
"""Dependency to get the datasets service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.datasets not in impls:
|
||||||
|
raise ValueError("Datasets API implementation not found")
|
||||||
|
return impls[Api.datasets]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Datasets"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
router_v1beta = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1BETA}",
|
||||||
|
tags=["Datasets"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/datasets",
|
||||||
|
response_model=Dataset,
|
||||||
|
summary="Register a new dataset",
|
||||||
|
description="Register a new dataset",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1beta.post(
|
||||||
|
"/datasets",
|
||||||
|
response_model=Dataset,
|
||||||
|
summary="Register a new dataset",
|
||||||
|
description="Register a new dataset",
|
||||||
|
)
|
||||||
|
async def register_dataset(
|
||||||
|
body: RegisterDatasetRequest = Body(...),
|
||||||
|
svc: DatasetsService = Depends(get_datasets_service),
|
||||||
|
) -> Dataset:
|
||||||
|
"""Register a new dataset."""
|
||||||
|
return await svc.register_dataset(
|
||||||
|
purpose=body.purpose,
|
||||||
|
source=body.source,
|
||||||
|
metadata=body.metadata,
|
||||||
|
dataset_id=body.dataset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/datasets/{dataset_id:path}",
|
||||||
|
response_model=Dataset,
|
||||||
|
summary="Get a dataset by its ID",
|
||||||
|
description="Get a dataset by its ID",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1beta.get(
|
||||||
|
"/datasets/{{dataset_id:path}}",
|
||||||
|
response_model=Dataset,
|
||||||
|
summary="Get a dataset by its ID",
|
||||||
|
description="Get a dataset by its ID",
|
||||||
|
)
|
||||||
|
async def get_dataset(
|
||||||
|
dataset_id: Annotated[str, FastAPIPath(..., description="The ID of the dataset to get")],
|
||||||
|
svc: DatasetsService = Depends(get_datasets_service),
|
||||||
|
) -> Dataset:
|
||||||
|
"""Get a dataset by its ID."""
|
||||||
|
return await svc.get_dataset(dataset_id=dataset_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/datasets",
|
||||||
|
response_model=ListDatasetsResponse,
|
||||||
|
summary="List all datasets",
|
||||||
|
description="List all datasets",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1beta.get(
|
||||||
|
"/datasets",
|
||||||
|
response_model=ListDatasetsResponse,
|
||||||
|
summary="List all datasets",
|
||||||
|
description="List all datasets",
|
||||||
|
)
|
||||||
|
async def list_datasets(svc: DatasetsService = Depends(get_datasets_service)) -> ListDatasetsResponse:
|
||||||
|
"""List all datasets."""
|
||||||
|
return await svc.list_datasets()
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/datasets/{dataset_id:path}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Unregister a dataset by its ID",
|
||||||
|
description="Unregister a dataset by its ID",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1beta.delete(
|
||||||
|
"/datasets/{{dataset_id:path}}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Unregister a dataset by its ID",
|
||||||
|
description="Unregister a dataset by its ID",
|
||||||
|
)
|
||||||
|
async def unregister_dataset(
|
||||||
|
dataset_id: Annotated[str, FastAPIPath(..., description="The ID of the dataset to unregister")],
|
||||||
|
svc: DatasetsService = Depends(get_datasets_service),
|
||||||
|
) -> None:
|
||||||
|
"""Unregister a dataset by its ID."""
|
||||||
|
await svc.unregister_dataset(dataset_id=dataset_id)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_datasets_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Datasets API (legacy compatibility)."""
|
||||||
|
main_router = APIRouter()
|
||||||
|
main_router.include_router(router)
|
||||||
|
main_router.include_router(router_v1beta)
|
||||||
|
return main_router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.datasets, create_datasets_router)
|
||||||
|
|
@ -130,23 +130,22 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
||||||
|
|
||||||
# built-in API
|
# built-in API
|
||||||
inspect = "inspect"
|
inspect = "inspect"
|
||||||
|
synthetic_data_generation = "synthetic_data_generation"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Error(BaseModel):
|
class Error(BaseModel):
|
||||||
"""
|
"""Error response from the API. Roughly follows RFC 7807."""
|
||||||
Error response from the API. Roughly follows RFC 7807.
|
|
||||||
|
|
||||||
:param status: HTTP status code
|
status: int = Field(..., description="HTTP status code")
|
||||||
:param title: Error title, a short summary of the error which is invariant for an error type
|
title: str = Field(
|
||||||
:param detail: Error detail, a longer human-readable description of the error
|
..., description="Error title, a short summary of the error which is invariant for an error type"
|
||||||
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
|
)
|
||||||
"""
|
detail: str = Field(..., description="Error detail, a longer human-readable description of the error")
|
||||||
|
instance: str | None = Field(
|
||||||
status: int
|
None,
|
||||||
title: str
|
description="(Optional) A URL which can be used to retrieve more information about the specific occurrence of the error",
|
||||||
detail: str
|
)
|
||||||
instance: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExternalApiSpec(BaseModel):
|
class ExternalApiSpec(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -4,4 +4,28 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .eval import *
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .eval_service import EvalService
|
||||||
|
from .models import (
|
||||||
|
AgentCandidate,
|
||||||
|
BenchmarkConfig,
|
||||||
|
EvalCandidate,
|
||||||
|
EvaluateResponse,
|
||||||
|
EvaluateRowsRequest,
|
||||||
|
ModelCandidate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backward compatibility - export Eval as alias for EvalService
|
||||||
|
Eval = EvalService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Eval",
|
||||||
|
"EvalService",
|
||||||
|
"ModelCandidate",
|
||||||
|
"AgentCandidate",
|
||||||
|
"EvalCandidate",
|
||||||
|
"BenchmarkConfig",
|
||||||
|
"EvaluateResponse",
|
||||||
|
"EvaluateRowsRequest",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,150 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Annotated, Any, Literal, Protocol
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig
|
|
||||||
from llama_stack.apis.common.job_types import Job
|
|
||||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
|
||||||
from llama_stack.apis.scoring import ScoringResult
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ModelCandidate(BaseModel):
|
|
||||||
"""A model candidate for evaluation.
|
|
||||||
|
|
||||||
:param model: The model ID to evaluate.
|
|
||||||
:param sampling_params: The sampling parameters for the model.
|
|
||||||
:param system_message: (Optional) The system message providing instructions or context to the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["model"] = "model"
|
|
||||||
model: str
|
|
||||||
sampling_params: SamplingParams
|
|
||||||
system_message: SystemMessage | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentCandidate(BaseModel):
|
|
||||||
"""An agent candidate for evaluation.
|
|
||||||
|
|
||||||
:param config: The configuration for the agent candidate.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["agent"] = "agent"
|
|
||||||
config: AgentConfig
|
|
||||||
|
|
||||||
|
|
||||||
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
|
|
||||||
register_schema(EvalCandidate, name="EvalCandidate")
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BenchmarkConfig(BaseModel):
|
|
||||||
"""A benchmark configuration for evaluation.
|
|
||||||
|
|
||||||
:param eval_candidate: The candidate to evaluate.
|
|
||||||
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
|
|
||||||
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
|
|
||||||
"""
|
|
||||||
|
|
||||||
eval_candidate: EvalCandidate
|
|
||||||
scoring_params: dict[str, ScoringFnParams] = Field(
|
|
||||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
|
||||||
default_factory=dict,
|
|
||||||
)
|
|
||||||
num_examples: int | None = Field(
|
|
||||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
# we could optinally add any specific dataset config here
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class EvaluateResponse(BaseModel):
|
|
||||||
"""The response from an evaluation.
|
|
||||||
|
|
||||||
:param generations: The generations from the evaluation.
|
|
||||||
:param scores: The scores from the evaluation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
generations: list[dict[str, Any]]
|
|
||||||
# each key in the dict is a scoring function name
|
|
||||||
scores: dict[str, ScoringResult]
|
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
|
||||||
"""Evaluations
|
|
||||||
|
|
||||||
Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def run_eval(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> Job:
|
|
||||||
"""Run an evaluation on a benchmark.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
|
||||||
:param benchmark_config: The configuration for the benchmark.
|
|
||||||
:returns: The job that was created to run the evaluation.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def evaluate_rows(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
input_rows: list[dict[str, Any]],
|
|
||||||
scoring_functions: list[str],
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
"""Evaluate a list of rows on a benchmark.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
|
||||||
:param input_rows: The rows to evaluate.
|
|
||||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
|
||||||
:param benchmark_config: The configuration for the benchmark.
|
|
||||||
:returns: EvaluateResponse object containing generations and scores.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
|
||||||
"""Get the status of a job.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
|
||||||
:param job_id: The ID of the job to get the status of.
|
|
||||||
:returns: The status of the evaluation job.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
|
||||||
"""Cancel a job.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
|
||||||
:param job_id: The ID of the job to cancel.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET", level=LLAMA_STACK_API_V1ALPHA
|
|
||||||
)
|
|
||||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
|
||||||
"""Get the result of a job.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
|
||||||
:param job_id: The ID of the job to get the result of.
|
|
||||||
:returns: The result of the job.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
50
src/llama_stack/apis/eval/eval_service.py
Normal file
50
src/llama_stack/apis/eval/eval_service.py
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_stack.apis.common.job_types import Job
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
from .models import BenchmarkConfig, EvaluateResponse
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class EvalService(Protocol):
|
||||||
|
"""Evaluations
|
||||||
|
|
||||||
|
Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||||
|
|
||||||
|
async def run_eval(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> Job:
|
||||||
|
"""Run an evaluation on a benchmark."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def evaluate_rows(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
input_rows: list[dict[str, Any]],
|
||||||
|
scoring_functions: list[str],
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
"""Evaluate a list of rows on a benchmark."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||||
|
"""Get the status of a job."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||||
|
"""Cancel a job."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||||
|
"""Get the result of a job."""
|
||||||
|
...
|
||||||
73
src/llama_stack/apis/eval/models.py
Normal file
73
src/llama_stack/apis/eval/models.py
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import AgentConfig
|
||||||
|
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||||
|
from llama_stack.apis.scoring.models import ScoringResult
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ModelCandidate(BaseModel):
|
||||||
|
"""A model candidate for evaluation."""
|
||||||
|
|
||||||
|
type: Literal["model"] = Field(default="model", description="The type of candidate.")
|
||||||
|
model: str = Field(..., description="The model ID to evaluate.")
|
||||||
|
sampling_params: SamplingParams = Field(..., description="The sampling parameters for the model.")
|
||||||
|
system_message: SystemMessage | None = Field(
|
||||||
|
default=None, description="The system message providing instructions or context to the model."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentCandidate(BaseModel):
|
||||||
|
"""An agent candidate for evaluation."""
|
||||||
|
|
||||||
|
type: Literal["agent"] = Field(default="agent", description="The type of candidate.")
|
||||||
|
config: AgentConfig = Field(..., description="The configuration for the agent candidate.")
|
||||||
|
|
||||||
|
|
||||||
|
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
|
||||||
|
register_schema(EvalCandidate, name="EvalCandidate")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BenchmarkConfig(BaseModel):
|
||||||
|
"""A benchmark configuration for evaluation."""
|
||||||
|
|
||||||
|
eval_candidate: EvalCandidate = Field(..., description="The candidate to evaluate.")
|
||||||
|
scoring_params: dict[str, ScoringFnParams] = Field(
|
||||||
|
description="Map between scoring function id and parameters for each scoring function you want to run.",
|
||||||
|
default_factory=dict,
|
||||||
|
)
|
||||||
|
num_examples: int | None = Field(
|
||||||
|
description="The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluateResponse(BaseModel):
|
||||||
|
"""The response from an evaluation."""
|
||||||
|
|
||||||
|
generations: list[dict[str, Any]] = Field(..., description="The generations from the evaluation.")
|
||||||
|
scores: dict[str, ScoringResult] = Field(
|
||||||
|
..., description="The scores from the evaluation. Each key in the dict is a scoring function name."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluateRowsRequest(BaseModel):
|
||||||
|
"""Request model for evaluating rows."""
|
||||||
|
|
||||||
|
input_rows: list[dict[str, Any]] = Field(..., description="The rows to evaluate.")
|
||||||
|
scoring_functions: list[str] = Field(..., description="The scoring functions to use for the evaluation.")
|
||||||
|
benchmark_config: BenchmarkConfig = Field(..., description="The configuration for the benchmark.")
|
||||||
170
src/llama_stack/apis/eval/routes.py
Normal file
170
src/llama_stack/apis/eval/routes.py
Normal file
|
|
@ -0,0 +1,170 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Request
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
|
||||||
|
from llama_stack.apis.common.job_types import Job
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .eval_service import EvalService
|
||||||
|
from .models import BenchmarkConfig, EvaluateResponse, EvaluateRowsRequest
|
||||||
|
|
||||||
|
|
||||||
|
def get_eval_service(request: Request) -> EvalService:
|
||||||
|
"""Dependency to get the eval service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.eval not in impls:
|
||||||
|
raise ValueError("Eval API implementation not found")
|
||||||
|
return impls[Api.eval]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Eval"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
router_v1alpha = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
|
||||||
|
tags=["Eval"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/eval/benchmarks/{benchmark_id}/jobs",
|
||||||
|
response_model=Job,
|
||||||
|
summary="Run an evaluation on a benchmark",
|
||||||
|
description="Run an evaluation on a benchmark",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.post(
|
||||||
|
"/eval/benchmarks/{{benchmark_id}}/jobs",
|
||||||
|
response_model=Job,
|
||||||
|
summary="Run an evaluation on a benchmark",
|
||||||
|
description="Run an evaluation on a benchmark",
|
||||||
|
)
|
||||||
|
async def run_eval(
|
||||||
|
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to run the evaluation on")],
|
||||||
|
benchmark_config: BenchmarkConfig = Body(...),
|
||||||
|
svc: EvalService = Depends(get_eval_service),
|
||||||
|
) -> Job:
|
||||||
|
"""Run an evaluation on a benchmark."""
|
||||||
|
return await svc.run_eval(benchmark_id=benchmark_id, benchmark_config=benchmark_config)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/eval/benchmarks/{benchmark_id}/evaluations",
|
||||||
|
response_model=EvaluateResponse,
|
||||||
|
summary="Evaluate a list of rows on a benchmark",
|
||||||
|
description="Evaluate a list of rows on a benchmark",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.post(
|
||||||
|
"/eval/benchmarks/{{benchmark_id}}/evaluations",
|
||||||
|
response_model=EvaluateResponse,
|
||||||
|
summary="Evaluate a list of rows on a benchmark",
|
||||||
|
description="Evaluate a list of rows on a benchmark",
|
||||||
|
)
|
||||||
|
async def evaluate_rows(
|
||||||
|
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to run the evaluation on")],
|
||||||
|
body: EvaluateRowsRequest = Body(...),
|
||||||
|
svc: EvalService = Depends(get_eval_service),
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
"""Evaluate a list of rows on a benchmark."""
|
||||||
|
return await svc.evaluate_rows(
|
||||||
|
benchmark_id=benchmark_id,
|
||||||
|
input_rows=body.input_rows,
|
||||||
|
scoring_functions=body.scoring_functions,
|
||||||
|
benchmark_config=body.benchmark_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
|
||||||
|
response_model=Job,
|
||||||
|
summary="Get the status of a job",
|
||||||
|
description="Get the status of a job",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/eval/benchmarks/{{benchmark_id}}/jobs/{{job_id}}",
|
||||||
|
response_model=Job,
|
||||||
|
summary="Get the status of a job",
|
||||||
|
description="Get the status of a job",
|
||||||
|
)
|
||||||
|
async def job_status(
|
||||||
|
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to run the evaluation on")],
|
||||||
|
job_id: Annotated[str, FastAPIPath(..., description="The ID of the job to get the status of")],
|
||||||
|
svc: EvalService = Depends(get_eval_service),
|
||||||
|
) -> Job:
|
||||||
|
"""Get the status of a job."""
|
||||||
|
return await svc.job_status(benchmark_id=benchmark_id, job_id=job_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Cancel a job",
|
||||||
|
description="Cancel a job",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.delete(
|
||||||
|
"/eval/benchmarks/{{benchmark_id}}/jobs/{{job_id}}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Cancel a job",
|
||||||
|
description="Cancel a job",
|
||||||
|
)
|
||||||
|
async def job_cancel(
|
||||||
|
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to run the evaluation on")],
|
||||||
|
job_id: Annotated[str, FastAPIPath(..., description="The ID of the job to cancel")],
|
||||||
|
svc: EvalService = Depends(get_eval_service),
|
||||||
|
) -> None:
|
||||||
|
"""Cancel a job."""
|
||||||
|
await svc.job_cancel(benchmark_id=benchmark_id, job_id=job_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result",
|
||||||
|
response_model=EvaluateResponse,
|
||||||
|
summary="Get the result of a job",
|
||||||
|
description="Get the result of a job",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/eval/benchmarks/{{benchmark_id}}/jobs/{{job_id}}/result",
|
||||||
|
response_model=EvaluateResponse,
|
||||||
|
summary="Get the result of a job",
|
||||||
|
description="Get the result of a job",
|
||||||
|
)
|
||||||
|
async def job_result(
|
||||||
|
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to run the evaluation on")],
|
||||||
|
job_id: Annotated[str, FastAPIPath(..., description="The ID of the job to get the result of")],
|
||||||
|
svc: EvalService = Depends(get_eval_service),
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
"""Get the result of a job."""
|
||||||
|
return await svc.job_result(benchmark_id=benchmark_id, job_id=job_id)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_eval_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Eval API (legacy compatibility)."""
|
||||||
|
main_router = APIRouter()
|
||||||
|
main_router.include_router(router)
|
||||||
|
main_router.include_router(router_v1alpha)
|
||||||
|
return main_router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.eval, create_eval_router)
|
||||||
|
|
@ -4,4 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .files import *
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .files_service import FileService
|
||||||
|
from .models import (
|
||||||
|
ExpiresAfter,
|
||||||
|
ListOpenAIFileResponse,
|
||||||
|
OpenAIFileDeleteResponse,
|
||||||
|
OpenAIFileObject,
|
||||||
|
OpenAIFilePurpose,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backward compatibility - export Files as alias for FileService
|
||||||
|
Files = FileService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Files",
|
||||||
|
"FileService",
|
||||||
|
"OpenAIFileObject",
|
||||||
|
"OpenAIFilePurpose",
|
||||||
|
"ExpiresAfter",
|
||||||
|
"ListOpenAIFileResponse",
|
||||||
|
"OpenAIFileDeleteResponse",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,194 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Annotated, ClassVar, Literal, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from fastapi import File, Form, Response, UploadFile
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import Order
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
# OpenAI Files API Models
|
|
||||||
class OpenAIFilePurpose(StrEnum):
|
|
||||||
"""
|
|
||||||
Valid purpose values for OpenAI Files API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
ASSISTANTS = "assistants"
|
|
||||||
BATCH = "batch"
|
|
||||||
# TODO: Add other purposes as needed
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIFileObject(BaseModel):
|
|
||||||
"""
|
|
||||||
OpenAI File object as defined in the OpenAI Files API.
|
|
||||||
|
|
||||||
:param object: The object type, which is always "file"
|
|
||||||
:param id: The file identifier, which can be referenced in the API endpoints
|
|
||||||
:param bytes: The size of the file, in bytes
|
|
||||||
:param created_at: The Unix timestamp (in seconds) for when the file was created
|
|
||||||
:param expires_at: The Unix timestamp (in seconds) for when the file expires
|
|
||||||
:param filename: The name of the file
|
|
||||||
:param purpose: The intended purpose of the file
|
|
||||||
"""
|
|
||||||
|
|
||||||
object: Literal["file"] = "file"
|
|
||||||
id: str
|
|
||||||
bytes: int
|
|
||||||
created_at: int
|
|
||||||
expires_at: int
|
|
||||||
filename: str
|
|
||||||
purpose: OpenAIFilePurpose
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ExpiresAfter(BaseModel):
|
|
||||||
"""
|
|
||||||
Control expiration of uploaded files.
|
|
||||||
|
|
||||||
Params:
|
|
||||||
- anchor, must be "created_at"
|
|
||||||
- seconds, must be int between 3600 and 2592000 (1 hour to 30 days)
|
|
||||||
"""
|
|
||||||
|
|
||||||
MIN: ClassVar[int] = 3600 # 1 hour
|
|
||||||
MAX: ClassVar[int] = 2592000 # 30 days
|
|
||||||
|
|
||||||
anchor: Literal["created_at"]
|
|
||||||
seconds: int = Field(..., ge=3600, le=2592000)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListOpenAIFileResponse(BaseModel):
|
|
||||||
"""
|
|
||||||
Response for listing files in OpenAI Files API.
|
|
||||||
|
|
||||||
:param data: List of file objects
|
|
||||||
:param has_more: Whether there are more files available beyond this page
|
|
||||||
:param first_id: ID of the first file in the list for pagination
|
|
||||||
:param last_id: ID of the last file in the list for pagination
|
|
||||||
:param object: The object type, which is always "list"
|
|
||||||
"""
|
|
||||||
|
|
||||||
data: list[OpenAIFileObject]
|
|
||||||
has_more: bool
|
|
||||||
first_id: str
|
|
||||||
last_id: str
|
|
||||||
object: Literal["list"] = "list"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIFileDeleteResponse(BaseModel):
|
|
||||||
"""
|
|
||||||
Response for deleting a file in OpenAI Files API.
|
|
||||||
|
|
||||||
:param id: The file identifier that was deleted
|
|
||||||
:param object: The object type, which is always "file"
|
|
||||||
:param deleted: Whether the file was successfully deleted
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
object: Literal["file"] = "file"
|
|
||||||
deleted: bool
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
@trace_protocol
|
|
||||||
class Files(Protocol):
|
|
||||||
"""Files
|
|
||||||
|
|
||||||
This API is used to upload documents that can be used with other Llama Stack APIs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# OpenAI Files API Endpoints
|
|
||||||
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def openai_upload_file(
|
|
||||||
self,
|
|
||||||
file: Annotated[UploadFile, File()],
|
|
||||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
|
||||||
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
|
|
||||||
) -> OpenAIFileObject:
|
|
||||||
"""Upload file.
|
|
||||||
|
|
||||||
Upload a file that can be used across various endpoints.
|
|
||||||
|
|
||||||
The file upload should be a multipart form request with:
|
|
||||||
- file: The File object (not file name) to be uploaded.
|
|
||||||
- purpose: The intended purpose of the uploaded file.
|
|
||||||
- expires_after: Optional form values describing expiration for the file.
|
|
||||||
|
|
||||||
:param file: The uploaded file object containing content and metadata (filename, content_type, etc.).
|
|
||||||
:param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune").
|
|
||||||
:param expires_after: Optional form values describing expiration for the file.
|
|
||||||
:returns: An OpenAIFileObject representing the uploaded file.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def openai_list_files(
|
|
||||||
self,
|
|
||||||
after: str | None = None,
|
|
||||||
limit: int | None = 10000,
|
|
||||||
order: Order | None = Order.desc,
|
|
||||||
purpose: OpenAIFilePurpose | None = None,
|
|
||||||
) -> ListOpenAIFileResponse:
|
|
||||||
"""List files.
|
|
||||||
|
|
||||||
Returns a list of files that belong to the user's organization.
|
|
||||||
|
|
||||||
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
|
|
||||||
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 10,000, and the default is 10,000.
|
|
||||||
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
|
|
||||||
:param purpose: Only return files with the given purpose.
|
|
||||||
:returns: An ListOpenAIFileResponse containing the list of files.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def openai_retrieve_file(
|
|
||||||
self,
|
|
||||||
file_id: str,
|
|
||||||
) -> OpenAIFileObject:
|
|
||||||
"""Retrieve file.
|
|
||||||
|
|
||||||
Returns information about a specific file.
|
|
||||||
|
|
||||||
:param file_id: The ID of the file to use for this request.
|
|
||||||
:returns: An OpenAIFileObject containing file information.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
|
||||||
async def openai_delete_file(
|
|
||||||
self,
|
|
||||||
file_id: str,
|
|
||||||
) -> OpenAIFileDeleteResponse:
|
|
||||||
"""Delete file.
|
|
||||||
|
|
||||||
:param file_id: The ID of the file to use for this request.
|
|
||||||
:returns: An OpenAIFileDeleteResponse indicating successful deletion.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def openai_retrieve_file_content(
|
|
||||||
self,
|
|
||||||
file_id: str,
|
|
||||||
) -> Response:
|
|
||||||
"""Retrieve file content.
|
|
||||||
|
|
||||||
Returns the contents of the specified file.
|
|
||||||
|
|
||||||
:param file_id: The ID of the file to use for this request.
|
|
||||||
:returns: The raw file content as a binary response.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
70
src/llama_stack/apis/files/files_service.py
Normal file
70
src/llama_stack/apis/files/files_service.py
Normal file
|
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from fastapi import File, Form, Response, UploadFile
|
||||||
|
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
ExpiresAfter,
|
||||||
|
ListOpenAIFileResponse,
|
||||||
|
OpenAIFileDeleteResponse,
|
||||||
|
OpenAIFileObject,
|
||||||
|
OpenAIFilePurpose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class FileService(Protocol):
|
||||||
|
"""Files
|
||||||
|
|
||||||
|
This API is used to upload documents that can be used with other Llama Stack APIs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# OpenAI Files API Endpoints
|
||||||
|
async def openai_upload_file(
|
||||||
|
self,
|
||||||
|
file: Annotated[UploadFile, File()],
|
||||||
|
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||||
|
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
"""Upload file."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def openai_list_files(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 10000,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
purpose: OpenAIFilePurpose | None = None,
|
||||||
|
) -> ListOpenAIFileResponse:
|
||||||
|
"""List files."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def openai_retrieve_file(
|
||||||
|
self,
|
||||||
|
file_id: str,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
"""Retrieve file."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def openai_delete_file(
|
||||||
|
self,
|
||||||
|
file_id: str,
|
||||||
|
) -> OpenAIFileDeleteResponse:
|
||||||
|
"""Delete file."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def openai_retrieve_file_content(
|
||||||
|
self,
|
||||||
|
file_id: str,
|
||||||
|
) -> Response:
|
||||||
|
"""Retrieve file content."""
|
||||||
|
...
|
||||||
66
src/llama_stack/apis/files/models.py
Normal file
66
src/llama_stack/apis/files/models.py
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import ClassVar, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIFilePurpose(StrEnum):
|
||||||
|
"""
|
||||||
|
Valid purpose values for OpenAI Files API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ASSISTANTS = "assistants"
|
||||||
|
BATCH = "batch"
|
||||||
|
# TODO: Add other purposes as needed
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIFileObject(BaseModel):
|
||||||
|
"""OpenAI File object as defined in the OpenAI Files API."""
|
||||||
|
|
||||||
|
object: Literal["file"] = Field(default="file", description="The object type, which is always 'file'.")
|
||||||
|
id: str = Field(..., description="The file identifier, which can be referenced in the API endpoints.")
|
||||||
|
bytes: int = Field(..., description="The size of the file, in bytes.")
|
||||||
|
created_at: int = Field(..., description="The Unix timestamp (in seconds) for when the file was created.")
|
||||||
|
expires_at: int = Field(..., description="The Unix timestamp (in seconds) for when the file expires.")
|
||||||
|
filename: str = Field(..., description="The name of the file.")
|
||||||
|
purpose: OpenAIFilePurpose = Field(..., description="The intended purpose of the file.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ExpiresAfter(BaseModel):
|
||||||
|
"""Control expiration of uploaded files."""
|
||||||
|
|
||||||
|
MIN: ClassVar[int] = 3600 # 1 hour
|
||||||
|
MAX: ClassVar[int] = 2592000 # 30 days
|
||||||
|
|
||||||
|
anchor: Literal["created_at"] = Field(..., description="Anchor must be 'created_at'.")
|
||||||
|
seconds: int = Field(..., ge=3600, le=2592000, description="Seconds between 3600 and 2592000 (1 hour to 30 days).")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ListOpenAIFileResponse(BaseModel):
|
||||||
|
"""Response for listing files in OpenAI Files API."""
|
||||||
|
|
||||||
|
data: list[OpenAIFileObject] = Field(..., description="List of file objects.")
|
||||||
|
has_more: bool = Field(..., description="Whether there are more files available beyond this page.")
|
||||||
|
first_id: str = Field(..., description="ID of the first file in the list for pagination.")
|
||||||
|
last_id: str = Field(..., description="ID of the last file in the list for pagination.")
|
||||||
|
object: Literal["list"] = Field(default="list", description="The object type, which is always 'list'.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIFileDeleteResponse(BaseModel):
|
||||||
|
"""Response for deleting a file in OpenAI Files API."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="The file identifier that was deleted.")
|
||||||
|
object: Literal["file"] = Field(default="file", description="The object type, which is always 'file'.")
|
||||||
|
deleted: bool = Field(..., description="Whether the file was successfully deleted.")
|
||||||
135
src/llama_stack/apis/files/routes.py
Normal file
135
src/llama_stack/apis/files/routes.py
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, File, Form, Query, Request, Response, UploadFile
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .files_service import FileService
|
||||||
|
from .models import (
|
||||||
|
ExpiresAfter,
|
||||||
|
ListOpenAIFileResponse,
|
||||||
|
OpenAIFileDeleteResponse,
|
||||||
|
OpenAIFileObject,
|
||||||
|
OpenAIFilePurpose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_service(request: Request) -> FileService:
|
||||||
|
"""Dependency to get the file service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.files not in impls:
|
||||||
|
raise ValueError("Files API implementation not found")
|
||||||
|
return impls[Api.files]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Files"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/files",
|
||||||
|
response_model=OpenAIFileObject,
|
||||||
|
summary="Upload file.",
|
||||||
|
description="Upload a file that can be used across various endpoints.",
|
||||||
|
)
|
||||||
|
async def openai_upload_file(
|
||||||
|
file: Annotated[UploadFile, File(..., description="The File object to be uploaded.")],
|
||||||
|
purpose: Annotated[OpenAIFilePurpose, Form(..., description="The intended purpose of the uploaded file.")],
|
||||||
|
expires_after: Annotated[
|
||||||
|
ExpiresAfter | None, Form(description="Optional form values describing expiration for the file.")
|
||||||
|
] = None,
|
||||||
|
svc: FileService = Depends(get_file_service),
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
"""Upload a file."""
|
||||||
|
return await svc.openai_upload_file(file=file, purpose=purpose, expires_after=expires_after)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/files",
|
||||||
|
response_model=ListOpenAIFileResponse,
|
||||||
|
summary="List files.",
|
||||||
|
description="Returns a list of files that belong to the user's organization.",
|
||||||
|
)
|
||||||
|
async def openai_list_files(
|
||||||
|
after: str | None = Query(
|
||||||
|
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||||
|
),
|
||||||
|
limit: int | None = Query(
|
||||||
|
10000,
|
||||||
|
description="A limit on the number of objects to be returned. Limit can range between 1 and 10,000, and the default is 10,000.",
|
||||||
|
),
|
||||||
|
order: Order | None = Query(
|
||||||
|
Order.desc,
|
||||||
|
description="Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.",
|
||||||
|
),
|
||||||
|
purpose: OpenAIFilePurpose | None = Query(None, description="Only return files with the given purpose."),
|
||||||
|
svc: FileService = Depends(get_file_service),
|
||||||
|
) -> ListOpenAIFileResponse:
|
||||||
|
"""List files."""
|
||||||
|
return await svc.openai_list_files(after=after, limit=limit, order=order, purpose=purpose)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/files/{file_id}",
|
||||||
|
response_model=OpenAIFileObject,
|
||||||
|
summary="Retrieve file.",
|
||||||
|
description="Returns information about a specific file.",
|
||||||
|
)
|
||||||
|
async def openai_retrieve_file(
|
||||||
|
file_id: Annotated[str, FastAPIPath(..., description="The ID of the file to use for this request.")],
|
||||||
|
svc: FileService = Depends(get_file_service),
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
"""Retrieve file information."""
|
||||||
|
return await svc.openai_retrieve_file(file_id=file_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/files/{file_id}",
|
||||||
|
response_model=OpenAIFileDeleteResponse,
|
||||||
|
summary="Delete file.",
|
||||||
|
description="Delete a file.",
|
||||||
|
)
|
||||||
|
async def openai_delete_file(
|
||||||
|
file_id: Annotated[str, FastAPIPath(..., description="The ID of the file to use for this request.")],
|
||||||
|
svc: FileService = Depends(get_file_service),
|
||||||
|
) -> OpenAIFileDeleteResponse:
|
||||||
|
"""Delete a file."""
|
||||||
|
return await svc.openai_delete_file(file_id=file_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/files/{file_id}/content",
|
||||||
|
response_class=Response,
|
||||||
|
summary="Retrieve file content.",
|
||||||
|
description="Returns the contents of the specified file.",
|
||||||
|
)
|
||||||
|
async def openai_retrieve_file_content(
|
||||||
|
file_id: Annotated[str, FastAPIPath(..., description="The ID of the file to use for this request.")],
|
||||||
|
svc: FileService = Depends(get_file_service),
|
||||||
|
) -> Response:
|
||||||
|
"""Retrieve file content."""
|
||||||
|
return await svc.openai_retrieve_file_content(file_id=file_id)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_files_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Files API (legacy compatibility)."""
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.files, create_files_router)
|
||||||
|
|
@ -4,4 +4,206 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .inference import *
|
# Import routes to trigger router registration
|
||||||
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat
|
||||||
|
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .inference_service import InferenceService
|
||||||
|
from .models import (
|
||||||
|
Bf16QuantizationConfig,
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseEvent,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
CompletionMessage,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseStreamChunk,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
|
Fp8QuantizationConfig,
|
||||||
|
GrammarResponseFormat,
|
||||||
|
GreedySamplingStrategy,
|
||||||
|
Int4QuantizationConfig,
|
||||||
|
JsonSchemaResponseFormat,
|
||||||
|
ListOpenAIChatCompletionResponse,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ModelStore,
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
OpenAIChatCompletionMessageContent,
|
||||||
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
OpenAIChatCompletionTextOnlyMessageContent,
|
||||||
|
OpenAIChatCompletionToolCall,
|
||||||
|
OpenAIChatCompletionToolCallFunction,
|
||||||
|
OpenAIChatCompletionUsage,
|
||||||
|
OpenAIChatCompletionUsageCompletionTokensDetails,
|
||||||
|
OpenAIChatCompletionUsagePromptTokensDetails,
|
||||||
|
OpenAIChoice,
|
||||||
|
OpenAIChoiceDelta,
|
||||||
|
OpenAIChoiceLogprobs,
|
||||||
|
OpenAIChunkChoice,
|
||||||
|
OpenAICompletion,
|
||||||
|
OpenAICompletionChoice,
|
||||||
|
OpenAICompletionLogprobs,
|
||||||
|
OpenAICompletionRequestWithExtraBody,
|
||||||
|
OpenAICompletionWithInputMessages,
|
||||||
|
OpenAIDeveloperMessageParam,
|
||||||
|
OpenAIEmbeddingData,
|
||||||
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
|
OpenAIEmbeddingUsage,
|
||||||
|
OpenAIFile,
|
||||||
|
OpenAIFileFile,
|
||||||
|
OpenAIImageURL,
|
||||||
|
OpenAIJSONSchema,
|
||||||
|
OpenAIMessageParam,
|
||||||
|
OpenAIResponseFormatJSONObject,
|
||||||
|
OpenAIResponseFormatJSONSchema,
|
||||||
|
OpenAIResponseFormatParam,
|
||||||
|
OpenAIResponseFormatText,
|
||||||
|
OpenAISystemMessageParam,
|
||||||
|
OpenAITokenLogProb,
|
||||||
|
OpenAIToolMessageParam,
|
||||||
|
OpenAITopLogProb,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
QuantizationConfig,
|
||||||
|
QuantizationType,
|
||||||
|
RerankData,
|
||||||
|
RerankResponse,
|
||||||
|
ResponseFormat,
|
||||||
|
ResponseFormatType,
|
||||||
|
SamplingParams,
|
||||||
|
SamplingStrategy,
|
||||||
|
SystemMessage,
|
||||||
|
SystemMessageBehavior,
|
||||||
|
TextTruncation,
|
||||||
|
TokenLogProbs,
|
||||||
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
|
ToolResponse,
|
||||||
|
ToolResponseMessage,
|
||||||
|
TopKSamplingStrategy,
|
||||||
|
TopPSamplingStrategy,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backward compatibility - export Inference as alias for InferenceService
|
||||||
|
Inference = InferenceService
|
||||||
|
InferenceProvider = InferenceService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Inference",
|
||||||
|
"InferenceProvider",
|
||||||
|
"InferenceService",
|
||||||
|
"InterleavedContent",
|
||||||
|
"ModelStore",
|
||||||
|
"Order",
|
||||||
|
# Sampling
|
||||||
|
"SamplingParams",
|
||||||
|
"SamplingStrategy",
|
||||||
|
"GreedySamplingStrategy",
|
||||||
|
"TopPSamplingStrategy",
|
||||||
|
"TopKSamplingStrategy",
|
||||||
|
# Quantization
|
||||||
|
"QuantizationConfig",
|
||||||
|
"QuantizationType",
|
||||||
|
"Bf16QuantizationConfig",
|
||||||
|
"Fp8QuantizationConfig",
|
||||||
|
"Int4QuantizationConfig",
|
||||||
|
# Messages
|
||||||
|
"Message",
|
||||||
|
"UserMessage",
|
||||||
|
"SystemMessage",
|
||||||
|
"ToolResponseMessage",
|
||||||
|
"CompletionMessage",
|
||||||
|
# Tools
|
||||||
|
"BuiltinTool",
|
||||||
|
"ToolCall",
|
||||||
|
"ToolChoice",
|
||||||
|
"ToolConfig",
|
||||||
|
"ToolDefinition",
|
||||||
|
"ToolPromptFormat",
|
||||||
|
"ToolResponse",
|
||||||
|
# StopReason
|
||||||
|
"StopReason",
|
||||||
|
# Completion
|
||||||
|
"CompletionRequest",
|
||||||
|
"CompletionResponse",
|
||||||
|
"CompletionResponseStreamChunk",
|
||||||
|
# Chat Completion
|
||||||
|
"ChatCompletionRequest",
|
||||||
|
"ChatCompletionResponse",
|
||||||
|
"ChatCompletionResponseStreamChunk",
|
||||||
|
"ChatCompletionResponseEvent",
|
||||||
|
"ChatCompletionResponseEventType",
|
||||||
|
# Embeddings
|
||||||
|
"EmbeddingsResponse",
|
||||||
|
"EmbeddingTaskType",
|
||||||
|
"TextTruncation",
|
||||||
|
# Rerank
|
||||||
|
"RerankResponse",
|
||||||
|
"RerankData",
|
||||||
|
# Response Format
|
||||||
|
"ResponseFormat",
|
||||||
|
"ResponseFormatType",
|
||||||
|
"JsonSchemaResponseFormat",
|
||||||
|
"GrammarResponseFormat",
|
||||||
|
# Log Probs
|
||||||
|
"LogProbConfig",
|
||||||
|
"TokenLogProbs",
|
||||||
|
# System Message Behavior
|
||||||
|
"SystemMessageBehavior",
|
||||||
|
# OpenAI Models
|
||||||
|
"OpenAICompletion",
|
||||||
|
"OpenAICompletionRequestWithExtraBody",
|
||||||
|
"OpenAICompletionChoice",
|
||||||
|
"OpenAICompletionLogprobs",
|
||||||
|
"OpenAIChatCompletion",
|
||||||
|
"OpenAIChatCompletionRequestWithExtraBody",
|
||||||
|
"OpenAIChatCompletionChunk",
|
||||||
|
"OpenAIChatCompletionUsage",
|
||||||
|
"OpenAIChatCompletionUsageCompletionTokensDetails",
|
||||||
|
"OpenAIChatCompletionUsagePromptTokensDetails",
|
||||||
|
"OpenAIChoice",
|
||||||
|
"OpenAIChoiceDelta",
|
||||||
|
"OpenAIChoiceLogprobs",
|
||||||
|
"OpenAIChunkChoice",
|
||||||
|
"OpenAIMessageParam",
|
||||||
|
"OpenAIUserMessageParam",
|
||||||
|
"OpenAISystemMessageParam",
|
||||||
|
"OpenAIAssistantMessageParam",
|
||||||
|
"OpenAIToolMessageParam",
|
||||||
|
"OpenAIDeveloperMessageParam",
|
||||||
|
"OpenAIChatCompletionContentPartParam",
|
||||||
|
"OpenAIChatCompletionContentPartTextParam",
|
||||||
|
"OpenAIChatCompletionContentPartImageParam",
|
||||||
|
"OpenAIChatCompletionMessageContent",
|
||||||
|
"OpenAIChatCompletionTextOnlyMessageContent",
|
||||||
|
"OpenAIChatCompletionToolCall",
|
||||||
|
"OpenAIChatCompletionToolCallFunction",
|
||||||
|
"OpenAIEmbeddingsRequestWithExtraBody",
|
||||||
|
"OpenAIEmbeddingsResponse",
|
||||||
|
"OpenAIEmbeddingData",
|
||||||
|
"OpenAIEmbeddingUsage",
|
||||||
|
"OpenAIResponseFormatParam",
|
||||||
|
"OpenAIResponseFormatText",
|
||||||
|
"OpenAIResponseFormatJSONSchema",
|
||||||
|
"OpenAIResponseFormatJSONObject",
|
||||||
|
"OpenAIJSONSchema",
|
||||||
|
"OpenAIImageURL",
|
||||||
|
"OpenAIFile",
|
||||||
|
"OpenAIFileFile",
|
||||||
|
"OpenAITokenLogProb",
|
||||||
|
"OpenAITopLogProb",
|
||||||
|
"OpenAICompletionWithInputMessages",
|
||||||
|
"ListOpenAIChatCompletionResponse",
|
||||||
|
]
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
93
src/llama_stack/apis/inference/inference_service.py
Normal file
93
src/llama_stack/apis/inference/inference_service.py
Normal file
|
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Annotated, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from fastapi import Body
|
||||||
|
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
ListOpenAIChatCompletionResponse,
|
||||||
|
ModelStore,
|
||||||
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
OpenAICompletion,
|
||||||
|
OpenAICompletionRequestWithExtraBody,
|
||||||
|
OpenAICompletionWithInputMessages,
|
||||||
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
|
RerankResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class InferenceService(Protocol):
|
||||||
|
"""
|
||||||
|
This protocol defines the interface that should be implemented by all inference providers.
|
||||||
|
|
||||||
|
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||||
|
|
||||||
|
This API provides the raw interface to the underlying models. Three kinds of models are supported:
|
||||||
|
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||||
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
|
- Rerank models: these models reorder the documents based on their relevance to a query.
|
||||||
|
"""
|
||||||
|
|
||||||
|
API_NAMESPACE: str = "Inference"
|
||||||
|
|
||||||
|
model_store: ModelStore | None = None
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
"""Rerank a list of documents based on their relevance to a query."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def openai_completion(
|
||||||
|
self,
|
||||||
|
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
|
||||||
|
) -> OpenAICompletion:
|
||||||
|
"""Create completion."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def openai_chat_completion(
|
||||||
|
self,
|
||||||
|
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
|
||||||
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
|
"""Create chat completions."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)],
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
"""Create embeddings."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def list_chat_completions(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 20,
|
||||||
|
model: str | None = None,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIChatCompletionResponse:
|
||||||
|
"""List chat completions."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||||
|
"""Get chat completion."""
|
||||||
|
...
|
||||||
818
src/llama_stack/apis/inference/models.py
Normal file
818
src/llama_stack/apis/inference/models.py
Normal file
|
|
@ -0,0 +1,818 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum, StrEnum
|
||||||
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||||
|
from llama_stack.apis.models import Model
|
||||||
|
from llama_stack.core.telemetry.telemetry import MetricResponseMixin
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
register_schema(ToolCall)
|
||||||
|
register_schema(ToolDefinition)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GreedySamplingStrategy(BaseModel):
|
||||||
|
"""Greedy sampling strategy that selects the highest probability token at each step."""
|
||||||
|
|
||||||
|
type: Literal["greedy"] = "greedy"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TopPSamplingStrategy(BaseModel):
|
||||||
|
"""Top-p (nucleus) sampling strategy that samples from the smallest set of tokens with cumulative probability >= p."""
|
||||||
|
|
||||||
|
type: Literal["top_p"] = "top_p"
|
||||||
|
temperature: float | None = Field(..., gt=0.0)
|
||||||
|
top_p: float | None = 0.95
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TopKSamplingStrategy(BaseModel):
|
||||||
|
"""Top-k sampling strategy that restricts sampling to the k most likely tokens."""
|
||||||
|
|
||||||
|
type: Literal["top_k"] = "top_k"
|
||||||
|
top_k: int = Field(..., ge=1)
|
||||||
|
|
||||||
|
|
||||||
|
SamplingStrategy = Annotated[
|
||||||
|
GreedySamplingStrategy | TopPSamplingStrategy | TopKSamplingStrategy,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SamplingParams(BaseModel):
|
||||||
|
"""Sampling parameters."""
|
||||||
|
|
||||||
|
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||||
|
|
||||||
|
max_tokens: int | None = None
|
||||||
|
repetition_penalty: float | None = 1.0
|
||||||
|
stop: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LogProbConfig(BaseModel):
|
||||||
|
"""Configuration for log probability generation."""
|
||||||
|
|
||||||
|
top_k: int | None = 0
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationType(Enum):
|
||||||
|
"""Type of model quantization to run inference with."""
|
||||||
|
|
||||||
|
bf16 = "bf16"
|
||||||
|
fp8_mixed = "fp8_mixed"
|
||||||
|
int4_mixed = "int4_mixed"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Fp8QuantizationConfig(BaseModel):
|
||||||
|
"""Configuration for 8-bit floating point quantization."""
|
||||||
|
|
||||||
|
type: Literal["fp8_mixed"] = "fp8_mixed"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Bf16QuantizationConfig(BaseModel):
|
||||||
|
"""Configuration for BFloat16 precision (typically no quantization)."""
|
||||||
|
|
||||||
|
type: Literal["bf16"] = "bf16"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Int4QuantizationConfig(BaseModel):
|
||||||
|
"""Configuration for 4-bit integer quantization."""
|
||||||
|
|
||||||
|
type: Literal["int4_mixed"] = "int4_mixed"
|
||||||
|
scheme: str | None = "int4_weight_int8_dynamic_activation"
|
||||||
|
|
||||||
|
|
||||||
|
QuantizationConfig = Annotated[
|
||||||
|
Bf16QuantizationConfig | Fp8QuantizationConfig | Int4QuantizationConfig,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class UserMessage(BaseModel):
|
||||||
|
"""A message from the user in a chat conversation."""
|
||||||
|
|
||||||
|
role: Literal["user"] = "user"
|
||||||
|
content: InterleavedContent
|
||||||
|
context: InterleavedContent | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SystemMessage(BaseModel):
|
||||||
|
"""A system message providing instructions or context to the model."""
|
||||||
|
|
||||||
|
role: Literal["system"] = "system"
|
||||||
|
content: InterleavedContent
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolResponseMessage(BaseModel):
|
||||||
|
"""A message representing the result of a tool invocation."""
|
||||||
|
|
||||||
|
role: Literal["tool"] = "tool"
|
||||||
|
call_id: str
|
||||||
|
content: InterleavedContent
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CompletionMessage(BaseModel):
|
||||||
|
"""A message containing the model's (assistant) response in a chat conversation.
|
||||||
|
|
||||||
|
- `StopReason.end_of_turn`: The model finished generating the entire response.
|
||||||
|
- `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response.
|
||||||
|
- `StopReason.out_of_tokens`: The model ran out of token budget.
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: Literal["assistant"] = "assistant"
|
||||||
|
content: InterleavedContent
|
||||||
|
stop_reason: StopReason
|
||||||
|
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
|
||||||
|
Message = Annotated[
|
||||||
|
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
|
||||||
|
Field(discriminator="role"),
|
||||||
|
]
|
||||||
|
register_schema(Message, name="Message")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolResponse(BaseModel):
|
||||||
|
"""Response from a tool invocation."""
|
||||||
|
|
||||||
|
call_id: str
|
||||||
|
tool_name: BuiltinTool | str
|
||||||
|
content: InterleavedContent
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
@field_validator("tool_name", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
return BuiltinTool(v)
|
||||||
|
except ValueError:
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ToolChoice(Enum):
|
||||||
|
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."""
|
||||||
|
|
||||||
|
auto = "auto"
|
||||||
|
required = "required"
|
||||||
|
none = "none"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TokenLogProbs(BaseModel):
|
||||||
|
"""Log probabilities for generated tokens."""
|
||||||
|
|
||||||
|
logprobs_by_token: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseEventType(Enum):
|
||||||
|
"""Types of events that can occur during chat completion."""
|
||||||
|
|
||||||
|
start = "start"
|
||||||
|
complete = "complete"
|
||||||
|
progress = "progress"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionResponseEvent(BaseModel):
|
||||||
|
"""An event during chat completion generation."""
|
||||||
|
|
||||||
|
event_type: ChatCompletionResponseEventType
|
||||||
|
delta: ContentDelta
|
||||||
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
stop_reason: StopReason | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormatType(StrEnum):
|
||||||
|
"""Types of formats for structured (guided) decoding."""
|
||||||
|
|
||||||
|
json_schema = "json_schema"
|
||||||
|
grammar = "grammar"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class JsonSchemaResponseFormat(BaseModel):
|
||||||
|
"""Configuration for JSON schema-guided response generation."""
|
||||||
|
|
||||||
|
type: Literal[ResponseFormatType.json_schema] = ResponseFormatType.json_schema
|
||||||
|
json_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GrammarResponseFormat(BaseModel):
|
||||||
|
"""Configuration for grammar-guided response generation."""
|
||||||
|
|
||||||
|
type: Literal[ResponseFormatType.grammar] = ResponseFormatType.grammar
|
||||||
|
bnf: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
ResponseFormat = Annotated[
|
||||||
|
JsonSchemaResponseFormat | GrammarResponseFormat,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(ResponseFormat, name="ResponseFormat")
|
||||||
|
|
||||||
|
|
||||||
|
# This is an internally used class
|
||||||
|
class CompletionRequest(BaseModel):
|
||||||
|
content: InterleavedContent
|
||||||
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
response_format: ResponseFormat | None = None
|
||||||
|
stream: bool | None = False
|
||||||
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CompletionResponse(MetricResponseMixin):
|
||||||
|
"""Response from a completion request."""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
stop_reason: StopReason
|
||||||
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
|
"""A chunk of a streamed completion response."""
|
||||||
|
|
||||||
|
delta: str
|
||||||
|
stop_reason: StopReason | None = None
|
||||||
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessageBehavior(Enum):
|
||||||
|
"""Config for how to override the default system prompt.
|
||||||
|
|
||||||
|
https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt-
|
||||||
|
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
append = "append"
|
||||||
|
replace = "replace"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolConfig(BaseModel):
|
||||||
|
"""Configuration for tool use.
|
||||||
|
|
||||||
|
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||||
|
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||||
|
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
||||||
|
- `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt.
|
||||||
|
- `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string
|
||||||
|
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
||||||
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
||||||
|
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
if isinstance(self.tool_choice, str):
|
||||||
|
try:
|
||||||
|
self.tool_choice = ToolChoice[self.tool_choice]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# This is an internally used class
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
messages: list[Message]
|
||||||
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
|
||||||
|
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
|
||||||
|
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
||||||
|
|
||||||
|
response_format: ResponseFormat | None = None
|
||||||
|
stream: bool | None = False
|
||||||
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
|
"""A chunk of a streamed chat completion response."""
|
||||||
|
|
||||||
|
event: ChatCompletionResponseEvent
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionResponse(MetricResponseMixin):
|
||||||
|
"""Response from a chat completion request."""
|
||||||
|
|
||||||
|
completion_message: CompletionMessage
|
||||||
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EmbeddingsResponse(BaseModel):
|
||||||
|
"""Response containing generated embeddings."""
|
||||||
|
|
||||||
|
embeddings: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RerankData(BaseModel):
|
||||||
|
"""A single rerank result from a reranking response."""
|
||||||
|
|
||||||
|
index: int
|
||||||
|
relevance_score: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RerankResponse(BaseModel):
|
||||||
|
"""Response from a reranking request."""
|
||||||
|
|
||||||
|
data: list[RerankData]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||||
|
"""Text content part for OpenAI-compatible chat completion messages."""
|
||||||
|
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIImageURL(BaseModel):
|
||||||
|
"""Image URL specification for OpenAI-compatible chat completion messages."""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
detail: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
||||||
|
"""Image content part for OpenAI-compatible chat completion messages."""
|
||||||
|
|
||||||
|
type: Literal["image_url"] = "image_url"
|
||||||
|
image_url: OpenAIImageURL
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIFileFile(BaseModel):
|
||||||
|
file_id: str | None = None
|
||||||
|
filename: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIFile(BaseModel):
|
||||||
|
type: Literal["file"] = "file"
|
||||||
|
file: OpenAIFileFile
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIChatCompletionContentPartParam = Annotated[
|
||||||
|
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||||
|
|
||||||
|
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIUserMessageParam(BaseModel):
|
||||||
|
"""A message from the user in an OpenAI-compatible chat completion request."""
|
||||||
|
|
||||||
|
role: Literal["user"] = "user"
|
||||||
|
content: OpenAIChatCompletionMessageContent
|
||||||
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAISystemMessageParam(BaseModel):
|
||||||
|
"""A system message providing instructions or context to the model."""
|
||||||
|
|
||||||
|
role: Literal["system"] = "system"
|
||||||
|
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||||
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||||
|
"""Function call details for OpenAI-compatible tool calls."""
|
||||||
|
|
||||||
|
name: str | None = None
|
||||||
|
arguments: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletionToolCall(BaseModel):
|
||||||
|
"""Tool call specification for OpenAI-compatible chat completion responses."""
|
||||||
|
|
||||||
|
index: int | None = None
|
||||||
|
id: str | None = None
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: OpenAIChatCompletionToolCallFunction | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIAssistantMessageParam(BaseModel):
|
||||||
|
"""A message containing the model's (assistant) response in an OpenAI-compatible chat completion request."""
|
||||||
|
|
||||||
|
role: Literal["assistant"] = "assistant"
|
||||||
|
content: OpenAIChatCompletionTextOnlyMessageContent | None = None
|
||||||
|
name: str | None = None
|
||||||
|
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIToolMessageParam(BaseModel):
|
||||||
|
"""A message representing the result of a tool invocation in an OpenAI-compatible chat completion request."""
|
||||||
|
|
||||||
|
role: Literal["tool"] = "tool"
|
||||||
|
tool_call_id: str
|
||||||
|
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIDeveloperMessageParam(BaseModel):
|
||||||
|
"""A message from the developer in an OpenAI-compatible chat completion request."""
|
||||||
|
|
||||||
|
role: Literal["developer"] = "developer"
|
||||||
|
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||||
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIMessageParam = Annotated[
|
||||||
|
OpenAIUserMessageParam
|
||||||
|
| OpenAISystemMessageParam
|
||||||
|
| OpenAIAssistantMessageParam
|
||||||
|
| OpenAIToolMessageParam
|
||||||
|
| OpenAIDeveloperMessageParam,
|
||||||
|
Field(discriminator="role"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseFormatText(BaseModel):
|
||||||
|
"""Text response format for OpenAI-compatible chat completion requests."""
|
||||||
|
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIJSONSchema(TypedDict, total=False):
|
||||||
|
"""JSON schema specification for OpenAI-compatible structured response format."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str | None
|
||||||
|
strict: bool | None
|
||||||
|
|
||||||
|
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||||
|
# has one. And, we don't want to alias here because then have to handle
|
||||||
|
# that alias when converting to OpenAI params. So, to support schema,
|
||||||
|
# we use a TypedDict.
|
||||||
|
schema: dict[str, Any] | None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseFormatJSONSchema(BaseModel):
|
||||||
|
"""JSON schema response format for OpenAI-compatible chat completion requests."""
|
||||||
|
|
||||||
|
type: Literal["json_schema"] = "json_schema"
|
||||||
|
json_schema: OpenAIJSONSchema
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseFormatJSONObject(BaseModel):
|
||||||
|
"""JSON object response format for OpenAI-compatible chat completion requests."""
|
||||||
|
|
||||||
|
type: Literal["json_object"] = "json_object"
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseFormatParam = Annotated[
|
||||||
|
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAITopLogProb(BaseModel):
|
||||||
|
"""The top log probability for a token from an OpenAI-compatible chat completion response.
|
||||||
|
|
||||||
|
:token: The token
|
||||||
|
:bytes: (Optional) The bytes for the token
|
||||||
|
:logprob: The log probability of the token
|
||||||
|
"""
|
||||||
|
|
||||||
|
token: str
|
||||||
|
bytes: list[int] | None = None
|
||||||
|
logprob: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAITokenLogProb(BaseModel):
|
||||||
|
"""The log probability for a token from an OpenAI-compatible chat completion response.
|
||||||
|
|
||||||
|
:token: The token
|
||||||
|
:bytes: (Optional) The bytes for the token
|
||||||
|
:logprob: The log probability of the token
|
||||||
|
:top_logprobs: The top log probabilities for the token
|
||||||
|
"""
|
||||||
|
|
||||||
|
token: str
|
||||||
|
bytes: list[int] | None = None
|
||||||
|
logprob: float
|
||||||
|
top_logprobs: list[OpenAITopLogProb]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChoiceLogprobs(BaseModel):
|
||||||
|
"""The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response."""
|
||||||
|
|
||||||
|
content: list[OpenAITokenLogProb] | None = None
|
||||||
|
refusal: list[OpenAITokenLogProb] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChoiceDelta(BaseModel):
|
||||||
|
"""A delta from an OpenAI-compatible chat completion streaming response."""
|
||||||
|
|
||||||
|
content: str | None = None
|
||||||
|
refusal: str | None = None
|
||||||
|
role: str | None = None
|
||||||
|
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||||
|
reasoning_content: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChunkChoice(BaseModel):
|
||||||
|
"""A chunk choice from an OpenAI-compatible chat completion streaming response."""
|
||||||
|
|
||||||
|
delta: OpenAIChoiceDelta
|
||||||
|
finish_reason: str
|
||||||
|
index: int
|
||||||
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChoice(BaseModel):
|
||||||
|
"""A choice from an OpenAI-compatible chat completion response."""
|
||||||
|
|
||||||
|
message: OpenAIMessageParam
|
||||||
|
finish_reason: str
|
||||||
|
index: int
|
||||||
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatCompletionUsageCompletionTokensDetails(BaseModel):
|
||||||
|
"""Token details for output tokens in OpenAI chat completion usage."""
|
||||||
|
|
||||||
|
reasoning_tokens: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatCompletionUsagePromptTokensDetails(BaseModel):
|
||||||
|
"""Token details for prompt tokens in OpenAI chat completion usage."""
|
||||||
|
|
||||||
|
cached_tokens: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletionUsage(BaseModel):
|
||||||
|
"""Usage information for OpenAI chat completion."""
|
||||||
|
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
prompt_tokens_details: OpenAIChatCompletionUsagePromptTokensDetails | None = None
|
||||||
|
completion_tokens_details: OpenAIChatCompletionUsageCompletionTokensDetails | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletion(BaseModel):
|
||||||
|
"""Response from an OpenAI-compatible chat completion request."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
choices: list[OpenAIChoice]
|
||||||
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
usage: OpenAIChatCompletionUsage | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletionChunk(BaseModel):
|
||||||
|
"""Chunk from a streaming response to an OpenAI-compatible chat completion request."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
choices: list[OpenAIChunkChoice]
|
||||||
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
usage: OpenAIChatCompletionUsage | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAICompletionLogprobs(BaseModel):
|
||||||
|
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
|
||||||
|
|
||||||
|
:text_offset: (Optional) The offset of the token in the text
|
||||||
|
:token_logprobs: (Optional) The log probabilities for the tokens
|
||||||
|
:tokens: (Optional) The tokens
|
||||||
|
:top_logprobs: (Optional) The top log probabilities for the tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
text_offset: list[int] | None = None
|
||||||
|
token_logprobs: list[float] | None = None
|
||||||
|
tokens: list[str] | None = None
|
||||||
|
top_logprobs: list[dict[str, float]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAICompletionChoice(BaseModel):
|
||||||
|
"""A choice from an OpenAI-compatible completion response.
|
||||||
|
|
||||||
|
:finish_reason: The reason the model stopped generating
|
||||||
|
:text: The text of the choice
|
||||||
|
:index: The index of the choice
|
||||||
|
:logprobs: (Optional) The log probabilities for the tokens in the choice
|
||||||
|
"""
|
||||||
|
|
||||||
|
finish_reason: str
|
||||||
|
text: str
|
||||||
|
index: int
|
||||||
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAICompletion(BaseModel):
|
||||||
|
"""Response from an OpenAI-compatible completion request.
|
||||||
|
|
||||||
|
:id: The ID of the completion
|
||||||
|
:choices: List of choices
|
||||||
|
:created: The Unix timestamp in seconds when the completion was created
|
||||||
|
:model: The model that was used to generate the completion
|
||||||
|
:object: The object type, which will be "text_completion"
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
choices: list[OpenAICompletionChoice]
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
object: Literal["text_completion"] = "text_completion"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIEmbeddingData(BaseModel):
|
||||||
|
"""A single embedding data object from an OpenAI-compatible embeddings response."""
|
||||||
|
|
||||||
|
object: Literal["embedding"] = "embedding"
|
||||||
|
# TODO: consider dropping str and using openai.types.embeddings.Embedding instead of OpenAIEmbeddingData
|
||||||
|
embedding: list[float] | str
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIEmbeddingUsage(BaseModel):
|
||||||
|
"""Usage information for an OpenAI-compatible embeddings response."""
|
||||||
|
|
||||||
|
prompt_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIEmbeddingsResponse(BaseModel):
|
||||||
|
"""Response from an OpenAI-compatible embeddings request."""
|
||||||
|
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
data: list[OpenAIEmbeddingData]
|
||||||
|
model: str
|
||||||
|
usage: OpenAIEmbeddingUsage
|
||||||
|
|
||||||
|
|
||||||
|
class ModelStore(Protocol):
|
||||||
|
async def get_model(self, identifier: str) -> Model: ...
|
||||||
|
|
||||||
|
|
||||||
|
class TextTruncation(Enum):
|
||||||
|
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left."""
|
||||||
|
|
||||||
|
none = "none"
|
||||||
|
start = "start"
|
||||||
|
end = "end"
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingTaskType(Enum):
|
||||||
|
"""How is the embedding being used? This is only supported by asymmetric embedding models."""
|
||||||
|
|
||||||
|
query = "query"
|
||||||
|
document = "document"
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
||||||
|
input_messages: list[OpenAIMessageParam]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ListOpenAIChatCompletionResponse(BaseModel):
|
||||||
|
"""Response from listing OpenAI-compatible chat completions."""
|
||||||
|
|
||||||
|
data: list[OpenAICompletionWithInputMessages]
|
||||||
|
has_more: bool
|
||||||
|
first_id: str
|
||||||
|
last_id: str
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
|
||||||
|
|
||||||
|
# extra_body can be accessed via .model_extra
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAICompletionRequestWithExtraBody(BaseModel, extra="allow"):
|
||||||
|
"""Request parameters for OpenAI-compatible completion endpoint."""
|
||||||
|
|
||||||
|
# Standard OpenAI completion parameters
|
||||||
|
model: str
|
||||||
|
prompt: str | list[str] | list[int] | list[list[int]]
|
||||||
|
best_of: int | None = None
|
||||||
|
echo: bool | None = None
|
||||||
|
frequency_penalty: float | None = None
|
||||||
|
logit_bias: dict[str, float] | None = None
|
||||||
|
logprobs: int | None = Field(None, ge=0, le=5)
|
||||||
|
max_tokens: int | None = None
|
||||||
|
n: int | None = None
|
||||||
|
presence_penalty: float | None = None
|
||||||
|
seed: int | None = None
|
||||||
|
stop: str | list[str] | None = None
|
||||||
|
stream: bool | None = None
|
||||||
|
stream_options: dict[str, Any] | None = None
|
||||||
|
temperature: float | None = None
|
||||||
|
top_p: float | None = None
|
||||||
|
user: str | None = None
|
||||||
|
suffix: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# extra_body can be accessed via .model_extra
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletionRequestWithExtraBody(BaseModel, extra="allow"):
|
||||||
|
"""Request parameters for OpenAI-compatible chat completion endpoint."""
|
||||||
|
|
||||||
|
# Standard OpenAI chat completion parameters
|
||||||
|
model: str
|
||||||
|
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)]
|
||||||
|
frequency_penalty: float | None = None
|
||||||
|
function_call: str | dict[str, Any] | None = None
|
||||||
|
functions: list[dict[str, Any]] | None = None
|
||||||
|
logit_bias: dict[str, float] | None = None
|
||||||
|
logprobs: bool | None = None
|
||||||
|
max_completion_tokens: int | None = None
|
||||||
|
max_tokens: int | None = None
|
||||||
|
n: int | None = None
|
||||||
|
parallel_tool_calls: bool | None = None
|
||||||
|
presence_penalty: float | None = None
|
||||||
|
response_format: OpenAIResponseFormatParam | None = None
|
||||||
|
seed: int | None = None
|
||||||
|
stop: str | list[str] | None = None
|
||||||
|
stream: bool | None = None
|
||||||
|
stream_options: dict[str, Any] | None = None
|
||||||
|
temperature: float | None = None
|
||||||
|
tool_choice: str | dict[str, Any] | None = None
|
||||||
|
tools: list[dict[str, Any]] | None = None
|
||||||
|
top_logprobs: int | None = None
|
||||||
|
top_p: float | None = None
|
||||||
|
user: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# extra_body can be accessed via .model_extra
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIEmbeddingsRequestWithExtraBody(BaseModel, extra="allow"):
|
||||||
|
"""Request parameters for OpenAI-compatible embeddings endpoint."""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
input: str | list[str]
|
||||||
|
encoding_format: str | None = "float"
|
||||||
|
dimensions: int | None = None
|
||||||
|
user: str | None = None
|
||||||
183
src/llama_stack/apis/inference/routes.py
Normal file
183
src/llama_stack/apis/inference/routes.py
Normal file
|
|
@ -0,0 +1,183 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Query, Request
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .inference_service import InferenceService
|
||||||
|
from .models import (
|
||||||
|
ListOpenAIChatCompletionResponse,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
OpenAICompletion,
|
||||||
|
OpenAICompletionRequestWithExtraBody,
|
||||||
|
OpenAICompletionWithInputMessages,
|
||||||
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
|
RerankResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_service(request: Request) -> InferenceService:
|
||||||
|
"""Dependency to get the inference service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.inference not in impls:
|
||||||
|
raise ValueError("Inference API implementation not found")
|
||||||
|
return impls[Api.inference]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Inference"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
router_v1alpha = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
|
||||||
|
tags=["Inference"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router_v1alpha.post(
|
||||||
|
"/inference/rerank",
|
||||||
|
response_model=RerankResponse,
|
||||||
|
summary="Rerank a list of documents.",
|
||||||
|
description="Rerank a list of documents based on their relevance to a query.",
|
||||||
|
)
|
||||||
|
async def rerank(
|
||||||
|
model: str = Body(..., description="The identifier of the reranking model to use."),
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam = Body(
|
||||||
|
..., description="The search query to rank items against."
|
||||||
|
),
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam] = Body(
|
||||||
|
..., description="List of items to rerank."
|
||||||
|
),
|
||||||
|
max_num_results: int | None = Body(None, description="Maximum number of results to return. Default: returns all."),
|
||||||
|
svc: InferenceService = Depends(get_inference_service),
|
||||||
|
) -> RerankResponse:
|
||||||
|
"""Rerank a list of documents based on their relevance to a query."""
|
||||||
|
return await svc.rerank(model=model, query=query, items=items, max_num_results=max_num_results)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/completions",
|
||||||
|
response_model=OpenAICompletion,
|
||||||
|
summary="Create completion.",
|
||||||
|
description="Create completion.",
|
||||||
|
)
|
||||||
|
async def openai_completion(
|
||||||
|
params: OpenAICompletionRequestWithExtraBody = Body(...),
|
||||||
|
svc: InferenceService = Depends(get_inference_service),
|
||||||
|
) -> OpenAICompletion:
|
||||||
|
"""Create completion."""
|
||||||
|
return await svc.openai_completion(params=params)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/chat/completions",
|
||||||
|
summary="Create chat completions.",
|
||||||
|
description="Create chat completions.",
|
||||||
|
)
|
||||||
|
async def openai_chat_completion(
|
||||||
|
params: OpenAIChatCompletionRequestWithExtraBody = Body(...),
|
||||||
|
svc: InferenceService = Depends(get_inference_service),
|
||||||
|
):
|
||||||
|
"""Create chat completions."""
|
||||||
|
response = await svc.openai_chat_completion(params=params)
|
||||||
|
|
||||||
|
# Check if response is an async generator/iterator (streaming response)
|
||||||
|
# Check for __aiter__ method which all async iterators have
|
||||||
|
if hasattr(response, "__aiter__"):
|
||||||
|
# Convert async generator to SSE stream
|
||||||
|
async def sse_stream():
|
||||||
|
try:
|
||||||
|
async for chunk in response:
|
||||||
|
if isinstance(chunk, BaseModel):
|
||||||
|
data = chunk.model_dump_json()
|
||||||
|
else:
|
||||||
|
data = json.dumps(chunk)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
# Send error as SSE event
|
||||||
|
error_data = json.dumps({"error": {"message": str(e)}})
|
||||||
|
yield f"data: {error_data}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(sse_stream(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/embeddings",
|
||||||
|
response_model=OpenAIEmbeddingsResponse,
|
||||||
|
summary="Create embeddings.",
|
||||||
|
description="Create embeddings.",
|
||||||
|
)
|
||||||
|
async def openai_embeddings(
|
||||||
|
params: OpenAIEmbeddingsRequestWithExtraBody = Body(...),
|
||||||
|
svc: InferenceService = Depends(get_inference_service),
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
"""Create embeddings."""
|
||||||
|
return await svc.openai_embeddings(params=params)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/chat/completions",
|
||||||
|
response_model=ListOpenAIChatCompletionResponse,
|
||||||
|
summary="List chat completions.",
|
||||||
|
description="List chat completions.",
|
||||||
|
)
|
||||||
|
async def list_chat_completions(
|
||||||
|
after: str | None = Query(None, description="The ID of the last chat completion to return."),
|
||||||
|
limit: int | None = Query(20, description="The maximum number of chat completions to return."),
|
||||||
|
model: str | None = Query(None, description="The model to filter by."),
|
||||||
|
order: Order | None = Query(
|
||||||
|
Order.desc, description="The order to sort the chat completions by: 'asc' or 'desc'. Defaults to 'desc'."
|
||||||
|
),
|
||||||
|
svc: InferenceService = Depends(get_inference_service),
|
||||||
|
) -> ListOpenAIChatCompletionResponse:
|
||||||
|
"""List chat completions."""
|
||||||
|
return await svc.list_chat_completions(after=after, limit=limit, model=model, order=order)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/chat/completions/{completion_id}",
|
||||||
|
response_model=OpenAICompletionWithInputMessages,
|
||||||
|
summary="Get chat completion.",
|
||||||
|
description="Get chat completion.",
|
||||||
|
)
|
||||||
|
async def get_chat_completion(
|
||||||
|
completion_id: Annotated[str, FastAPIPath(..., description="ID of the chat completion.")],
|
||||||
|
svc: InferenceService = Depends(get_inference_service),
|
||||||
|
) -> OpenAICompletionWithInputMessages:
|
||||||
|
"""Get chat completion."""
|
||||||
|
return await svc.get_chat_completion(completion_id=completion_id)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_inference_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Inference API (legacy compatibility)."""
|
||||||
|
main_router = APIRouter()
|
||||||
|
main_router.include_router(router)
|
||||||
|
main_router.include_router(router_v1alpha)
|
||||||
|
return main_router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.inference, create_inference_router)
|
||||||
|
|
@ -4,4 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .inspect import *
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .inspect_service import InspectService
|
||||||
|
from .models import HealthInfo, ListRoutesResponse, RouteInfo, VersionInfo
|
||||||
|
|
||||||
|
# Backward compatibility - export Inspect as alias for InspectService
|
||||||
|
Inspect = InspectService
|
||||||
|
|
||||||
|
__all__ = ["Inspect", "InspectService", "ListRoutesResponse", "RouteInfo", "HealthInfo", "VersionInfo"]
|
||||||
|
|
|
||||||
|
|
@ -1,102 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Literal, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.version import (
|
|
||||||
LLAMA_STACK_API_V1,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.datatypes import HealthStatus
|
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
# Valid values for the route filter parameter.
|
|
||||||
# Actual API levels: v1, v1alpha, v1beta (filters by level, excludes deprecated)
|
|
||||||
# Special filter value: "deprecated" (shows deprecated routes regardless of level)
|
|
||||||
ApiFilter = Literal["v1", "v1alpha", "v1beta", "deprecated"]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RouteInfo(BaseModel):
|
|
||||||
"""Information about an API route including its path, method, and implementing providers.
|
|
||||||
|
|
||||||
:param route: The API endpoint path
|
|
||||||
:param method: HTTP method for the route
|
|
||||||
:param provider_types: List of provider types that implement this route
|
|
||||||
"""
|
|
||||||
|
|
||||||
route: str
|
|
||||||
method: str
|
|
||||||
provider_types: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class HealthInfo(BaseModel):
|
|
||||||
"""Health status information for the service.
|
|
||||||
|
|
||||||
:param status: Current health status of the service
|
|
||||||
"""
|
|
||||||
|
|
||||||
status: HealthStatus
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class VersionInfo(BaseModel):
|
|
||||||
"""Version information for the service.
|
|
||||||
|
|
||||||
:param version: Version number of the service
|
|
||||||
"""
|
|
||||||
|
|
||||||
version: str
|
|
||||||
|
|
||||||
|
|
||||||
class ListRoutesResponse(BaseModel):
|
|
||||||
"""Response containing a list of all available API routes.
|
|
||||||
|
|
||||||
:param data: List of available route information objects
|
|
||||||
"""
|
|
||||||
|
|
||||||
data: list[RouteInfo]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class Inspect(Protocol):
|
|
||||||
"""Inspect
|
|
||||||
|
|
||||||
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def list_routes(self, api_filter: ApiFilter | None = None) -> ListRoutesResponse:
|
|
||||||
"""List routes.
|
|
||||||
|
|
||||||
List all available API routes with their methods and implementing providers.
|
|
||||||
|
|
||||||
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
|
|
||||||
:returns: Response containing information about all available routes.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
|
|
||||||
async def health(self) -> HealthInfo:
|
|
||||||
"""Get health status.
|
|
||||||
|
|
||||||
Get the current health status of the service.
|
|
||||||
|
|
||||||
:returns: Health information indicating if the service is operational.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
|
|
||||||
async def version(self) -> VersionInfo:
|
|
||||||
"""Get version.
|
|
||||||
|
|
||||||
Get the version of the service.
|
|
||||||
|
|
||||||
:returns: Version information containing the service version number.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
29
src/llama_stack/apis/inspect/inspect_service.py
Normal file
29
src/llama_stack/apis/inspect/inspect_service.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from .models import HealthInfo, ListRoutesResponse, VersionInfo
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class InspectService(Protocol):
|
||||||
|
"""Inspect
|
||||||
|
|
||||||
|
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def list_routes(self) -> ListRoutesResponse:
|
||||||
|
"""List routes."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def health(self) -> HealthInfo:
|
||||||
|
"""Get health status."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def version(self) -> VersionInfo:
|
||||||
|
"""Get version."""
|
||||||
|
...
|
||||||
39
src/llama_stack/apis/inspect/models.py
Normal file
39
src/llama_stack/apis/inspect/models.py
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RouteInfo(BaseModel):
|
||||||
|
"""Information about an API route including its path, method, and implementing providers."""
|
||||||
|
|
||||||
|
route: str = Field(..., description="The API endpoint path")
|
||||||
|
method: str = Field(..., description="HTTP method for the route")
|
||||||
|
provider_types: list[str] = Field(..., description="List of provider types that implement this route")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class HealthInfo(BaseModel):
|
||||||
|
"""Health status information for the service."""
|
||||||
|
|
||||||
|
status: HealthStatus = Field(..., description="Current health status of the service")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class VersionInfo(BaseModel):
|
||||||
|
"""Version information for the service."""
|
||||||
|
|
||||||
|
version: str = Field(..., description="Version number of the service")
|
||||||
|
|
||||||
|
|
||||||
|
class ListRoutesResponse(BaseModel):
|
||||||
|
"""Response containing a list of all available API routes."""
|
||||||
|
|
||||||
|
data: list[RouteInfo] = Field(..., description="List of available route information objects")
|
||||||
73
src/llama_stack/apis/inspect/routes.py
Normal file
73
src/llama_stack/apis/inspect/routes.py
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from fastapi import Depends, Request
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .inspect_service import InspectService
|
||||||
|
from .models import HealthInfo, ListRoutesResponse, VersionInfo
|
||||||
|
|
||||||
|
|
||||||
|
def get_inspect_service(request: Request) -> InspectService:
|
||||||
|
"""Dependency to get the inspect service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.inspect not in impls:
|
||||||
|
raise ValueError("Inspect API implementation not found")
|
||||||
|
return impls[Api.inspect]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Inspect"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/inspect/routes",
|
||||||
|
response_model=ListRoutesResponse,
|
||||||
|
summary="List routes.",
|
||||||
|
description="List all available API routes with their methods and implementing providers.",
|
||||||
|
)
|
||||||
|
async def list_routes(svc: InspectService = Depends(get_inspect_service)) -> ListRoutesResponse:
|
||||||
|
"""List all available API routes."""
|
||||||
|
return await svc.list_routes()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/health",
|
||||||
|
response_model=HealthInfo,
|
||||||
|
summary="Get health status.",
|
||||||
|
description="Get the current health status of the service.",
|
||||||
|
)
|
||||||
|
async def health(svc: InspectService = Depends(get_inspect_service)) -> HealthInfo:
|
||||||
|
"""Get the current health status of the service."""
|
||||||
|
return await svc.health()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/version",
|
||||||
|
response_model=VersionInfo,
|
||||||
|
summary="Get version.",
|
||||||
|
description="Get the version of the service.",
|
||||||
|
)
|
||||||
|
async def version(svc: InspectService = Depends(get_inspect_service)) -> VersionInfo:
|
||||||
|
"""Get the version of the service."""
|
||||||
|
return await svc.version()
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_inspect_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Inspect API (legacy compatibility)."""
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.inspect, create_inspect_router)
|
||||||
|
|
@ -4,4 +4,30 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .models import *
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .model_schemas import (
|
||||||
|
ListModelsResponse,
|
||||||
|
Model,
|
||||||
|
ModelInput,
|
||||||
|
ModelType,
|
||||||
|
OpenAIListModelsResponse,
|
||||||
|
OpenAIModel,
|
||||||
|
RegisterModelRequest,
|
||||||
|
)
|
||||||
|
from .models_service import ModelService
|
||||||
|
|
||||||
|
# Backward compatibility - export Models as alias for ModelService
|
||||||
|
Models = ModelService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Models",
|
||||||
|
"ModelService",
|
||||||
|
"Model",
|
||||||
|
"ModelInput",
|
||||||
|
"ModelType",
|
||||||
|
"ListModelsResponse",
|
||||||
|
"RegisterModelRequest",
|
||||||
|
"OpenAIModel",
|
||||||
|
"OpenAIListModelsResponse",
|
||||||
|
]
|
||||||
|
|
|
||||||
98
src/llama_stack/apis/models/model_schemas.py
Normal file
98
src/llama_stack/apis/models/model_schemas.py
Normal file
|
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class CommonModelFields(BaseModel):
|
||||||
|
metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Any additional metadata for this model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ModelType(StrEnum):
|
||||||
|
"""Enumeration of supported model types in Llama Stack."""
|
||||||
|
|
||||||
|
llm = "llm"
|
||||||
|
embedding = "embedding"
|
||||||
|
rerank = "rerank"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Model(CommonModelFields, Resource):
|
||||||
|
"""A model resource representing an AI model registered in Llama Stack."""
|
||||||
|
|
||||||
|
type: Literal[ResourceType.model] = Field(
|
||||||
|
default=ResourceType.model, description="The resource type, always 'model' for model resources."
|
||||||
|
)
|
||||||
|
model_type: ModelType = Field(default=ModelType.llm, description="The type of model (LLM or embedding model).")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_model_id(self) -> str:
|
||||||
|
assert self.provider_resource_id is not None, "Provider resource ID must be set"
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
@field_validator("provider_resource_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_provider_resource_id(cls, v):
|
||||||
|
if v is None:
|
||||||
|
raise ValueError("provider_resource_id cannot be None")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInput(CommonModelFields):
|
||||||
|
model_id: str
|
||||||
|
provider_id: str | None = None
|
||||||
|
provider_model_id: str | None = None
|
||||||
|
model_type: ModelType | None = ModelType.llm
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
|
class ListModelsResponse(BaseModel):
|
||||||
|
"""Response model for listing models."""
|
||||||
|
|
||||||
|
data: list[Model] = Field(description="List of model resources.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RegisterModelRequest(BaseModel):
|
||||||
|
"""Request model for registering a new model."""
|
||||||
|
|
||||||
|
model_id: str = Field(..., description="The identifier of the model to register.")
|
||||||
|
provider_model_id: str | None = Field(default=None, description="The identifier of the model in the provider.")
|
||||||
|
provider_id: str | None = Field(default=None, description="The identifier of the provider.")
|
||||||
|
metadata: dict[str, Any] | None = Field(default=None, description="Any additional metadata for this model.")
|
||||||
|
model_type: ModelType | None = Field(default=None, description="The type of model to register.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIModel(BaseModel):
|
||||||
|
"""A model from OpenAI."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="The ID of the model.")
|
||||||
|
object: Literal["model"] = Field(default="model", description="The object type, which will be 'model'.")
|
||||||
|
created: int = Field(..., description="The Unix timestamp in seconds when the model was created.")
|
||||||
|
owned_by: str = Field(..., description="The owner of the model.")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIListModelsResponse(BaseModel):
|
||||||
|
"""Response model for listing OpenAI models."""
|
||||||
|
|
||||||
|
data: list[OpenAIModel] = Field(description="List of OpenAI model objects.")
|
||||||
|
|
@ -1,172 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Any, Literal, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
class CommonModelFields(BaseModel):
|
|
||||||
metadata: dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Any additional metadata for this model",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ModelType(StrEnum):
|
|
||||||
"""Enumeration of supported model types in Llama Stack.
|
|
||||||
:cvar llm: Large language model for text generation and completion
|
|
||||||
:cvar embedding: Embedding model for converting text to vector representations
|
|
||||||
:cvar rerank: Reranking model for reordering documents based on their relevance to a query
|
|
||||||
"""
|
|
||||||
|
|
||||||
llm = "llm"
|
|
||||||
embedding = "embedding"
|
|
||||||
rerank = "rerank"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Model(CommonModelFields, Resource):
|
|
||||||
"""A model resource representing an AI model registered in Llama Stack.
|
|
||||||
|
|
||||||
:param type: The resource type, always 'model' for model resources
|
|
||||||
:param model_type: The type of model (LLM or embedding model)
|
|
||||||
:param metadata: Any additional metadata for this model
|
|
||||||
:param identifier: Unique identifier for this resource in llama stack
|
|
||||||
:param provider_resource_id: Unique identifier for this resource in the provider
|
|
||||||
:param provider_id: ID of the provider that owns this resource
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal[ResourceType.model] = ResourceType.model
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model_id(self) -> str:
|
|
||||||
return self.identifier
|
|
||||||
|
|
||||||
@property
|
|
||||||
def provider_model_id(self) -> str:
|
|
||||||
assert self.provider_resource_id is not None, "Provider resource ID must be set"
|
|
||||||
return self.provider_resource_id
|
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
model_type: ModelType = Field(default=ModelType.llm)
|
|
||||||
|
|
||||||
@field_validator("provider_resource_id")
|
|
||||||
@classmethod
|
|
||||||
def validate_provider_resource_id(cls, v):
|
|
||||||
if v is None:
|
|
||||||
raise ValueError("provider_resource_id cannot be None")
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInput(CommonModelFields):
|
|
||||||
model_id: str
|
|
||||||
provider_id: str | None = None
|
|
||||||
provider_model_id: str | None = None
|
|
||||||
model_type: ModelType | None = ModelType.llm
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
|
|
||||||
class ListModelsResponse(BaseModel):
|
|
||||||
data: list[Model]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIModel(BaseModel):
|
|
||||||
"""A model from OpenAI.
|
|
||||||
|
|
||||||
:id: The ID of the model
|
|
||||||
:object: The object type, which will be "model"
|
|
||||||
:created: The Unix timestamp in seconds when the model was created
|
|
||||||
:owned_by: The owner of the model
|
|
||||||
:custom_metadata: Llama Stack-specific metadata including model_type, provider info, and additional metadata
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
object: Literal["model"] = "model"
|
|
||||||
created: int
|
|
||||||
owned_by: str
|
|
||||||
custom_metadata: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIListModelsResponse(BaseModel):
|
|
||||||
data: list[OpenAIModel]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
@trace_protocol
|
|
||||||
class Models(Protocol):
|
|
||||||
async def list_models(self) -> ListModelsResponse:
|
|
||||||
"""List all models.
|
|
||||||
|
|
||||||
:returns: A ListModelsResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
|
||||||
"""List models using the OpenAI API.
|
|
||||||
|
|
||||||
:returns: A OpenAIListModelsResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def get_model(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
) -> Model:
|
|
||||||
"""Get model.
|
|
||||||
|
|
||||||
Get a model by its identifier.
|
|
||||||
|
|
||||||
:param model_id: The identifier of the model to get.
|
|
||||||
:returns: A Model.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def register_model(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
provider_model_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
model_type: ModelType | None = None,
|
|
||||||
) -> Model:
|
|
||||||
"""Register model.
|
|
||||||
|
|
||||||
Register a model.
|
|
||||||
|
|
||||||
:param model_id: The identifier of the model to register.
|
|
||||||
:param provider_model_id: The identifier of the model in the provider.
|
|
||||||
:param provider_id: The identifier of the provider.
|
|
||||||
:param metadata: Any additional metadata for this model.
|
|
||||||
:param model_type: The type of model to register.
|
|
||||||
:returns: A Model.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
|
||||||
async def unregister_model(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Unregister model.
|
|
||||||
|
|
||||||
Unregister a model.
|
|
||||||
|
|
||||||
:param model_id: The identifier of the model to unregister.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
98
src/llama_stack/apis/models/models_models.py
Normal file
98
src/llama_stack/apis/models/models_models.py
Normal file
|
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class CommonModelFields(BaseModel):
|
||||||
|
metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Any additional metadata for this model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ModelType(StrEnum):
|
||||||
|
"""Enumeration of supported model types in Llama Stack."""
|
||||||
|
|
||||||
|
llm = "llm"
|
||||||
|
embedding = "embedding"
|
||||||
|
rerank = "rerank"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Model(CommonModelFields, Resource):
|
||||||
|
"""A model resource representing an AI model registered in Llama Stack."""
|
||||||
|
|
||||||
|
type: Literal[ResourceType.model] = Field(
|
||||||
|
default=ResourceType.model, description="The resource type, always 'model' for model resources."
|
||||||
|
)
|
||||||
|
model_type: ModelType = Field(default=ModelType.llm, description="The type of model (LLM or embedding model).")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_model_id(self) -> str:
|
||||||
|
assert self.provider_resource_id is not None, "Provider resource ID must be set"
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
@field_validator("provider_resource_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_provider_resource_id(cls, v):
|
||||||
|
if v is None:
|
||||||
|
raise ValueError("provider_resource_id cannot be None")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInput(CommonModelFields):
|
||||||
|
model_id: str
|
||||||
|
provider_id: str | None = None
|
||||||
|
provider_model_id: str | None = None
|
||||||
|
model_type: ModelType | None = ModelType.llm
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
|
class ListModelsResponse(BaseModel):
|
||||||
|
"""Response model for listing models."""
|
||||||
|
|
||||||
|
data: list[Model] = Field(description="List of model resources.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RegisterModelRequest(BaseModel):
|
||||||
|
"""Request model for registering a new model."""
|
||||||
|
|
||||||
|
model_id: str = Field(..., description="The identifier of the model to register.")
|
||||||
|
provider_model_id: str | None = Field(default=None, description="The identifier of the model in the provider.")
|
||||||
|
provider_id: str | None = Field(default=None, description="The identifier of the provider.")
|
||||||
|
metadata: dict[str, Any] | None = Field(default=None, description="Any additional metadata for this model.")
|
||||||
|
model_type: ModelType | None = Field(default=None, description="The type of model to register.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIModel(BaseModel):
|
||||||
|
"""A model from OpenAI."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="The ID of the model.")
|
||||||
|
object: Literal["model"] = Field(default="model", description="The object type, which will be 'model'.")
|
||||||
|
created: int = Field(..., description="The Unix timestamp in seconds when the model was created.")
|
||||||
|
owned_by: str = Field(..., description="The owner of the model.")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIListModelsResponse(BaseModel):
|
||||||
|
"""Response model for listing OpenAI models."""
|
||||||
|
|
||||||
|
data: list[OpenAIModel] = Field(description="List of OpenAI model objects.")
|
||||||
53
src/llama_stack/apis/models/models_service.py
Normal file
53
src/llama_stack/apis/models/models_service.py
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
from .model_schemas import (
|
||||||
|
ListModelsResponse,
|
||||||
|
Model,
|
||||||
|
ModelType,
|
||||||
|
OpenAIListModelsResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class ModelService(Protocol):
|
||||||
|
async def list_models(self) -> ListModelsResponse:
|
||||||
|
"""List all models."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
|
"""List models using the OpenAI API."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_model(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
) -> Model:
|
||||||
|
"""Get model."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def register_model(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
provider_model_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
model_type: ModelType | None = None,
|
||||||
|
) -> Model:
|
||||||
|
"""Register model."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def unregister_model(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Unregister model."""
|
||||||
|
...
|
||||||
107
src/llama_stack/apis/models/routes.py
Normal file
107
src/llama_stack/apis/models/routes.py
Normal file
|
|
@ -0,0 +1,107 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Request
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .model_schemas import (
|
||||||
|
ListModelsResponse,
|
||||||
|
Model,
|
||||||
|
RegisterModelRequest,
|
||||||
|
)
|
||||||
|
from .models_service import ModelService
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_service(request: Request) -> ModelService:
|
||||||
|
"""Dependency to get the model service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.models not in impls:
|
||||||
|
raise ValueError("Models API implementation not found")
|
||||||
|
return impls[Api.models]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Models"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/models",
|
||||||
|
response_model=ListModelsResponse,
|
||||||
|
summary="List all models.",
|
||||||
|
description="List all models registered in Llama Stack.",
|
||||||
|
)
|
||||||
|
async def list_models(svc: ModelService = Depends(get_model_service)) -> ListModelsResponse:
|
||||||
|
"""List all models."""
|
||||||
|
return await svc.list_models()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/models/{model_id:path}",
|
||||||
|
response_model=Model,
|
||||||
|
summary="Get model.",
|
||||||
|
description="Get a model by its identifier.",
|
||||||
|
)
|
||||||
|
async def get_model(
|
||||||
|
model_id: Annotated[str, FastAPIPath(..., description="The identifier of the model to get.")],
|
||||||
|
svc: ModelService = Depends(get_model_service),
|
||||||
|
) -> Model:
|
||||||
|
"""Get model by its identifier."""
|
||||||
|
return await svc.get_model(model_id=model_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/models",
|
||||||
|
response_model=Model,
|
||||||
|
summary="Register model.",
|
||||||
|
description="Register a new model in Llama Stack.",
|
||||||
|
)
|
||||||
|
async def register_model(
|
||||||
|
body: RegisterModelRequest = Body(...),
|
||||||
|
svc: ModelService = Depends(get_model_service),
|
||||||
|
) -> Model:
|
||||||
|
"""Register a new model."""
|
||||||
|
return await svc.register_model(
|
||||||
|
model_id=body.model_id,
|
||||||
|
provider_model_id=body.provider_model_id,
|
||||||
|
provider_id=body.provider_id,
|
||||||
|
metadata=body.metadata,
|
||||||
|
model_type=body.model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/models/{model_id:path}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Unregister model.",
|
||||||
|
description="Unregister a model from Llama Stack.",
|
||||||
|
)
|
||||||
|
async def unregister_model(
|
||||||
|
model_id: Annotated[str, FastAPIPath(..., description="The identifier of the model to unregister.")],
|
||||||
|
svc: ModelService = Depends(get_model_service),
|
||||||
|
) -> None:
|
||||||
|
"""Unregister a model."""
|
||||||
|
await svc.unregister_model(model_id=model_id)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_models_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Models API (legacy compatibility)."""
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.models, create_models_router)
|
||||||
|
|
@ -4,4 +4,61 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .post_training import *
|
# Import routes to trigger router registration
|
||||||
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
|
from llama_stack.apis.common.training_types import Checkpoint
|
||||||
|
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .models import (
|
||||||
|
AlgorithmConfig,
|
||||||
|
DataConfig,
|
||||||
|
DatasetFormat,
|
||||||
|
DPOAlignmentConfig,
|
||||||
|
DPOLossType,
|
||||||
|
EfficiencyConfig,
|
||||||
|
ListPostTrainingJobsResponse,
|
||||||
|
LoraFinetuningConfig,
|
||||||
|
OptimizerConfig,
|
||||||
|
OptimizerType,
|
||||||
|
PostTrainingJob,
|
||||||
|
PostTrainingJobArtifactsResponse,
|
||||||
|
PostTrainingJobLogStream,
|
||||||
|
PostTrainingJobStatusResponse,
|
||||||
|
PostTrainingRLHFRequest,
|
||||||
|
PreferenceOptimizeRequest,
|
||||||
|
QATFinetuningConfig,
|
||||||
|
RLHFAlgorithm,
|
||||||
|
SupervisedFineTuneRequest,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
|
from .post_training_service import PostTrainingService
|
||||||
|
|
||||||
|
# Backward compatibility - export PostTraining as alias for PostTrainingService
|
||||||
|
PostTraining = PostTrainingService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PostTraining",
|
||||||
|
"PostTrainingService",
|
||||||
|
"Checkpoint",
|
||||||
|
"JobStatus",
|
||||||
|
"OptimizerType",
|
||||||
|
"DatasetFormat",
|
||||||
|
"DataConfig",
|
||||||
|
"OptimizerConfig",
|
||||||
|
"EfficiencyConfig",
|
||||||
|
"TrainingConfig",
|
||||||
|
"LoraFinetuningConfig",
|
||||||
|
"QATFinetuningConfig",
|
||||||
|
"AlgorithmConfig",
|
||||||
|
"PostTrainingJobLogStream",
|
||||||
|
"RLHFAlgorithm",
|
||||||
|
"DPOLossType",
|
||||||
|
"DPOAlignmentConfig",
|
||||||
|
"PostTrainingRLHFRequest",
|
||||||
|
"PostTrainingJob",
|
||||||
|
"PostTrainingJobStatusResponse",
|
||||||
|
"ListPostTrainingJobsResponse",
|
||||||
|
"PostTrainingJobArtifactsResponse",
|
||||||
|
"SupervisedFineTuneRequest",
|
||||||
|
"PreferenceOptimizeRequest",
|
||||||
|
]
|
||||||
|
|
|
||||||
222
src/llama_stack/apis/post_training/models.py
Normal file
222
src/llama_stack/apis/post_training/models.py
Normal file
|
|
@ -0,0 +1,222 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
|
from llama_stack.apis.common.training_types import Checkpoint
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OptimizerType(Enum):
|
||||||
|
"""Available optimizer algorithms for training."""
|
||||||
|
|
||||||
|
adam = "adam"
|
||||||
|
adamw = "adamw"
|
||||||
|
sgd = "sgd"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DatasetFormat(Enum):
|
||||||
|
"""Format of the training dataset."""
|
||||||
|
|
||||||
|
instruct = "instruct"
|
||||||
|
dialog = "dialog"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DataConfig(BaseModel):
|
||||||
|
"""Configuration for training data and data loading."""
|
||||||
|
|
||||||
|
dataset_id: str
|
||||||
|
batch_size: int
|
||||||
|
shuffle: bool
|
||||||
|
data_format: DatasetFormat
|
||||||
|
validation_dataset_id: str | None = None
|
||||||
|
packed: bool | None = False
|
||||||
|
train_on_input: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OptimizerConfig(BaseModel):
|
||||||
|
"""Configuration parameters for the optimization algorithm."""
|
||||||
|
|
||||||
|
optimizer_type: OptimizerType
|
||||||
|
lr: float
|
||||||
|
weight_decay: float
|
||||||
|
num_warmup_steps: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EfficiencyConfig(BaseModel):
|
||||||
|
"""Configuration for memory and compute efficiency optimizations."""
|
||||||
|
|
||||||
|
enable_activation_checkpointing: bool | None = False
|
||||||
|
enable_activation_offloading: bool | None = False
|
||||||
|
memory_efficient_fsdp_wrap: bool | None = False
|
||||||
|
fsdp_cpu_offload: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TrainingConfig(BaseModel):
|
||||||
|
"""Comprehensive configuration for the training process."""
|
||||||
|
|
||||||
|
n_epochs: int
|
||||||
|
max_steps_per_epoch: int = 1
|
||||||
|
gradient_accumulation_steps: int = 1
|
||||||
|
max_validation_steps: int | None = 1
|
||||||
|
data_config: DataConfig | None = None
|
||||||
|
optimizer_config: OptimizerConfig | None = None
|
||||||
|
efficiency_config: EfficiencyConfig | None = None
|
||||||
|
dtype: str | None = "bf16"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class LoraFinetuningConfig(BaseModel):
|
||||||
|
"""Configuration for Low-Rank Adaptation (LoRA) fine-tuning."""
|
||||||
|
|
||||||
|
type: Literal["LoRA"] = "LoRA"
|
||||||
|
lora_attn_modules: list[str]
|
||||||
|
apply_lora_to_mlp: bool
|
||||||
|
apply_lora_to_output: bool
|
||||||
|
rank: int
|
||||||
|
alpha: int
|
||||||
|
use_dora: bool | None = False
|
||||||
|
quantize_base: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QATFinetuningConfig(BaseModel):
|
||||||
|
"""Configuration for Quantization-Aware Training (QAT) fine-tuning."""
|
||||||
|
|
||||||
|
type: Literal["QAT"] = "QAT"
|
||||||
|
quantizer_name: str
|
||||||
|
group_size: int
|
||||||
|
|
||||||
|
|
||||||
|
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
||||||
|
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingJobLogStream(BaseModel):
|
||||||
|
"""Stream of logs from a finetuning job."""
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
log_lines: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RLHFAlgorithm(Enum):
|
||||||
|
"""Available reinforcement learning from human feedback algorithms."""
|
||||||
|
|
||||||
|
dpo = "dpo"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DPOLossType(Enum):
|
||||||
|
sigmoid = "sigmoid"
|
||||||
|
hinge = "hinge"
|
||||||
|
ipo = "ipo"
|
||||||
|
kto_pair = "kto_pair"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DPOAlignmentConfig(BaseModel):
|
||||||
|
"""Configuration for Direct Preference Optimization (DPO) alignment."""
|
||||||
|
|
||||||
|
beta: float
|
||||||
|
loss_type: DPOLossType = DPOLossType.sigmoid
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingRLHFRequest(BaseModel):
|
||||||
|
"""Request to finetune a model using reinforcement learning from human feedback."""
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
|
||||||
|
finetuned_model: URL
|
||||||
|
|
||||||
|
dataset_id: str
|
||||||
|
validation_dataset_id: str
|
||||||
|
|
||||||
|
algorithm: RLHFAlgorithm
|
||||||
|
algorithm_config: DPOAlignmentConfig
|
||||||
|
|
||||||
|
optimizer_config: OptimizerConfig
|
||||||
|
training_config: TrainingConfig
|
||||||
|
|
||||||
|
# TODO: define these
|
||||||
|
hyperparam_search_config: dict[str, Any]
|
||||||
|
logger_config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class PostTrainingJob(BaseModel):
|
||||||
|
job_uuid: str = Field(..., description="The UUID of the job")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingJobStatusResponse(BaseModel):
|
||||||
|
"""Status of a finetuning job."""
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
status: JobStatus
|
||||||
|
|
||||||
|
scheduled_at: datetime | None = None
|
||||||
|
started_at: datetime | None = None
|
||||||
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
|
resources_allocated: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ListPostTrainingJobsResponse(BaseModel):
|
||||||
|
data: list[PostTrainingJob] = Field(..., description="The list of training jobs")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
|
"""Artifacts of a finetuning job."""
|
||||||
|
|
||||||
|
job_uuid: str = Field(..., description="The UUID of the job")
|
||||||
|
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
|
# TODO(ashwin): metrics, evals
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SupervisedFineTuneRequest(BaseModel):
|
||||||
|
"""Request to run supervised fine-tuning of a model."""
|
||||||
|
|
||||||
|
job_uuid: str = Field(..., description="The UUID of the job to create")
|
||||||
|
training_config: TrainingConfig = Field(..., description="The training configuration")
|
||||||
|
hyperparam_search_config: dict[str, Any] = Field(..., description="The hyperparam search configuration")
|
||||||
|
logger_config: dict[str, Any] = Field(..., description="The logger configuration")
|
||||||
|
model: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Model descriptor for training if not in provider config`",
|
||||||
|
)
|
||||||
|
checkpoint_dir: str | None = Field(default=None, description="The directory to save checkpoint(s) to")
|
||||||
|
algorithm_config: AlgorithmConfig | None = Field(default=None, description="The algorithm configuration")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PreferenceOptimizeRequest(BaseModel):
|
||||||
|
"""Request to run preference optimization of a model."""
|
||||||
|
|
||||||
|
job_uuid: str = Field(..., description="The UUID of the job to create")
|
||||||
|
finetuned_model: str = Field(..., description="The model to fine-tune")
|
||||||
|
algorithm_config: DPOAlignmentConfig = Field(..., description="The algorithm configuration")
|
||||||
|
training_config: TrainingConfig = Field(..., description="The training configuration")
|
||||||
|
hyperparam_search_config: dict[str, Any] = Field(..., description="The hyperparam search configuration")
|
||||||
|
logger_config: dict[str, Any] = Field(..., description="The logger configuration")
|
||||||
|
|
@ -1,368 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Annotated, Any, Literal, Protocol
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.common.job_types import JobStatus
|
|
||||||
from llama_stack.apis.common.training_types import Checkpoint
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OptimizerType(Enum):
|
|
||||||
"""Available optimizer algorithms for training.
|
|
||||||
:cvar adam: Adaptive Moment Estimation optimizer
|
|
||||||
:cvar adamw: AdamW optimizer with weight decay
|
|
||||||
:cvar sgd: Stochastic Gradient Descent optimizer
|
|
||||||
"""
|
|
||||||
|
|
||||||
adam = "adam"
|
|
||||||
adamw = "adamw"
|
|
||||||
sgd = "sgd"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DatasetFormat(Enum):
|
|
||||||
"""Format of the training dataset.
|
|
||||||
:cvar instruct: Instruction-following format with prompt and completion
|
|
||||||
:cvar dialog: Multi-turn conversation format with messages
|
|
||||||
"""
|
|
||||||
|
|
||||||
instruct = "instruct"
|
|
||||||
dialog = "dialog"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DataConfig(BaseModel):
|
|
||||||
"""Configuration for training data and data loading.
|
|
||||||
|
|
||||||
:param dataset_id: Unique identifier for the training dataset
|
|
||||||
:param batch_size: Number of samples per training batch
|
|
||||||
:param shuffle: Whether to shuffle the dataset during training
|
|
||||||
:param data_format: Format of the dataset (instruct or dialog)
|
|
||||||
:param validation_dataset_id: (Optional) Unique identifier for the validation dataset
|
|
||||||
:param packed: (Optional) Whether to pack multiple samples into a single sequence for efficiency
|
|
||||||
:param train_on_input: (Optional) Whether to compute loss on input tokens as well as output tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
dataset_id: str
|
|
||||||
batch_size: int
|
|
||||||
shuffle: bool
|
|
||||||
data_format: DatasetFormat
|
|
||||||
validation_dataset_id: str | None = None
|
|
||||||
packed: bool | None = False
|
|
||||||
train_on_input: bool | None = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OptimizerConfig(BaseModel):
|
|
||||||
"""Configuration parameters for the optimization algorithm.
|
|
||||||
|
|
||||||
:param optimizer_type: Type of optimizer to use (adam, adamw, or sgd)
|
|
||||||
:param lr: Learning rate for the optimizer
|
|
||||||
:param weight_decay: Weight decay coefficient for regularization
|
|
||||||
:param num_warmup_steps: Number of steps for learning rate warmup
|
|
||||||
"""
|
|
||||||
|
|
||||||
optimizer_type: OptimizerType
|
|
||||||
lr: float
|
|
||||||
weight_decay: float
|
|
||||||
num_warmup_steps: int
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class EfficiencyConfig(BaseModel):
|
|
||||||
"""Configuration for memory and compute efficiency optimizations.
|
|
||||||
|
|
||||||
:param enable_activation_checkpointing: (Optional) Whether to use activation checkpointing to reduce memory usage
|
|
||||||
:param enable_activation_offloading: (Optional) Whether to offload activations to CPU to save GPU memory
|
|
||||||
:param memory_efficient_fsdp_wrap: (Optional) Whether to use memory-efficient FSDP wrapping
|
|
||||||
:param fsdp_cpu_offload: (Optional) Whether to offload FSDP parameters to CPU
|
|
||||||
"""
|
|
||||||
|
|
||||||
enable_activation_checkpointing: bool | None = False
|
|
||||||
enable_activation_offloading: bool | None = False
|
|
||||||
memory_efficient_fsdp_wrap: bool | None = False
|
|
||||||
fsdp_cpu_offload: bool | None = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class TrainingConfig(BaseModel):
|
|
||||||
"""Comprehensive configuration for the training process.
|
|
||||||
|
|
||||||
:param n_epochs: Number of training epochs to run
|
|
||||||
:param max_steps_per_epoch: Maximum number of steps to run per epoch
|
|
||||||
:param gradient_accumulation_steps: Number of steps to accumulate gradients before updating
|
|
||||||
:param max_validation_steps: (Optional) Maximum number of validation steps per epoch
|
|
||||||
:param data_config: (Optional) Configuration for data loading and formatting
|
|
||||||
:param optimizer_config: (Optional) Configuration for the optimization algorithm
|
|
||||||
:param efficiency_config: (Optional) Configuration for memory and compute optimizations
|
|
||||||
:param dtype: (Optional) Data type for model parameters (bf16, fp16, fp32)
|
|
||||||
"""
|
|
||||||
|
|
||||||
n_epochs: int
|
|
||||||
max_steps_per_epoch: int = 1
|
|
||||||
gradient_accumulation_steps: int = 1
|
|
||||||
max_validation_steps: int | None = 1
|
|
||||||
data_config: DataConfig | None = None
|
|
||||||
optimizer_config: OptimizerConfig | None = None
|
|
||||||
efficiency_config: EfficiencyConfig | None = None
|
|
||||||
dtype: str | None = "bf16"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class LoraFinetuningConfig(BaseModel):
|
|
||||||
"""Configuration for Low-Rank Adaptation (LoRA) fine-tuning.
|
|
||||||
|
|
||||||
:param type: Algorithm type identifier, always "LoRA"
|
|
||||||
:param lora_attn_modules: List of attention module names to apply LoRA to
|
|
||||||
:param apply_lora_to_mlp: Whether to apply LoRA to MLP layers
|
|
||||||
:param apply_lora_to_output: Whether to apply LoRA to output projection layers
|
|
||||||
:param rank: Rank of the LoRA adaptation (lower rank = fewer parameters)
|
|
||||||
:param alpha: LoRA scaling parameter that controls adaptation strength
|
|
||||||
:param use_dora: (Optional) Whether to use DoRA (Weight-Decomposed Low-Rank Adaptation)
|
|
||||||
:param quantize_base: (Optional) Whether to quantize the base model weights
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["LoRA"] = "LoRA"
|
|
||||||
lora_attn_modules: list[str]
|
|
||||||
apply_lora_to_mlp: bool
|
|
||||||
apply_lora_to_output: bool
|
|
||||||
rank: int
|
|
||||||
alpha: int
|
|
||||||
use_dora: bool | None = False
|
|
||||||
quantize_base: bool | None = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class QATFinetuningConfig(BaseModel):
|
|
||||||
"""Configuration for Quantization-Aware Training (QAT) fine-tuning.
|
|
||||||
|
|
||||||
:param type: Algorithm type identifier, always "QAT"
|
|
||||||
:param quantizer_name: Name of the quantization algorithm to use
|
|
||||||
:param group_size: Size of groups for grouped quantization
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["QAT"] = "QAT"
|
|
||||||
quantizer_name: str
|
|
||||||
group_size: int
|
|
||||||
|
|
||||||
|
|
||||||
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
|
||||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingJobLogStream(BaseModel):
|
|
||||||
"""Stream of logs from a finetuning job.
|
|
||||||
|
|
||||||
:param job_uuid: Unique identifier for the training job
|
|
||||||
:param log_lines: List of log message strings from the training process
|
|
||||||
"""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
log_lines: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RLHFAlgorithm(Enum):
|
|
||||||
"""Available reinforcement learning from human feedback algorithms.
|
|
||||||
:cvar dpo: Direct Preference Optimization algorithm
|
|
||||||
"""
|
|
||||||
|
|
||||||
dpo = "dpo"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DPOLossType(Enum):
|
|
||||||
sigmoid = "sigmoid"
|
|
||||||
hinge = "hinge"
|
|
||||||
ipo = "ipo"
|
|
||||||
kto_pair = "kto_pair"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DPOAlignmentConfig(BaseModel):
|
|
||||||
"""Configuration for Direct Preference Optimization (DPO) alignment.
|
|
||||||
|
|
||||||
:param beta: Temperature parameter for the DPO loss
|
|
||||||
:param loss_type: The type of loss function to use for DPO
|
|
||||||
"""
|
|
||||||
|
|
||||||
beta: float
|
|
||||||
loss_type: DPOLossType = DPOLossType.sigmoid
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingRLHFRequest(BaseModel):
|
|
||||||
"""Request to finetune a model using reinforcement learning from human feedback.
|
|
||||||
|
|
||||||
:param job_uuid: Unique identifier for the training job
|
|
||||||
:param finetuned_model: URL or path to the base model to fine-tune
|
|
||||||
:param dataset_id: Unique identifier for the training dataset
|
|
||||||
:param validation_dataset_id: Unique identifier for the validation dataset
|
|
||||||
:param algorithm: RLHF algorithm to use for training
|
|
||||||
:param algorithm_config: Configuration parameters for the RLHF algorithm
|
|
||||||
:param optimizer_config: Configuration parameters for the optimization algorithm
|
|
||||||
:param training_config: Configuration parameters for the training process
|
|
||||||
:param hyperparam_search_config: Configuration for hyperparameter search
|
|
||||||
:param logger_config: Configuration for training logging
|
|
||||||
"""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
|
|
||||||
finetuned_model: URL
|
|
||||||
|
|
||||||
dataset_id: str
|
|
||||||
validation_dataset_id: str
|
|
||||||
|
|
||||||
algorithm: RLHFAlgorithm
|
|
||||||
algorithm_config: DPOAlignmentConfig
|
|
||||||
|
|
||||||
optimizer_config: OptimizerConfig
|
|
||||||
training_config: TrainingConfig
|
|
||||||
|
|
||||||
# TODO: define these
|
|
||||||
hyperparam_search_config: dict[str, Any]
|
|
||||||
logger_config: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class PostTrainingJob(BaseModel):
|
|
||||||
job_uuid: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingJobStatusResponse(BaseModel):
|
|
||||||
"""Status of a finetuning job.
|
|
||||||
|
|
||||||
:param job_uuid: Unique identifier for the training job
|
|
||||||
:param status: Current status of the training job
|
|
||||||
:param scheduled_at: (Optional) Timestamp when the job was scheduled
|
|
||||||
:param started_at: (Optional) Timestamp when the job execution began
|
|
||||||
:param completed_at: (Optional) Timestamp when the job finished, if completed
|
|
||||||
:param resources_allocated: (Optional) Information about computational resources allocated to the job
|
|
||||||
:param checkpoints: List of model checkpoints created during training
|
|
||||||
"""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
status: JobStatus
|
|
||||||
|
|
||||||
scheduled_at: datetime | None = None
|
|
||||||
started_at: datetime | None = None
|
|
||||||
completed_at: datetime | None = None
|
|
||||||
|
|
||||||
resources_allocated: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class ListPostTrainingJobsResponse(BaseModel):
|
|
||||||
data: list[PostTrainingJob]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingJobArtifactsResponse(BaseModel):
|
|
||||||
"""Artifacts of a finetuning job.
|
|
||||||
|
|
||||||
:param job_uuid: Unique identifier for the training job
|
|
||||||
:param checkpoints: List of model checkpoints created during training
|
|
||||||
"""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
|
||||||
|
|
||||||
# TODO(ashwin): metrics, evals
|
|
||||||
|
|
||||||
|
|
||||||
class PostTraining(Protocol):
|
|
||||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def supervised_fine_tune(
|
|
||||||
self,
|
|
||||||
job_uuid: str,
|
|
||||||
training_config: TrainingConfig,
|
|
||||||
hyperparam_search_config: dict[str, Any],
|
|
||||||
logger_config: dict[str, Any],
|
|
||||||
model: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Model descriptor for training if not in provider config`",
|
|
||||||
),
|
|
||||||
checkpoint_dir: str | None = None,
|
|
||||||
algorithm_config: AlgorithmConfig | None = None,
|
|
||||||
) -> PostTrainingJob:
|
|
||||||
"""Run supervised fine-tuning of a model.
|
|
||||||
|
|
||||||
:param job_uuid: The UUID of the job to create.
|
|
||||||
:param training_config: The training configuration.
|
|
||||||
:param hyperparam_search_config: The hyperparam search configuration.
|
|
||||||
:param logger_config: The logger configuration.
|
|
||||||
:param model: The model to fine-tune.
|
|
||||||
:param checkpoint_dir: The directory to save checkpoint(s) to.
|
|
||||||
:param algorithm_config: The algorithm configuration.
|
|
||||||
:returns: A PostTrainingJob.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def preference_optimize(
|
|
||||||
self,
|
|
||||||
job_uuid: str,
|
|
||||||
finetuned_model: str,
|
|
||||||
algorithm_config: DPOAlignmentConfig,
|
|
||||||
training_config: TrainingConfig,
|
|
||||||
hyperparam_search_config: dict[str, Any],
|
|
||||||
logger_config: dict[str, Any],
|
|
||||||
) -> PostTrainingJob:
|
|
||||||
"""Run preference optimization of a model.
|
|
||||||
|
|
||||||
:param job_uuid: The UUID of the job to create.
|
|
||||||
:param finetuned_model: The model to fine-tune.
|
|
||||||
:param algorithm_config: The algorithm configuration.
|
|
||||||
:param training_config: The training configuration.
|
|
||||||
:param hyperparam_search_config: The hyperparam search configuration.
|
|
||||||
:param logger_config: The logger configuration.
|
|
||||||
:returns: A PostTrainingJob.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
|
||||||
"""Get all training jobs.
|
|
||||||
|
|
||||||
:returns: A ListPostTrainingJobsResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
|
||||||
"""Get the status of a training job.
|
|
||||||
|
|
||||||
:param job_uuid: The UUID of the job to get the status of.
|
|
||||||
:returns: A PostTrainingJobStatusResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
|
||||||
"""Cancel a training job.
|
|
||||||
|
|
||||||
:param job_uuid: The UUID of the job to cancel.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
|
||||||
"""Get the artifacts of a training job.
|
|
||||||
|
|
||||||
:param job_uuid: The UUID of the job to get the artifacts of.
|
|
||||||
:returns: A PostTrainingJobArtifactsResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
64
src/llama_stack/apis/post_training/post_training_service.py
Normal file
64
src/llama_stack/apis/post_training/post_training_service.py
Normal file
|
|
@ -0,0 +1,64 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
AlgorithmConfig,
|
||||||
|
DPOAlignmentConfig,
|
||||||
|
ListPostTrainingJobsResponse,
|
||||||
|
PostTrainingJob,
|
||||||
|
PostTrainingJobArtifactsResponse,
|
||||||
|
PostTrainingJobStatusResponse,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class PostTrainingService(Protocol):
|
||||||
|
async def supervised_fine_tune(
|
||||||
|
self,
|
||||||
|
job_uuid: str,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: dict[str, Any],
|
||||||
|
logger_config: dict[str, Any],
|
||||||
|
model: str | None = None,
|
||||||
|
checkpoint_dir: str | None = None,
|
||||||
|
algorithm_config: AlgorithmConfig | None = None,
|
||||||
|
) -> PostTrainingJob:
|
||||||
|
"""Run supervised fine-tuning of a model."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def preference_optimize(
|
||||||
|
self,
|
||||||
|
job_uuid: str,
|
||||||
|
finetuned_model: str,
|
||||||
|
algorithm_config: DPOAlignmentConfig,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: dict[str, Any],
|
||||||
|
logger_config: dict[str, Any],
|
||||||
|
) -> PostTrainingJob:
|
||||||
|
"""Run preference optimization of a model."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||||
|
"""Get all training jobs."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||||
|
"""Get the status of a training job."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
|
"""Cancel a training job."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||||
|
"""Get the artifacts of a training job."""
|
||||||
|
...
|
||||||
199
src/llama_stack/apis/post_training/routes.py
Normal file
199
src/llama_stack/apis/post_training/routes.py
Normal file
|
|
@ -0,0 +1,199 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Query, Request
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
ListPostTrainingJobsResponse,
|
||||||
|
PostTrainingJob,
|
||||||
|
PostTrainingJobArtifactsResponse,
|
||||||
|
PostTrainingJobStatusResponse,
|
||||||
|
PreferenceOptimizeRequest,
|
||||||
|
SupervisedFineTuneRequest,
|
||||||
|
)
|
||||||
|
from .post_training_service import PostTrainingService
|
||||||
|
|
||||||
|
|
||||||
|
def get_post_training_service(request: Request) -> PostTrainingService:
|
||||||
|
"""Dependency to get the post training service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.post_training not in impls:
|
||||||
|
raise ValueError("Post Training API implementation not found")
|
||||||
|
return impls[Api.post_training]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Post Training"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
router_v1alpha = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
|
||||||
|
tags=["Post Training"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/post-training/supervised-fine-tune",
|
||||||
|
response_model=PostTrainingJob,
|
||||||
|
summary="Run supervised fine-tuning of a model",
|
||||||
|
description="Run supervised fine-tuning of a model",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.post(
|
||||||
|
"/post-training/supervised-fine-tune",
|
||||||
|
response_model=PostTrainingJob,
|
||||||
|
summary="Run supervised fine-tuning of a model",
|
||||||
|
description="Run supervised fine-tuning of a model",
|
||||||
|
)
|
||||||
|
async def supervised_fine_tune(
|
||||||
|
body: SupervisedFineTuneRequest = Body(...),
|
||||||
|
svc: PostTrainingService = Depends(get_post_training_service),
|
||||||
|
) -> PostTrainingJob:
|
||||||
|
"""Run supervised fine-tuning of a model."""
|
||||||
|
return await svc.supervised_fine_tune(
|
||||||
|
job_uuid=body.job_uuid,
|
||||||
|
training_config=body.training_config,
|
||||||
|
hyperparam_search_config=body.hyperparam_search_config,
|
||||||
|
logger_config=body.logger_config,
|
||||||
|
model=body.model,
|
||||||
|
checkpoint_dir=body.checkpoint_dir,
|
||||||
|
algorithm_config=body.algorithm_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/post-training/preference-optimize",
|
||||||
|
response_model=PostTrainingJob,
|
||||||
|
summary="Run preference optimization of a model",
|
||||||
|
description="Run preference optimization of a model",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.post(
|
||||||
|
"/post-training/preference-optimize",
|
||||||
|
response_model=PostTrainingJob,
|
||||||
|
summary="Run preference optimization of a model",
|
||||||
|
description="Run preference optimization of a model",
|
||||||
|
)
|
||||||
|
async def preference_optimize(
|
||||||
|
body: PreferenceOptimizeRequest = Body(...),
|
||||||
|
svc: PostTrainingService = Depends(get_post_training_service),
|
||||||
|
) -> PostTrainingJob:
|
||||||
|
"""Run preference optimization of a model."""
|
||||||
|
return await svc.preference_optimize(
|
||||||
|
job_uuid=body.job_uuid,
|
||||||
|
finetuned_model=body.finetuned_model,
|
||||||
|
algorithm_config=body.algorithm_config,
|
||||||
|
training_config=body.training_config,
|
||||||
|
hyperparam_search_config=body.hyperparam_search_config,
|
||||||
|
logger_config=body.logger_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/post-training/jobs",
|
||||||
|
response_model=ListPostTrainingJobsResponse,
|
||||||
|
summary="Get all training jobs",
|
||||||
|
description="Get all training jobs",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/post-training/jobs",
|
||||||
|
response_model=ListPostTrainingJobsResponse,
|
||||||
|
summary="Get all training jobs",
|
||||||
|
description="Get all training jobs",
|
||||||
|
)
|
||||||
|
async def get_training_jobs(
|
||||||
|
svc: PostTrainingService = Depends(get_post_training_service),
|
||||||
|
) -> ListPostTrainingJobsResponse:
|
||||||
|
"""Get all training jobs."""
|
||||||
|
return await svc.get_training_jobs()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/post-training/job/status",
|
||||||
|
response_model=PostTrainingJobStatusResponse,
|
||||||
|
summary="Get the status of a training job",
|
||||||
|
description="Get the status of a training job",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/post-training/job/status",
|
||||||
|
response_model=PostTrainingJobStatusResponse,
|
||||||
|
summary="Get the status of a training job",
|
||||||
|
description="Get the status of a training job",
|
||||||
|
)
|
||||||
|
async def get_training_job_status(
|
||||||
|
job_uuid: str = Query(..., description="The UUID of the job to get the status of"),
|
||||||
|
svc: PostTrainingService = Depends(get_post_training_service),
|
||||||
|
) -> PostTrainingJobStatusResponse:
|
||||||
|
"""Get the status of a training job."""
|
||||||
|
return await svc.get_training_job_status(job_uuid=job_uuid)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/post-training/job/cancel",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Cancel a training job",
|
||||||
|
description="Cancel a training job",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.post(
|
||||||
|
"/post-training/job/cancel",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Cancel a training job",
|
||||||
|
description="Cancel a training job",
|
||||||
|
)
|
||||||
|
async def cancel_training_job(
|
||||||
|
job_uuid: str = Query(..., description="The UUID of the job to cancel"),
|
||||||
|
svc: PostTrainingService = Depends(get_post_training_service),
|
||||||
|
) -> None:
|
||||||
|
"""Cancel a training job."""
|
||||||
|
await svc.cancel_training_job(job_uuid=job_uuid)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/post-training/job/artifacts",
|
||||||
|
response_model=PostTrainingJobArtifactsResponse,
|
||||||
|
summary="Get the artifacts of a training job",
|
||||||
|
description="Get the artifacts of a training job",
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
|
@router_v1alpha.get(
|
||||||
|
"/post-training/job/artifacts",
|
||||||
|
response_model=PostTrainingJobArtifactsResponse,
|
||||||
|
summary="Get the artifacts of a training job",
|
||||||
|
description="Get the artifacts of a training job",
|
||||||
|
)
|
||||||
|
async def get_training_job_artifacts(
|
||||||
|
job_uuid: str = Query(..., description="The UUID of the job to get the artifacts of"),
|
||||||
|
svc: PostTrainingService = Depends(get_post_training_service),
|
||||||
|
) -> PostTrainingJobArtifactsResponse:
|
||||||
|
"""Get the artifacts of a training job."""
|
||||||
|
return await svc.get_training_job_artifacts(job_uuid=job_uuid)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_post_training_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Post Training API (legacy compatibility)."""
|
||||||
|
main_router = APIRouter()
|
||||||
|
main_router.include_router(router)
|
||||||
|
main_router.include_router(router_v1alpha)
|
||||||
|
return main_router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.post_training, create_post_training_router)
|
||||||
|
|
@ -4,6 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .prompts import ListPromptsResponse, Prompt, Prompts
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .models import (
|
||||||
|
CreatePromptRequest,
|
||||||
|
ListPromptsResponse,
|
||||||
|
Prompt,
|
||||||
|
SetDefaultVersionRequest,
|
||||||
|
UpdatePromptRequest,
|
||||||
|
)
|
||||||
|
from .prompts_service import PromptService
|
||||||
|
|
||||||
__all__ = ["Prompt", "Prompts", "ListPromptsResponse"]
|
# Backward compatibility - export Prompts as alias for PromptService
|
||||||
|
Prompts = PromptService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Prompts",
|
||||||
|
"PromptService",
|
||||||
|
"Prompt",
|
||||||
|
"ListPromptsResponse",
|
||||||
|
"CreatePromptRequest",
|
||||||
|
"UpdatePromptRequest",
|
||||||
|
"SetDefaultVersionRequest",
|
||||||
|
]
|
||||||
|
|
|
||||||
113
src/llama_stack/apis/prompts/models.py
Normal file
113
src/llama_stack/apis/prompts/models.py
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import re
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Prompt(BaseModel):
|
||||||
|
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack."""
|
||||||
|
|
||||||
|
prompt: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="The system prompt text with variable placeholders. Variables are only supported when using the Responses API.",
|
||||||
|
)
|
||||||
|
version: int = Field(description="Version (integer starting at 1, incremented on save).", ge=1)
|
||||||
|
prompt_id: str = Field(description="Unique identifier formatted as 'pmpt_<48-digit-hash>'.")
|
||||||
|
variables: list[str] = Field(
|
||||||
|
default_factory=list, description="List of prompt variable names that can be used in the prompt template."
|
||||||
|
)
|
||||||
|
is_default: bool = Field(
|
||||||
|
default=False, description="Boolean indicating whether this version is the default version for this prompt."
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("prompt_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_prompt_id(cls, prompt_id: str) -> str:
|
||||||
|
if not isinstance(prompt_id, str):
|
||||||
|
raise TypeError("prompt_id must be a string in format 'pmpt_<48-digit-hash>'")
|
||||||
|
|
||||||
|
if not prompt_id.startswith("pmpt_"):
|
||||||
|
raise ValueError("prompt_id must start with 'pmpt_' prefix")
|
||||||
|
|
||||||
|
hex_part = prompt_id[5:]
|
||||||
|
if len(hex_part) != 48:
|
||||||
|
raise ValueError("prompt_id must be in format 'pmpt_<48-digit-hash>' (48 lowercase hex chars)")
|
||||||
|
|
||||||
|
for char in hex_part:
|
||||||
|
if char not in "0123456789abcdef":
|
||||||
|
raise ValueError("prompt_id hex part must contain only lowercase hex characters [0-9a-f]")
|
||||||
|
|
||||||
|
return prompt_id
|
||||||
|
|
||||||
|
@field_validator("version")
|
||||||
|
@classmethod
|
||||||
|
def validate_version(cls, prompt_version: int) -> int:
|
||||||
|
if prompt_version < 1:
|
||||||
|
raise ValueError("version must be >= 1")
|
||||||
|
return prompt_version
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_prompt_variables(self):
|
||||||
|
"""Validate that all variables used in the prompt are declared in the variables list."""
|
||||||
|
if not self.prompt:
|
||||||
|
return self
|
||||||
|
|
||||||
|
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
|
||||||
|
declared_variables = set(self.variables)
|
||||||
|
|
||||||
|
undeclared = prompt_variables - declared_variables
|
||||||
|
if undeclared:
|
||||||
|
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_prompt_id(cls) -> str:
|
||||||
|
# Generate 48 hex characters (24 bytes)
|
||||||
|
random_bytes = secrets.token_bytes(24)
|
||||||
|
hex_string = random_bytes.hex()
|
||||||
|
return f"pmpt_{hex_string}"
|
||||||
|
|
||||||
|
|
||||||
|
class ListPromptsResponse(BaseModel):
|
||||||
|
"""Response model to list prompts."""
|
||||||
|
|
||||||
|
data: list[Prompt] = Field(description="List of prompt resources.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CreatePromptRequest(BaseModel):
|
||||||
|
"""Request model for creating a new prompt."""
|
||||||
|
|
||||||
|
prompt: str = Field(..., description="The prompt text content with variable placeholders.")
|
||||||
|
variables: list[str] | None = Field(
|
||||||
|
default=None, description="List of variable names that can be used in the prompt template."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class UpdatePromptRequest(BaseModel):
|
||||||
|
"""Request model for updating an existing prompt."""
|
||||||
|
|
||||||
|
prompt: str = Field(..., description="The updated prompt text content.")
|
||||||
|
version: int = Field(..., description="The current version of the prompt being updated.")
|
||||||
|
variables: list[str] | None = Field(
|
||||||
|
default=None, description="Updated list of variable names that can be used in the prompt template."
|
||||||
|
)
|
||||||
|
set_as_default: bool = Field(default=True, description="Set the new version as the default (default=True).")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SetDefaultVersionRequest(BaseModel):
|
||||||
|
"""Request model for setting a prompt version as default."""
|
||||||
|
|
||||||
|
version: int = Field(..., description="The version to set as default.")
|
||||||
|
|
@ -1,204 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import re
|
|
||||||
import secrets
|
|
||||||
from typing import Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Prompt(BaseModel):
|
|
||||||
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack.
|
|
||||||
|
|
||||||
:param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API.
|
|
||||||
:param version: Version (integer starting at 1, incremented on save)
|
|
||||||
:param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>'
|
|
||||||
:param variables: List of prompt variable names that can be used in the prompt template
|
|
||||||
:param is_default: Boolean indicating whether this version is the default version for this prompt
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
|
|
||||||
version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1)
|
|
||||||
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
|
|
||||||
variables: list[str] = Field(
|
|
||||||
default_factory=list, description="List of variable names that can be used in the prompt template"
|
|
||||||
)
|
|
||||||
is_default: bool = Field(
|
|
||||||
default=False, description="Boolean indicating whether this version is the default version"
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("prompt_id")
|
|
||||||
@classmethod
|
|
||||||
def validate_prompt_id(cls, prompt_id: str) -> str:
|
|
||||||
if not isinstance(prompt_id, str):
|
|
||||||
raise TypeError("prompt_id must be a string in format 'pmpt_<48-digit-hash>'")
|
|
||||||
|
|
||||||
if not prompt_id.startswith("pmpt_"):
|
|
||||||
raise ValueError("prompt_id must start with 'pmpt_' prefix")
|
|
||||||
|
|
||||||
hex_part = prompt_id[5:]
|
|
||||||
if len(hex_part) != 48:
|
|
||||||
raise ValueError("prompt_id must be in format 'pmpt_<48-digit-hash>' (48 lowercase hex chars)")
|
|
||||||
|
|
||||||
for char in hex_part:
|
|
||||||
if char not in "0123456789abcdef":
|
|
||||||
raise ValueError("prompt_id hex part must contain only lowercase hex characters [0-9a-f]")
|
|
||||||
|
|
||||||
return prompt_id
|
|
||||||
|
|
||||||
@field_validator("version")
|
|
||||||
@classmethod
|
|
||||||
def validate_version(cls, prompt_version: int) -> int:
|
|
||||||
if prompt_version < 1:
|
|
||||||
raise ValueError("version must be >= 1")
|
|
||||||
return prompt_version
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_prompt_variables(self):
|
|
||||||
"""Validate that all variables used in the prompt are declared in the variables list."""
|
|
||||||
if not self.prompt:
|
|
||||||
return self
|
|
||||||
|
|
||||||
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
|
|
||||||
declared_variables = set(self.variables)
|
|
||||||
|
|
||||||
undeclared = prompt_variables - declared_variables
|
|
||||||
if undeclared:
|
|
||||||
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def generate_prompt_id(cls) -> str:
|
|
||||||
# Generate 48 hex characters (24 bytes)
|
|
||||||
random_bytes = secrets.token_bytes(24)
|
|
||||||
hex_string = random_bytes.hex()
|
|
||||||
return f"pmpt_{hex_string}"
|
|
||||||
|
|
||||||
|
|
||||||
class ListPromptsResponse(BaseModel):
|
|
||||||
"""Response model to list prompts."""
|
|
||||||
|
|
||||||
data: list[Prompt]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
@trace_protocol
|
|
||||||
class Prompts(Protocol):
|
|
||||||
"""Prompts
|
|
||||||
|
|
||||||
Protocol for prompt management operations."""
|
|
||||||
|
|
||||||
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def list_prompts(self) -> ListPromptsResponse:
|
|
||||||
"""List all prompts.
|
|
||||||
|
|
||||||
:returns: A ListPromptsResponse containing all prompts.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/prompts/{prompt_id}/versions", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def list_prompt_versions(
|
|
||||||
self,
|
|
||||||
prompt_id: str,
|
|
||||||
) -> ListPromptsResponse:
|
|
||||||
"""List prompt versions.
|
|
||||||
|
|
||||||
List all versions of a specific prompt.
|
|
||||||
|
|
||||||
:param prompt_id: The identifier of the prompt to list versions for.
|
|
||||||
:returns: A ListPromptsResponse containing all versions of the prompt.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/prompts/{prompt_id}", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def get_prompt(
|
|
||||||
self,
|
|
||||||
prompt_id: str,
|
|
||||||
version: int | None = None,
|
|
||||||
) -> Prompt:
|
|
||||||
"""Get prompt.
|
|
||||||
|
|
||||||
Get a prompt by its identifier and optional version.
|
|
||||||
|
|
||||||
:param prompt_id: The identifier of the prompt to get.
|
|
||||||
:param version: The version of the prompt to get (defaults to latest).
|
|
||||||
:returns: A Prompt resource.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/prompts", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def create_prompt(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
variables: list[str] | None = None,
|
|
||||||
) -> Prompt:
|
|
||||||
"""Create prompt.
|
|
||||||
|
|
||||||
Create a new prompt.
|
|
||||||
|
|
||||||
:param prompt: The prompt text content with variable placeholders.
|
|
||||||
:param variables: List of variable names that can be used in the prompt template.
|
|
||||||
:returns: The created Prompt resource.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/prompts/{prompt_id}", method="PUT", level=LLAMA_STACK_API_V1)
|
|
||||||
async def update_prompt(
|
|
||||||
self,
|
|
||||||
prompt_id: str,
|
|
||||||
prompt: str,
|
|
||||||
version: int,
|
|
||||||
variables: list[str] | None = None,
|
|
||||||
set_as_default: bool = True,
|
|
||||||
) -> Prompt:
|
|
||||||
"""Update prompt.
|
|
||||||
|
|
||||||
Update an existing prompt (increments version).
|
|
||||||
|
|
||||||
:param prompt_id: The identifier of the prompt to update.
|
|
||||||
:param prompt: The updated prompt text content.
|
|
||||||
:param version: The current version of the prompt being updated.
|
|
||||||
:param variables: Updated list of variable names that can be used in the prompt template.
|
|
||||||
:param set_as_default: Set the new version as the default (default=True).
|
|
||||||
:returns: The updated Prompt resource with incremented version.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/prompts/{prompt_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
|
||||||
async def delete_prompt(
|
|
||||||
self,
|
|
||||||
prompt_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Delete prompt.
|
|
||||||
|
|
||||||
Delete a prompt.
|
|
||||||
|
|
||||||
:param prompt_id: The identifier of the prompt to delete.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT", level=LLAMA_STACK_API_V1)
|
|
||||||
async def set_default_version(
|
|
||||||
self,
|
|
||||||
prompt_id: str,
|
|
||||||
version: int,
|
|
||||||
) -> Prompt:
|
|
||||||
"""Set prompt version.
|
|
||||||
|
|
||||||
Set which version of a prompt should be the default in get_prompt (latest).
|
|
||||||
|
|
||||||
:param prompt_id: The identifier of the prompt.
|
|
||||||
:param version: The version to set as default.
|
|
||||||
:returns: The prompt with the specified version now set as default.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
72
src/llama_stack/apis/prompts/prompts_service.py
Normal file
72
src/llama_stack/apis/prompts/prompts_service.py
Normal file
|
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
from .models import ListPromptsResponse, Prompt
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class PromptService(Protocol):
|
||||||
|
"""Prompts
|
||||||
|
|
||||||
|
Protocol for prompt management operations."""
|
||||||
|
|
||||||
|
async def list_prompts(self) -> ListPromptsResponse:
|
||||||
|
"""List all prompts."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def list_prompt_versions(
|
||||||
|
self,
|
||||||
|
prompt_id: str,
|
||||||
|
) -> ListPromptsResponse:
|
||||||
|
"""List prompt versions."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_prompt(
|
||||||
|
self,
|
||||||
|
prompt_id: str,
|
||||||
|
version: int | None = None,
|
||||||
|
) -> Prompt:
|
||||||
|
"""Get prompt."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def create_prompt(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
variables: list[str] | None = None,
|
||||||
|
) -> Prompt:
|
||||||
|
"""Create prompt."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def update_prompt(
|
||||||
|
self,
|
||||||
|
prompt_id: str,
|
||||||
|
prompt: str,
|
||||||
|
version: int,
|
||||||
|
variables: list[str] | None = None,
|
||||||
|
set_as_default: bool = True,
|
||||||
|
) -> Prompt:
|
||||||
|
"""Update prompt."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def delete_prompt(
|
||||||
|
self,
|
||||||
|
prompt_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Delete prompt."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def set_default_version(
|
||||||
|
self,
|
||||||
|
prompt_id: str,
|
||||||
|
version: int,
|
||||||
|
) -> Prompt:
|
||||||
|
"""Set prompt version."""
|
||||||
|
...
|
||||||
154
src/llama_stack/apis/prompts/routes.py
Normal file
154
src/llama_stack/apis/prompts/routes.py
Normal file
|
|
@ -0,0 +1,154 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Query, Request
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
CreatePromptRequest,
|
||||||
|
ListPromptsResponse,
|
||||||
|
Prompt,
|
||||||
|
SetDefaultVersionRequest,
|
||||||
|
UpdatePromptRequest,
|
||||||
|
)
|
||||||
|
from .prompts_service import PromptService
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_service(request: Request) -> PromptService:
|
||||||
|
"""Dependency to get the prompt service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.prompts not in impls:
|
||||||
|
raise ValueError("Prompts API implementation not found")
|
||||||
|
return impls[Api.prompts]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Prompts"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/prompts",
|
||||||
|
response_model=ListPromptsResponse,
|
||||||
|
summary="List all prompts",
|
||||||
|
description="List all prompts registered in Llama Stack",
|
||||||
|
)
|
||||||
|
async def list_prompts(svc: PromptService = Depends(get_prompt_service)) -> ListPromptsResponse:
|
||||||
|
"""List all prompts."""
|
||||||
|
return await svc.list_prompts()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/prompts/{prompt_id}/versions",
|
||||||
|
response_model=ListPromptsResponse,
|
||||||
|
summary="List prompt versions",
|
||||||
|
description="List all versions of a specific prompt",
|
||||||
|
)
|
||||||
|
async def list_prompt_versions(
|
||||||
|
prompt_id: Annotated[str, FastAPIPath(..., description="The identifier of the prompt to list versions for")],
|
||||||
|
svc: PromptService = Depends(get_prompt_service),
|
||||||
|
) -> ListPromptsResponse:
|
||||||
|
"""List prompt versions."""
|
||||||
|
return await svc.list_prompt_versions(prompt_id=prompt_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/prompts/{prompt_id}",
|
||||||
|
response_model=Prompt,
|
||||||
|
summary="Get prompt",
|
||||||
|
description="Get a prompt by its identifier and optional version",
|
||||||
|
)
|
||||||
|
async def get_prompt(
|
||||||
|
prompt_id: Annotated[str, FastAPIPath(..., description="The identifier of the prompt to get")],
|
||||||
|
version: int | None = Query(None, description="The version of the prompt to get (defaults to latest)"),
|
||||||
|
svc: PromptService = Depends(get_prompt_service),
|
||||||
|
) -> Prompt:
|
||||||
|
"""Get prompt by its identifier and optional version."""
|
||||||
|
return await svc.get_prompt(prompt_id=prompt_id, version=version)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/prompts",
|
||||||
|
response_model=Prompt,
|
||||||
|
summary="Create prompt",
|
||||||
|
description="Create a new prompt",
|
||||||
|
)
|
||||||
|
async def create_prompt(
|
||||||
|
body: CreatePromptRequest = Body(...),
|
||||||
|
svc: PromptService = Depends(get_prompt_service),
|
||||||
|
) -> Prompt:
|
||||||
|
"""Create a new prompt."""
|
||||||
|
return await svc.create_prompt(prompt=body.prompt, variables=body.variables)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/prompts/{prompt_id}",
|
||||||
|
response_model=Prompt,
|
||||||
|
summary="Update prompt",
|
||||||
|
description="Update an existing prompt (increments version)",
|
||||||
|
)
|
||||||
|
async def update_prompt(
|
||||||
|
prompt_id: Annotated[str, FastAPIPath(..., description="The identifier of the prompt to update")],
|
||||||
|
body: UpdatePromptRequest = Body(...),
|
||||||
|
svc: PromptService = Depends(get_prompt_service),
|
||||||
|
) -> Prompt:
|
||||||
|
"""Update an existing prompt."""
|
||||||
|
return await svc.update_prompt(
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
prompt=body.prompt,
|
||||||
|
version=body.version,
|
||||||
|
variables=body.variables,
|
||||||
|
set_as_default=body.set_as_default,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/prompts/{prompt_id}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Delete prompt",
|
||||||
|
description="Delete a prompt",
|
||||||
|
)
|
||||||
|
async def delete_prompt(
|
||||||
|
prompt_id: Annotated[str, FastAPIPath(..., description="The identifier of the prompt to delete")],
|
||||||
|
svc: PromptService = Depends(get_prompt_service),
|
||||||
|
) -> None:
|
||||||
|
"""Delete a prompt."""
|
||||||
|
await svc.delete_prompt(prompt_id=prompt_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/prompts/{prompt_id}/set-default-version",
|
||||||
|
response_model=Prompt,
|
||||||
|
summary="Set prompt version",
|
||||||
|
description="Set which version of a prompt should be the default in get_prompt (latest)",
|
||||||
|
)
|
||||||
|
async def set_default_version(
|
||||||
|
prompt_id: Annotated[str, FastAPIPath(..., description="The identifier of the prompt")],
|
||||||
|
body: SetDefaultVersionRequest = Body(...),
|
||||||
|
svc: PromptService = Depends(get_prompt_service),
|
||||||
|
) -> Prompt:
|
||||||
|
"""Set which version of a prompt should be the default."""
|
||||||
|
return await svc.set_default_version(prompt_id=prompt_id, version=body.version)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_prompts_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Prompts API (legacy compatibility)."""
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.prompts, create_prompts_router)
|
||||||
|
|
@ -4,4 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .providers import *
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .models import ListProvidersResponse, ProviderInfo
|
||||||
|
from .providers_service import ProviderService
|
||||||
|
|
||||||
|
# Backward compatibility - export Providers as alias for ProviderService
|
||||||
|
Providers = ProviderService
|
||||||
|
|
||||||
|
__all__ = ["Providers", "ProviderService", "ListProvidersResponse", "ProviderInfo"]
|
||||||
|
|
|
||||||
29
src/llama_stack/apis/providers/models.py
Normal file
29
src/llama_stack/apis/providers/models.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import HealthResponse
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ProviderInfo(BaseModel):
|
||||||
|
"""Information about a registered provider including its configuration and health status."""
|
||||||
|
|
||||||
|
api: str = Field(..., description="The API name this provider implements")
|
||||||
|
provider_id: str = Field(..., description="Unique identifier for the provider")
|
||||||
|
provider_type: str = Field(..., description="The type of provider implementation")
|
||||||
|
config: dict[str, Any] = Field(..., description="Configuration parameters for the provider")
|
||||||
|
health: HealthResponse = Field(..., description="Current health status of the provider")
|
||||||
|
|
||||||
|
|
||||||
|
class ListProvidersResponse(BaseModel):
|
||||||
|
"""Response containing a list of all available providers."""
|
||||||
|
|
||||||
|
data: list[ProviderInfo] = Field(..., description="List of provider information objects")
|
||||||
|
|
@ -1,69 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.providers.datatypes import HealthResponse
|
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ProviderInfo(BaseModel):
|
|
||||||
"""Information about a registered provider including its configuration and health status.
|
|
||||||
|
|
||||||
:param api: The API name this provider implements
|
|
||||||
:param provider_id: Unique identifier for the provider
|
|
||||||
:param provider_type: The type of provider implementation
|
|
||||||
:param config: Configuration parameters for the provider
|
|
||||||
:param health: Current health status of the provider
|
|
||||||
"""
|
|
||||||
|
|
||||||
api: str
|
|
||||||
provider_id: str
|
|
||||||
provider_type: str
|
|
||||||
config: dict[str, Any]
|
|
||||||
health: HealthResponse
|
|
||||||
|
|
||||||
|
|
||||||
class ListProvidersResponse(BaseModel):
|
|
||||||
"""Response containing a list of all available providers.
|
|
||||||
|
|
||||||
:param data: List of provider information objects
|
|
||||||
"""
|
|
||||||
|
|
||||||
data: list[ProviderInfo]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class Providers(Protocol):
|
|
||||||
"""Providers
|
|
||||||
|
|
||||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def list_providers(self) -> ListProvidersResponse:
|
|
||||||
"""List providers.
|
|
||||||
|
|
||||||
List all available providers.
|
|
||||||
|
|
||||||
:returns: A ListProvidersResponse containing information about all providers.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
|
||||||
"""Get provider.
|
|
||||||
|
|
||||||
Get detailed information about a specific provider.
|
|
||||||
|
|
||||||
:param provider_id: The ID of the provider to inspect.
|
|
||||||
:returns: A ProviderInfo object containing the provider's details.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
25
src/llama_stack/apis/providers/providers_service.py
Normal file
25
src/llama_stack/apis/providers/providers_service.py
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from .models import ListProvidersResponse, ProviderInfo
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ProviderService(Protocol):
|
||||||
|
"""Providers
|
||||||
|
|
||||||
|
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def list_providers(self) -> ListProvidersResponse:
|
||||||
|
"""List providers."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||||
|
"""Get provider."""
|
||||||
|
...
|
||||||
68
src/llama_stack/apis/providers/routes.py
Normal file
68
src/llama_stack/apis/providers/routes.py
Normal file
|
|
@ -0,0 +1,68 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, Request
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .models import ListProvidersResponse, ProviderInfo
|
||||||
|
from .providers_service import ProviderService
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_service(request: Request) -> ProviderService:
|
||||||
|
"""Dependency to get the provider service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.providers not in impls:
|
||||||
|
raise ValueError("Providers API implementation not found")
|
||||||
|
return impls[Api.providers]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Providers"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/providers",
|
||||||
|
response_model=ListProvidersResponse,
|
||||||
|
summary="List providers",
|
||||||
|
description="List all available providers",
|
||||||
|
)
|
||||||
|
async def list_providers(svc: ProviderService = Depends(get_provider_service)) -> ListProvidersResponse:
|
||||||
|
"""List all available providers."""
|
||||||
|
return await svc.list_providers()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/providers/{provider_id}",
|
||||||
|
response_model=ProviderInfo,
|
||||||
|
summary="Get provider",
|
||||||
|
description="Get detailed information about a specific provider",
|
||||||
|
)
|
||||||
|
async def inspect_provider(
|
||||||
|
provider_id: Annotated[str, FastAPIPath(..., description="The ID of the provider to inspect")],
|
||||||
|
svc: ProviderService = Depends(get_provider_service),
|
||||||
|
) -> ProviderInfo:
|
||||||
|
"""Get detailed information about a specific provider."""
|
||||||
|
return await svc.inspect_provider(provider_id=provider_id)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_providers_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Providers API (legacy compatibility)."""
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.providers, create_providers_router)
|
||||||
|
|
@ -4,4 +4,31 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .safety import *
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .models import (
|
||||||
|
ModerationObject,
|
||||||
|
ModerationObjectResults,
|
||||||
|
RunModerationRequest,
|
||||||
|
RunShieldRequest,
|
||||||
|
RunShieldResponse,
|
||||||
|
SafetyViolation,
|
||||||
|
ViolationLevel,
|
||||||
|
)
|
||||||
|
from .safety_service import SafetyService, ShieldStore
|
||||||
|
|
||||||
|
# Backward compatibility - export Safety as alias for SafetyService
|
||||||
|
Safety = SafetyService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Safety",
|
||||||
|
"SafetyService",
|
||||||
|
"ShieldStore",
|
||||||
|
"ModerationObject",
|
||||||
|
"ModerationObjectResults",
|
||||||
|
"RunShieldRequest",
|
||||||
|
"RunShieldResponse",
|
||||||
|
"RunModerationRequest",
|
||||||
|
"SafetyViolation",
|
||||||
|
"ViolationLevel",
|
||||||
|
]
|
||||||
|
|
|
||||||
96
src/llama_stack/apis/safety/models.py
Normal file
96
src/llama_stack/apis/safety/models.py
Normal file
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import OpenAIMessageParam
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ModerationObjectResults(BaseModel):
|
||||||
|
"""A moderation object."""
|
||||||
|
|
||||||
|
flagged: bool = Field(..., description="Whether any of the below categories are flagged.")
|
||||||
|
categories: dict[str, bool] | None = Field(
|
||||||
|
default=None, description="A list of the categories, and whether they are flagged or not."
|
||||||
|
)
|
||||||
|
category_applied_input_types: dict[str, list[str]] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="A list of the categories along with the input type(s) that the score applies to.",
|
||||||
|
)
|
||||||
|
category_scores: dict[str, float] | None = Field(
|
||||||
|
default=None, description="A list of the categories along with their scores as predicted by model."
|
||||||
|
)
|
||||||
|
user_message: str | None = Field(default=None, description="User message.")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ModerationObject(BaseModel):
|
||||||
|
"""A moderation object."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="The unique identifier for the moderation request.")
|
||||||
|
model: str = Field(..., description="The model used to generate the moderation results.")
|
||||||
|
results: list[ModerationObjectResults] = Field(..., description="A list of moderation objects.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ViolationLevel(Enum):
|
||||||
|
"""Severity level of a safety violation.
|
||||||
|
|
||||||
|
:cvar INFO: Informational level violation that does not require action
|
||||||
|
:cvar WARN: Warning level violation that suggests caution but allows continuation
|
||||||
|
:cvar ERROR: Error level violation that requires blocking or intervention
|
||||||
|
"""
|
||||||
|
|
||||||
|
INFO = "info"
|
||||||
|
WARN = "warn"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SafetyViolation(BaseModel):
|
||||||
|
"""Details of a safety violation detected by content moderation."""
|
||||||
|
|
||||||
|
violation_level: ViolationLevel = Field(..., description="Severity level of the violation.")
|
||||||
|
user_message: str | None = Field(default=None, description="Message to convey to the user about the violation.")
|
||||||
|
metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Additional metadata including specific violation codes for debugging and telemetry.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RunShieldResponse(BaseModel):
|
||||||
|
"""Response from running a safety shield."""
|
||||||
|
|
||||||
|
violation: SafetyViolation | None = Field(
|
||||||
|
default=None, description="Safety violation detected by the shield, if any."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RunShieldRequest(BaseModel):
|
||||||
|
"""Request model for running a shield."""
|
||||||
|
|
||||||
|
shield_id: str = Field(..., description="The identifier of the shield to run.")
|
||||||
|
messages: list[OpenAIMessageParam] = Field(..., description="The messages to run the shield on.")
|
||||||
|
params: dict[str, Any] = Field(..., description="The parameters of the shield.")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RunModerationRequest(BaseModel):
|
||||||
|
"""Request model for running moderation."""
|
||||||
|
|
||||||
|
input: str | list[str] = Field(
|
||||||
|
...,
|
||||||
|
description="Input (or inputs) to classify. Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.",
|
||||||
|
)
|
||||||
|
model: str | None = Field(default=None, description="The content moderation model you would like to use.")
|
||||||
68
src/llama_stack/apis/safety/routes.py
Normal file
68
src/llama_stack/apis/safety/routes.py
Normal file
|
|
@ -0,0 +1,68 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Request
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .models import ModerationObject, RunModerationRequest, RunShieldRequest, RunShieldResponse
|
||||||
|
from .safety_service import SafetyService
|
||||||
|
|
||||||
|
|
||||||
|
def get_safety_service(request: Request) -> SafetyService:
|
||||||
|
"""Dependency to get the safety service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.safety not in impls:
|
||||||
|
raise ValueError("Safety API implementation not found")
|
||||||
|
return impls[Api.safety]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Safety"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/safety/run-shield",
|
||||||
|
response_model=RunShieldResponse,
|
||||||
|
summary="Run shield.",
|
||||||
|
description="Run a shield.",
|
||||||
|
)
|
||||||
|
async def run_shield(
|
||||||
|
body: RunShieldRequest = Body(...),
|
||||||
|
svc: SafetyService = Depends(get_safety_service),
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
"""Run a shield."""
|
||||||
|
return await svc.run_shield(shield_id=body.shield_id, messages=body.messages, params=body.params)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/moderations",
|
||||||
|
response_model=ModerationObject,
|
||||||
|
summary="Create moderation.",
|
||||||
|
description="Classifies if text and/or image inputs are potentially harmful.",
|
||||||
|
)
|
||||||
|
async def run_moderation(
|
||||||
|
body: RunModerationRequest = Body(...),
|
||||||
|
svc: SafetyService = Depends(get_safety_service),
|
||||||
|
) -> ModerationObject:
|
||||||
|
"""Create moderation."""
|
||||||
|
return await svc.run_moderation(input=body.input, model=body.model)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_safety_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Safety API (legacy compatibility)."""
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.safety, create_safety_router)
|
||||||
|
|
@ -1,134 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import OpenAIMessageParam
|
|
||||||
from llama_stack.apis.shields import Shield
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ModerationObjectResults(BaseModel):
|
|
||||||
"""A moderation object.
|
|
||||||
:param flagged: Whether any of the below categories are flagged.
|
|
||||||
:param categories: A list of the categories, and whether they are flagged or not.
|
|
||||||
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
|
|
||||||
:param category_scores: A list of the categories along with their scores as predicted by model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
flagged: bool
|
|
||||||
categories: dict[str, bool] | None = None
|
|
||||||
category_applied_input_types: dict[str, list[str]] | None = None
|
|
||||||
category_scores: dict[str, float] | None = None
|
|
||||||
user_message: str | None = None
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ModerationObject(BaseModel):
|
|
||||||
"""A moderation object.
|
|
||||||
:param id: The unique identifier for the moderation request.
|
|
||||||
:param model: The model used to generate the moderation results.
|
|
||||||
:param results: A list of moderation objects
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
model: str
|
|
||||||
results: list[ModerationObjectResults]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ViolationLevel(Enum):
|
|
||||||
"""Severity level of a safety violation.
|
|
||||||
|
|
||||||
:cvar INFO: Informational level violation that does not require action
|
|
||||||
:cvar WARN: Warning level violation that suggests caution but allows continuation
|
|
||||||
:cvar ERROR: Error level violation that requires blocking or intervention
|
|
||||||
"""
|
|
||||||
|
|
||||||
INFO = "info"
|
|
||||||
WARN = "warn"
|
|
||||||
ERROR = "error"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class SafetyViolation(BaseModel):
|
|
||||||
"""Details of a safety violation detected by content moderation.
|
|
||||||
|
|
||||||
:param violation_level: Severity level of the violation
|
|
||||||
:param user_message: (Optional) Message to convey to the user about the violation
|
|
||||||
:param metadata: Additional metadata including specific violation codes for debugging and telemetry
|
|
||||||
"""
|
|
||||||
|
|
||||||
violation_level: ViolationLevel
|
|
||||||
|
|
||||||
# what message should you convey to the user
|
|
||||||
user_message: str | None = None
|
|
||||||
|
|
||||||
# additional metadata (including specific violation codes) more for
|
|
||||||
# debugging, telemetry
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RunShieldResponse(BaseModel):
|
|
||||||
"""Response from running a safety shield.
|
|
||||||
|
|
||||||
:param violation: (Optional) Safety violation detected by the shield, if any
|
|
||||||
"""
|
|
||||||
|
|
||||||
violation: SafetyViolation | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ShieldStore(Protocol):
|
|
||||||
async def get_shield(self, identifier: str) -> Shield: ...
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
@trace_protocol
|
|
||||||
class Safety(Protocol):
|
|
||||||
"""Safety
|
|
||||||
|
|
||||||
OpenAI-compatible Moderations API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
shield_store: ShieldStore
|
|
||||||
|
|
||||||
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def run_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
params: dict[str, Any],
|
|
||||||
) -> RunShieldResponse:
|
|
||||||
"""Run shield.
|
|
||||||
|
|
||||||
Run a shield.
|
|
||||||
|
|
||||||
:param shield_id: The identifier of the shield to run.
|
|
||||||
:param messages: The messages to run the shield on.
|
|
||||||
:param params: The parameters of the shield.
|
|
||||||
:returns: A RunShieldResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
|
||||||
"""Create moderation.
|
|
||||||
|
|
||||||
Classifies if text and/or image inputs are potentially harmful.
|
|
||||||
:param input: Input (or inputs) to classify.
|
|
||||||
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
|
||||||
:param model: (Optional) The content moderation model you would like to use.
|
|
||||||
:returns: A moderation object.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
41
src/llama_stack/apis/safety/safety_service.py
Normal file
41
src/llama_stack/apis/safety/safety_service.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import OpenAIMessageParam
|
||||||
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
from .models import ModerationObject, RunShieldResponse
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldStore(Protocol):
|
||||||
|
async def get_shield(self, identifier: str) -> Shield: ...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class SafetyService(Protocol):
|
||||||
|
"""Safety
|
||||||
|
|
||||||
|
OpenAI-compatible Moderations API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
shield_store: ShieldStore
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
messages: list[OpenAIMessageParam],
|
||||||
|
params: dict[str, Any],
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
"""Run shield."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||||
|
"""Create moderation."""
|
||||||
|
...
|
||||||
|
|
@ -4,4 +4,29 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .scoring import *
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .models import (
|
||||||
|
ScoreBatchRequest,
|
||||||
|
ScoreBatchResponse,
|
||||||
|
ScoreRequest,
|
||||||
|
ScoreResponse,
|
||||||
|
ScoringResult,
|
||||||
|
ScoringResultRow,
|
||||||
|
)
|
||||||
|
from .scoring_service import ScoringFunctionStore, ScoringService
|
||||||
|
|
||||||
|
# Backward compatibility - export Scoring as alias for ScoringService
|
||||||
|
Scoring = ScoringService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Scoring",
|
||||||
|
"ScoringService",
|
||||||
|
"ScoringFunctionStore",
|
||||||
|
"ScoreBatchRequest",
|
||||||
|
"ScoreBatchResponse",
|
||||||
|
"ScoreRequest",
|
||||||
|
"ScoreResponse",
|
||||||
|
"ScoringResult",
|
||||||
|
"ScoringResultRow",
|
||||||
|
]
|
||||||
|
|
|
||||||
61
src/llama_stack/apis/scoring/models.py
Normal file
61
src/llama_stack/apis/scoring/models.py
Normal file
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
# mapping of metric to value
|
||||||
|
ScoringResultRow = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoringResult(BaseModel):
|
||||||
|
"""A scoring result for a single row."""
|
||||||
|
|
||||||
|
score_rows: list[ScoringResultRow] = Field(
|
||||||
|
..., description="The scoring result for each row. Each row is a map of column name to value"
|
||||||
|
)
|
||||||
|
aggregated_results: dict[str, Any] = Field(..., description="Map of metric name to aggregated value")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoreBatchResponse(BaseModel):
|
||||||
|
"""Response from batch scoring operations on datasets."""
|
||||||
|
|
||||||
|
dataset_id: str | None = Field(default=None, description="The identifier of the dataset that was scored")
|
||||||
|
results: dict[str, ScoringResult] = Field(..., description="A map of scoring function name to ScoringResult")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoreResponse(BaseModel):
|
||||||
|
"""The response from scoring."""
|
||||||
|
|
||||||
|
results: dict[str, ScoringResult] = Field(..., description="A map of scoring function name to ScoringResult")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoreBatchRequest(BaseModel):
|
||||||
|
"""Request for batch scoring operations."""
|
||||||
|
|
||||||
|
dataset_id: str = Field(..., description="The ID of the dataset to score")
|
||||||
|
scoring_functions: dict[str, ScoringFnParams | None] = Field(
|
||||||
|
..., description="The scoring functions to use for the scoring"
|
||||||
|
)
|
||||||
|
save_results_dataset: bool = Field(default=False, description="Whether to save the results to a dataset")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoreRequest(BaseModel):
|
||||||
|
"""Request for scoring a list of rows."""
|
||||||
|
|
||||||
|
input_rows: list[dict[str, Any]] = Field(..., description="The rows to score")
|
||||||
|
scoring_functions: dict[str, ScoringFnParams | None] = Field(
|
||||||
|
..., description="The scoring functions to use for the scoring"
|
||||||
|
)
|
||||||
75
src/llama_stack/apis/scoring/routes.py
Normal file
75
src/llama_stack/apis/scoring/routes.py
Normal file
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Request
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .models import ScoreBatchRequest, ScoreBatchResponse, ScoreRequest, ScoreResponse
|
||||||
|
from .scoring_service import ScoringService
|
||||||
|
|
||||||
|
|
||||||
|
def get_scoring_service(request: Request) -> ScoringService:
|
||||||
|
"""Dependency to get the scoring service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.scoring not in impls:
|
||||||
|
raise ValueError("Scoring API implementation not found")
|
||||||
|
return impls[Api.scoring]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Scoring"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/scoring/score-batch",
|
||||||
|
response_model=ScoreBatchResponse,
|
||||||
|
summary="Score a batch of rows",
|
||||||
|
description="Score a batch of rows from a dataset",
|
||||||
|
)
|
||||||
|
async def score_batch(
|
||||||
|
body: ScoreBatchRequest = Body(...),
|
||||||
|
svc: ScoringService = Depends(get_scoring_service),
|
||||||
|
) -> ScoreBatchResponse:
|
||||||
|
"""Score a batch of rows from a dataset."""
|
||||||
|
return await svc.score_batch(
|
||||||
|
dataset_id=body.dataset_id,
|
||||||
|
scoring_functions=body.scoring_functions,
|
||||||
|
save_results_dataset=body.save_results_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/scoring/score",
|
||||||
|
response_model=ScoreResponse,
|
||||||
|
summary="Score a list of rows",
|
||||||
|
description="Score a list of rows",
|
||||||
|
)
|
||||||
|
async def score(
|
||||||
|
body: ScoreRequest = Body(...),
|
||||||
|
svc: ScoringService = Depends(get_scoring_service),
|
||||||
|
) -> ScoreResponse:
|
||||||
|
"""Score a list of rows."""
|
||||||
|
return await svc.score(
|
||||||
|
input_rows=body.input_rows,
|
||||||
|
scoring_functions=body.scoring_functions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_scoring_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Scoring API (legacy compatibility)."""
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.scoring, create_scoring_router)
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
# mapping of metric to value
|
|
||||||
ScoringResultRow = dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ScoringResult(BaseModel):
|
|
||||||
"""
|
|
||||||
A scoring result for a single row.
|
|
||||||
|
|
||||||
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
|
|
||||||
:param aggregated_results: Map of metric name to aggregated value
|
|
||||||
"""
|
|
||||||
|
|
||||||
score_rows: list[ScoringResultRow]
|
|
||||||
# aggregated metrics to value
|
|
||||||
aggregated_results: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ScoreBatchResponse(BaseModel):
|
|
||||||
"""Response from batch scoring operations on datasets.
|
|
||||||
|
|
||||||
:param dataset_id: (Optional) The identifier of the dataset that was scored
|
|
||||||
:param results: A map of scoring function name to ScoringResult
|
|
||||||
"""
|
|
||||||
|
|
||||||
dataset_id: str | None = None
|
|
||||||
results: dict[str, ScoringResult]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ScoreResponse(BaseModel):
|
|
||||||
"""
|
|
||||||
The response from scoring.
|
|
||||||
|
|
||||||
:param results: A map of scoring function name to ScoringResult.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# each key in the dict is a scoring function name
|
|
||||||
results: dict[str, ScoringResult]
|
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionStore(Protocol):
|
|
||||||
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class Scoring(Protocol):
|
|
||||||
scoring_function_store: ScoringFunctionStore
|
|
||||||
|
|
||||||
@webmethod(route="/scoring/score-batch", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def score_batch(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
scoring_functions: dict[str, ScoringFnParams | None],
|
|
||||||
save_results_dataset: bool = False,
|
|
||||||
) -> ScoreBatchResponse:
|
|
||||||
"""Score a batch of rows.
|
|
||||||
|
|
||||||
:param dataset_id: The ID of the dataset to score.
|
|
||||||
:param scoring_functions: The scoring functions to use for the scoring.
|
|
||||||
:param save_results_dataset: Whether to save the results to a dataset.
|
|
||||||
:returns: A ScoreBatchResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/scoring/score", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def score(
|
|
||||||
self,
|
|
||||||
input_rows: list[dict[str, Any]],
|
|
||||||
scoring_functions: dict[str, ScoringFnParams | None],
|
|
||||||
) -> ScoreResponse:
|
|
||||||
"""Score a list of rows.
|
|
||||||
|
|
||||||
:param input_rows: The rows to score.
|
|
||||||
:param scoring_functions: The scoring functions to use for the scoring.
|
|
||||||
:returns: A ScoreResponse object containing rows and aggregated results.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
37
src/llama_stack/apis/scoring/scoring_service.py
Normal file
37
src/llama_stack/apis/scoring/scoring_service.py
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||||
|
|
||||||
|
from .models import ScoreBatchResponse, ScoreResponse
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringFunctionStore(Protocol):
|
||||||
|
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ScoringService(Protocol):
|
||||||
|
scoring_function_store: ScoringFunctionStore
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: dict[str, ScoringFnParams | None],
|
||||||
|
save_results_dataset: bool = False,
|
||||||
|
) -> ScoreBatchResponse:
|
||||||
|
"""Score a batch of rows."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
input_rows: list[dict[str, Any]],
|
||||||
|
scoring_functions: dict[str, ScoringFnParams | None],
|
||||||
|
) -> ScoreResponse:
|
||||||
|
"""Score a list of rows."""
|
||||||
|
...
|
||||||
|
|
@ -4,4 +4,38 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .scoring_functions import *
|
# Import routes to trigger router registration
|
||||||
|
from . import routes # noqa: F401
|
||||||
|
from .models import (
|
||||||
|
AggregationFunctionType,
|
||||||
|
BasicScoringFnParams,
|
||||||
|
CommonScoringFnFields,
|
||||||
|
ListScoringFunctionsResponse,
|
||||||
|
LLMAsJudgeScoringFnParams,
|
||||||
|
RegexParserScoringFnParams,
|
||||||
|
RegisterScoringFunctionRequest,
|
||||||
|
ScoringFn,
|
||||||
|
ScoringFnInput,
|
||||||
|
ScoringFnParams,
|
||||||
|
ScoringFnParamsType,
|
||||||
|
)
|
||||||
|
from .scoring_functions_service import ScoringFunctionsService
|
||||||
|
|
||||||
|
# Backward compatibility - export ScoringFunctions as alias for ScoringFunctionsService
|
||||||
|
ScoringFunctions = ScoringFunctionsService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ScoringFunctions",
|
||||||
|
"ScoringFunctionsService",
|
||||||
|
"ScoringFn",
|
||||||
|
"ScoringFnInput",
|
||||||
|
"CommonScoringFnFields",
|
||||||
|
"ScoringFnParams",
|
||||||
|
"ScoringFnParamsType",
|
||||||
|
"LLMAsJudgeScoringFnParams",
|
||||||
|
"RegexParserScoringFnParams",
|
||||||
|
"BasicScoringFnParams",
|
||||||
|
"AggregationFunctionType",
|
||||||
|
"ListScoringFunctionsResponse",
|
||||||
|
"RegisterScoringFunctionRequest",
|
||||||
|
]
|
||||||
|
|
|
||||||
143
src/llama_stack/apis/scoring_functions/models.py
Normal file
143
src/llama_stack/apis/scoring_functions/models.py
Normal file
|
|
@ -0,0 +1,143 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoringFnParamsType(StrEnum):
|
||||||
|
"""Types of scoring function parameter configurations."""
|
||||||
|
|
||||||
|
llm_as_judge = "llm_as_judge"
|
||||||
|
regex_parser = "regex_parser"
|
||||||
|
basic = "basic"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AggregationFunctionType(StrEnum):
|
||||||
|
"""Types of aggregation functions for scoring results."""
|
||||||
|
|
||||||
|
average = "average"
|
||||||
|
weighted_average = "weighted_average"
|
||||||
|
median = "median"
|
||||||
|
categorical_count = "categorical_count"
|
||||||
|
accuracy = "accuracy"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
|
"""Parameters for LLM-as-judge scoring function configuration."""
|
||||||
|
|
||||||
|
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
|
||||||
|
judge_model: str
|
||||||
|
prompt_template: str | None = None
|
||||||
|
judge_score_regexes: list[str] = Field(
|
||||||
|
description="Regexes to extract the answer from generated response",
|
||||||
|
default_factory=lambda: [],
|
||||||
|
)
|
||||||
|
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||||
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
|
default_factory=lambda: [],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RegexParserScoringFnParams(BaseModel):
|
||||||
|
"""Parameters for regex parser scoring function configuration."""
|
||||||
|
|
||||||
|
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
|
||||||
|
parsing_regexes: list[str] = Field(
|
||||||
|
description="Regex to extract the answer from generated response",
|
||||||
|
default_factory=lambda: [],
|
||||||
|
)
|
||||||
|
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||||
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
|
default_factory=lambda: [],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BasicScoringFnParams(BaseModel):
|
||||||
|
"""Parameters for basic scoring function configuration."""
|
||||||
|
|
||||||
|
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
|
||||||
|
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||||
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
|
default_factory=list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ScoringFnParams = Annotated[
|
||||||
|
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||||
|
|
||||||
|
|
||||||
|
class CommonScoringFnFields(BaseModel):
|
||||||
|
description: str | None = None
|
||||||
|
metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Any additional metadata for this definition",
|
||||||
|
)
|
||||||
|
return_type: ParamType = Field(
|
||||||
|
description="The return type of the deterministic function",
|
||||||
|
)
|
||||||
|
params: ScoringFnParams | None = Field(
|
||||||
|
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoringFn(CommonScoringFnFields, Resource):
|
||||||
|
"""A scoring function resource for evaluating model outputs."""
|
||||||
|
|
||||||
|
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scoring_fn_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_scoring_fn_id(self) -> str | None:
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||||
|
scoring_fn_id: str
|
||||||
|
provider_id: str | None = None
|
||||||
|
provider_scoring_fn_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListScoringFunctionsResponse(BaseModel):
|
||||||
|
"""Response model for listing scoring functions."""
|
||||||
|
|
||||||
|
data: list[ScoringFn] = Field(..., description="List of scoring function resources")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RegisterScoringFunctionRequest(BaseModel):
|
||||||
|
"""Request model for registering a scoring function."""
|
||||||
|
|
||||||
|
scoring_fn_id: str = Field(..., description="The ID of the scoring function to register")
|
||||||
|
description: str = Field(..., description="The description of the scoring function")
|
||||||
|
return_type: ParamType = Field(..., description="The return type of the scoring function")
|
||||||
|
provider_scoring_fn_id: str | None = Field(
|
||||||
|
default=None, description="The ID of the provider scoring function to use for the scoring function"
|
||||||
|
)
|
||||||
|
provider_id: str | None = Field(default=None, description="The ID of the provider to use for the scoring function")
|
||||||
|
params: ScoringFnParams | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||||
|
)
|
||||||
111
src/llama_stack/apis/scoring_functions/routes.py
Normal file
111
src/llama_stack/apis/scoring_functions/routes.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Body, Depends, Request
|
||||||
|
from fastapi import Path as FastAPIPath
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
|
from llama_stack.core.server.router_utils import standard_responses
|
||||||
|
from llama_stack.core.server.routers import APIRouter, register_router
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
ListScoringFunctionsResponse,
|
||||||
|
RegisterScoringFunctionRequest,
|
||||||
|
ScoringFn,
|
||||||
|
)
|
||||||
|
from .scoring_functions_service import ScoringFunctionsService
|
||||||
|
|
||||||
|
|
||||||
|
def get_scoring_functions_service(request: Request) -> ScoringFunctionsService:
|
||||||
|
"""Dependency to get the scoring functions service implementation from app state."""
|
||||||
|
impls = getattr(request.app.state, "impls", {})
|
||||||
|
if Api.scoring_functions not in impls:
|
||||||
|
raise ValueError("Scoring Functions API implementation not found")
|
||||||
|
return impls[Api.scoring_functions]
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||||
|
tags=["Scoring Functions"],
|
||||||
|
responses=standard_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/scoring-functions",
|
||||||
|
response_model=ListScoringFunctionsResponse,
|
||||||
|
summary="List all scoring functions",
|
||||||
|
description="List all scoring functions",
|
||||||
|
)
|
||||||
|
async def list_scoring_functions(
|
||||||
|
svc: ScoringFunctionsService = Depends(get_scoring_functions_service),
|
||||||
|
) -> ListScoringFunctionsResponse:
|
||||||
|
"""List all scoring functions."""
|
||||||
|
return await svc.list_scoring_functions()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/scoring-functions/{scoring_fn_id:path}",
|
||||||
|
response_model=ScoringFn,
|
||||||
|
summary="Get a scoring function by its ID",
|
||||||
|
description="Get a scoring function by its ID",
|
||||||
|
)
|
||||||
|
async def get_scoring_function(
|
||||||
|
scoring_fn_id: Annotated[str, FastAPIPath(..., description="The ID of the scoring function to get")],
|
||||||
|
svc: ScoringFunctionsService = Depends(get_scoring_functions_service),
|
||||||
|
) -> ScoringFn:
|
||||||
|
"""Get a scoring function by its ID."""
|
||||||
|
return await svc.get_scoring_function(scoring_fn_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/scoring-functions",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Register a scoring function",
|
||||||
|
description="Register a scoring function",
|
||||||
|
)
|
||||||
|
async def register_scoring_function(
|
||||||
|
body: RegisterScoringFunctionRequest = Body(...),
|
||||||
|
svc: ScoringFunctionsService = Depends(get_scoring_functions_service),
|
||||||
|
) -> None:
|
||||||
|
"""Register a scoring function."""
|
||||||
|
return await svc.register_scoring_function(
|
||||||
|
scoring_fn_id=body.scoring_fn_id,
|
||||||
|
description=body.description,
|
||||||
|
return_type=body.return_type,
|
||||||
|
provider_scoring_fn_id=body.provider_scoring_fn_id,
|
||||||
|
provider_id=body.provider_id,
|
||||||
|
params=body.params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/scoring-functions/{scoring_fn_id:path}",
|
||||||
|
response_model=None,
|
||||||
|
status_code=204,
|
||||||
|
summary="Unregister a scoring function",
|
||||||
|
description="Unregister a scoring function",
|
||||||
|
)
|
||||||
|
async def unregister_scoring_function(
|
||||||
|
scoring_fn_id: Annotated[str, FastAPIPath(..., description="The ID of the scoring function to unregister")],
|
||||||
|
svc: ScoringFunctionsService = Depends(get_scoring_functions_service),
|
||||||
|
) -> None:
|
||||||
|
"""Unregister a scoring function."""
|
||||||
|
await svc.unregister_scoring_function(scoring_fn_id)
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility with the router registry system
|
||||||
|
def create_scoring_functions_router(impl_getter) -> APIRouter:
|
||||||
|
"""Create a FastAPI router for the Scoring Functions API (legacy compatibility)."""
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
# Register the router factory
|
||||||
|
register_router(Api.scoring_functions, create_scoring_functions_router)
|
||||||
|
|
@ -1,208 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
Literal,
|
|
||||||
Protocol,
|
|
||||||
runtime_checkable,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
|
||||||
# with standard metrics so they can be rolled up?
|
|
||||||
@json_schema_type
|
|
||||||
class ScoringFnParamsType(StrEnum):
|
|
||||||
"""Types of scoring function parameter configurations.
|
|
||||||
:cvar llm_as_judge: Use an LLM model to evaluate and score responses
|
|
||||||
:cvar regex_parser: Use regex patterns to extract and score specific parts of responses
|
|
||||||
:cvar basic: Basic scoring with simple aggregation functions
|
|
||||||
"""
|
|
||||||
|
|
||||||
llm_as_judge = "llm_as_judge"
|
|
||||||
regex_parser = "regex_parser"
|
|
||||||
basic = "basic"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AggregationFunctionType(StrEnum):
|
|
||||||
"""Types of aggregation functions for scoring results.
|
|
||||||
:cvar average: Calculate the arithmetic mean of scores
|
|
||||||
:cvar weighted_average: Calculate a weighted average of scores
|
|
||||||
:cvar median: Calculate the median value of scores
|
|
||||||
:cvar categorical_count: Count occurrences of categorical values
|
|
||||||
:cvar accuracy: Calculate accuracy as the proportion of correct answers
|
|
||||||
"""
|
|
||||||
|
|
||||||
average = "average"
|
|
||||||
weighted_average = "weighted_average"
|
|
||||||
median = "median"
|
|
||||||
categorical_count = "categorical_count"
|
|
||||||
accuracy = "accuracy"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
|
||||||
"""Parameters for LLM-as-judge scoring function configuration.
|
|
||||||
:param type: The type of scoring function parameters, always llm_as_judge
|
|
||||||
:param judge_model: Identifier of the LLM model to use as a judge for scoring
|
|
||||||
:param prompt_template: (Optional) Custom prompt template for the judge model
|
|
||||||
:param judge_score_regexes: Regexes to extract the answer from generated response
|
|
||||||
:param aggregation_functions: Aggregation functions to apply to the scores of each row
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
|
|
||||||
judge_model: str
|
|
||||||
prompt_template: str | None = None
|
|
||||||
judge_score_regexes: list[str] = Field(
|
|
||||||
description="Regexes to extract the answer from generated response",
|
|
||||||
default_factory=lambda: [],
|
|
||||||
)
|
|
||||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
|
||||||
default_factory=lambda: [],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RegexParserScoringFnParams(BaseModel):
|
|
||||||
"""Parameters for regex parser scoring function configuration.
|
|
||||||
:param type: The type of scoring function parameters, always regex_parser
|
|
||||||
:param parsing_regexes: Regex to extract the answer from generated response
|
|
||||||
:param aggregation_functions: Aggregation functions to apply to the scores of each row
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
|
|
||||||
parsing_regexes: list[str] = Field(
|
|
||||||
description="Regex to extract the answer from generated response",
|
|
||||||
default_factory=lambda: [],
|
|
||||||
)
|
|
||||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
|
||||||
default_factory=lambda: [],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BasicScoringFnParams(BaseModel):
|
|
||||||
"""Parameters for basic scoring function configuration.
|
|
||||||
:param type: The type of scoring function parameters, always basic
|
|
||||||
:param aggregation_functions: Aggregation functions to apply to the scores of each row
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
|
|
||||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
|
||||||
default_factory=list,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ScoringFnParams = Annotated[
|
|
||||||
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
register_schema(ScoringFnParams, name="ScoringFnParams")
|
|
||||||
|
|
||||||
|
|
||||||
class CommonScoringFnFields(BaseModel):
|
|
||||||
description: str | None = None
|
|
||||||
metadata: dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Any additional metadata for this definition",
|
|
||||||
)
|
|
||||||
return_type: ParamType = Field(
|
|
||||||
description="The return type of the deterministic function",
|
|
||||||
)
|
|
||||||
params: ScoringFnParams | None = Field(
|
|
||||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ScoringFn(CommonScoringFnFields, Resource):
|
|
||||||
"""A scoring function resource for evaluating model outputs.
|
|
||||||
:param type: The resource type, always scoring_function
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
|
|
||||||
|
|
||||||
@property
|
|
||||||
def scoring_fn_id(self) -> str:
|
|
||||||
return self.identifier
|
|
||||||
|
|
||||||
@property
|
|
||||||
def provider_scoring_fn_id(self) -> str | None:
|
|
||||||
return self.provider_resource_id
|
|
||||||
|
|
||||||
|
|
||||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
|
||||||
scoring_fn_id: str
|
|
||||||
provider_id: str | None = None
|
|
||||||
provider_scoring_fn_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ListScoringFunctionsResponse(BaseModel):
|
|
||||||
data: list[ScoringFn]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class ScoringFunctions(Protocol):
|
|
||||||
@webmethod(route="/scoring-functions", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
|
||||||
"""List all scoring functions.
|
|
||||||
|
|
||||||
:returns: A ListScoringFunctionsResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
|
|
||||||
"""Get a scoring function by its ID.
|
|
||||||
|
|
||||||
:param scoring_fn_id: The ID of the scoring function to get.
|
|
||||||
:returns: A ScoringFn.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def register_scoring_function(
|
|
||||||
self,
|
|
||||||
scoring_fn_id: str,
|
|
||||||
description: str,
|
|
||||||
return_type: ParamType,
|
|
||||||
provider_scoring_fn_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: ScoringFnParams | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Register a scoring function.
|
|
||||||
|
|
||||||
:param scoring_fn_id: The ID of the scoring function to register.
|
|
||||||
:param description: The description of the scoring function.
|
|
||||||
:param return_type: The return type of the scoring function.
|
|
||||||
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
|
|
||||||
:param provider_id: The ID of the provider to use for the scoring function.
|
|
||||||
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
|
||||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
|
||||||
"""Unregister a scoring function.
|
|
||||||
|
|
||||||
:param scoring_fn_id: The ID of the scoring function to unregister.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue