Merge branch 'main' into nvidia-e2e-notebook

This commit is contained in:
Jash Gulabrai 2025-05-19 09:23:07 -04:00
commit 51b68b4be6
234 changed files with 21943 additions and 7540 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from collections.abc import AsyncIterator
from datetime import datetime
from enum import Enum
@ -12,6 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
@ -29,12 +31,20 @@ from llama_stack.apis.tools import ToolDef
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from .openai_responses import (
OpenAIResponseInputMessage,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
OpenAIResponseObjectStream,
)
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
class Attachment(BaseModel):
"""An attachment to an agent turn.
@ -73,7 +83,7 @@ class StepCommon(BaseModel):
completed_at: datetime | None = None
class StepType(Enum):
class StepType(StrEnum):
"""Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM.
@ -97,7 +107,7 @@ class InferenceStep(StepCommon):
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value
step_type: Literal[StepType.inference] = StepType.inference
model_response: CompletionMessage
@ -109,7 +119,7 @@ class ToolExecutionStep(StepCommon):
:param tool_responses: The tool responses from the tool calls.
"""
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
tool_calls: list[ToolCall]
tool_responses: list[ToolResponse]
@ -121,7 +131,7 @@ class ShieldCallStep(StepCommon):
:param violation: The violation from the shield call.
"""
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
step_type: Literal[StepType.shield_call] = StepType.shield_call
violation: SafetyViolation | None
@ -133,7 +143,7 @@ class MemoryRetrievalStep(StepCommon):
:param inserted_context: The context retrieved from the vector databases.
"""
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]?
vector_db_ids: str
inserted_context: InterleavedContent
@ -154,7 +164,7 @@ class Turn(BaseModel):
input_messages: list[UserMessage | ToolResponseMessage]
steps: list[Step]
output_message: CompletionMessage
output_attachments: list[Attachment] | None = Field(default_factory=list)
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
started_at: datetime
completed_at: datetime | None = None
@ -182,10 +192,10 @@ register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: list[str] | None = Field(default_factory=list)
output_shields: list[str] | None = Field(default_factory=list)
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
client_tools: list[ToolDef] | None = Field(default_factory=list)
input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None)
@ -232,21 +242,11 @@ class Agent(BaseModel):
created_at: datetime
@json_schema_type
class ListAgentsResponse(BaseModel):
data: list[Agent]
@json_schema_type
class ListAgentSessionsResponse(BaseModel):
data: list[Session]
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = None
class AgentTurnResponseEventType(Enum):
class AgentTurnResponseEventType(StrEnum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
@ -258,15 +258,15 @@ class AgentTurnResponseEventType(Enum):
@json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
step_type: StepType
step_id: str
metadata: dict[str, Any] | None = Field(default_factory=dict)
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
@json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
step_type: StepType
step_id: str
step_details: Step
@ -276,7 +276,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
step_type: StepType
step_id: str
@ -285,21 +285,19 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
@json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
turn_id: str
@json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
turn: Turn
@json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
AgentTurnResponseEventType.turn_awaiting_input.value
)
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
turn: Turn
@ -341,7 +339,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
messages: list[UserMessage | ToolResponseMessage]
documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = None
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
stream: bool | None = False
tool_config: ToolConfig | None = None
@ -415,8 +413,9 @@ class Agents(Protocol):
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
@ -510,6 +509,7 @@ class Agents(Protocol):
:param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by.
:returns: A Session.
"""
...
@ -519,7 +519,7 @@ class Agents(Protocol):
session_id: str,
agent_id: str,
) -> None:
"""Delete an agent session by its ID.
"""Delete an agent session by its ID and its associated turns.
:param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for.
@ -531,17 +531,19 @@ class Agents(Protocol):
self,
agent_id: str,
) -> None:
"""Delete an agent by its ID.
"""Delete an agent by its ID and its associated sessions and turns.
:param agent_id: The ID of the agent to delete.
"""
...
@webmethod(route="/agents", method="GET")
async def list_agents(self) -> ListAgentsResponse:
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents.
:returns: A ListAgentsResponse.
:param start_index: The index to start the pagination from.
:param limit: The number of agents to return.
:returns: A PaginatedResponse.
"""
...
@ -558,11 +560,15 @@ class Agents(Protocol):
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
:returns: A ListAgentSessionsResponse.
:param start_index: The index to start the pagination from.
:param limit: The number of sessions to return.
:returns: A PaginatedResponse.
"""
...
@ -588,7 +594,7 @@ class Agents(Protocol):
@webmethod(route="/openai/v1/responses", method="POST")
async def create_openai_response(
self,
input: str | list[OpenAIResponseInputMessage],
input: str | list[OpenAIResponseInput],
model: str,
previous_response_id: str | None = None,
store: bool | None = True,
@ -601,4 +607,6 @@ class Agents(Protocol):
:param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:returns: An OpenAIResponseObject.
"""
...

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Literal
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
@ -17,6 +17,28 @@ class OpenAIResponseError(BaseModel):
message: str
@json_schema_type
class OpenAIResponseInputMessageContentText(BaseModel):
text: str
type: Literal["input_text"] = "input_text"
@json_schema_type
class OpenAIResponseInputMessageContentImage(BaseModel):
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
type: Literal["input_image"] = "input_image"
# TODO: handle file_id
image_url: str | None = None
# TODO: handle file content types
OpenAIResponseInputMessageContent = Annotated[
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
@json_schema_type
class OpenAIResponseOutputMessageContentOutputText(BaseModel):
text: str
@ -31,13 +53,22 @@ register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMe
@json_schema_type
class OpenAIResponseOutputMessage(BaseModel):
id: str
content: list[OpenAIResponseOutputMessageContent]
role: Literal["assistant"] = "assistant"
status: str
class OpenAIResponseMessage(BaseModel):
"""
Corresponds to the various Message types in the Responses API.
They are all under one type because the Responses API gives them all
the same "type" value, and there is no way to tell them apart in certain
scenarios.
"""
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
type: Literal["message"] = "message"
# The fields below are not used in all scenarios, but are required in others.
id: str | None = None
status: str | None = None
@json_schema_type
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
@ -46,8 +77,18 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
type: Literal["web_search_call"] = "web_search_call"
@json_schema_type
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
arguments: str
call_id: str
name: str
type: Literal["function_call"] = "function_call"
id: str
status: str
OpenAIResponseOutput = Annotated[
OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@ -90,32 +131,29 @@ register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
@json_schema_type
class OpenAIResponseInputMessageContentText(BaseModel):
text: str
type: Literal["input_text"] = "input_text"
class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
"""
This represents the output of a function call that gets passed back to the model.
"""
call_id: str
output: str
type: Literal["function_call_output"] = "function_call_output"
id: str | None = None
status: str | None = None
@json_schema_type
class OpenAIResponseInputMessageContentImage(BaseModel):
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
type: Literal["input_image"] = "input_image"
# TODO: handle file_id
image_url: str | None = None
# TODO: handle file content types
OpenAIResponseInputMessageContent = Annotated[
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
Field(discriminator="type"),
OpenAIResponseInput = Annotated[
# Responses API allows output messages to be passed in as input
OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseInputFunctionToolCallOutput
|
# Fallback to the generic message type as a last resort
OpenAIResponseMessage,
Field(union_mode="left_to_right"),
]
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
@json_schema_type
class OpenAIResponseInputMessage(BaseModel):
content: str | list[OpenAIResponseInputMessageContent]
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
type: Literal["message"] | None = "message"
register_schema(OpenAIResponseInput, name="OpenAIResponseInput")
@json_schema_type
@ -126,8 +164,35 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
# TODO: add user_location
@json_schema_type
class OpenAIResponseInputToolFunction(BaseModel):
type: Literal["function"] = "function"
name: str
description: str | None = None
parameters: dict[str, Any] | None
strict: bool | None = None
class FileSearchRankingOptions(BaseModel):
ranker: str | None = None
score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
@json_schema_type
class OpenAIResponseInputToolFileSearch(BaseModel):
type: Literal["file_search"] = "file_search"
vector_store_id: list[str]
ranking_options: FileSearchRankingOptions | None = None
# TODO: add filters
OpenAIResponseInputTool = Annotated[
OpenAIResponseInputToolWebSearch,
OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction,
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
class OpenAIResponseInputItemList(BaseModel):
data: list[OpenAIResponseInput]
object: Literal["list"] = "list"

View file

@ -38,7 +38,17 @@ class BatchInference(Protocol):
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> Job: ...
) -> Job:
"""Generate completions for a batch of content.
:param model: The model to use for the completion.
:param content_batch: The content to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param response_format: The response format to use for the completion.
:param logprobs: The logprobs to use for the completion.
:returns: A job for the completion.
"""
...
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def chat_completion(
@ -52,4 +62,17 @@ class BatchInference(Protocol):
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> Job: ...
) -> Job:
"""Generate chat completions for a batch of messages.
:param model: The model to use for the chat completion.
:param messages_batch: The messages to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param tools: The tools to use for the chat completion.
:param tool_choice: The tool choice to use for the chat completion.
:param tool_prompt_format: The tool prompt format to use for the chat completion.
:param response_format: The response format to use for the chat completion.
:param logprobs: The logprobs to use for the chat completion.
:returns: A job for the chat completion.
"""
...

View file

@ -22,14 +22,14 @@ class CommonBenchmarkFields(BaseModel):
@json_schema_type
class Benchmark(CommonBenchmarkFields, Resource):
type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value
type: Literal[ResourceType.benchmark] = ResourceType.benchmark
@property
def benchmark_id(self) -> str:
return self.identifier
@property
def provider_benchmark_id(self) -> str:
def provider_benchmark_id(self) -> str | None:
return self.provider_resource_id
@ -46,13 +46,24 @@ class ListBenchmarksResponse(BaseModel):
@runtime_checkable
class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET")
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
async def list_benchmarks(self) -> ListBenchmarksResponse:
"""List all benchmarks.
:returns: A ListBenchmarksResponse.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
async def get_benchmark(
self,
benchmark_id: str,
) -> Benchmark: ...
) -> Benchmark:
"""Get a benchmark by its ID.
:param benchmark_id: The ID of the benchmark to get.
:returns: A Benchmark.
"""
...
@webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark(
@ -63,4 +74,14 @@ class Benchmarks(Protocol):
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None: ...
) -> None:
"""Register a benchmark.
:param benchmark_id: The ID of the benchmark to register.
:param dataset_id: The ID of the dataset to use for the benchmark.
:param scoring_functions: The scoring functions to use for the benchmark.
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
:param provider_id: The ID of the provider to use for the benchmark.
:param metadata: The metadata to use for the benchmark.
"""
...

View file

@ -28,7 +28,7 @@ class _URLOrData(BaseModel):
url: URL | None = None
# data is a base64 encoded string, hint with contentEncoding=base64
data: str | None = Field(contentEncoding="base64", default=None)
data: str | None = Field(default=None, json_schema_extra={"contentEncoding": "base64"})
@model_validator(mode="before")
@classmethod

View file

@ -34,14 +34,21 @@ class DatasetIO(Protocol):
- limit: Number of items to return. If None or -1, returns all items.
The response includes:
- data: List of items for the current page
- has_more: Whether there are more items available after this set
- data: List of items for the current page.
- has_more: Whether there are more items available after this set.
:param dataset_id: The ID of the dataset to get the rows from.
:param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get.
:returns: A PaginatedResponse.
"""
...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ...
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
"""Append rows to a dataset.
:param dataset_id: The ID of the dataset to append the rows to.
:param rows: The rows to append to the dataset.
"""
...

View file

@ -106,14 +106,14 @@ class CommonDatasetFields(BaseModel):
@json_schema_type
class Dataset(CommonDatasetFields, Resource):
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value
type: Literal[ResourceType.dataset] = ResourceType.dataset
@property
def dataset_id(self) -> str:
return self.identifier
@property
def provider_dataset_id(self) -> str:
def provider_dataset_id(self) -> str | None:
return self.provider_resource_id
@ -137,7 +137,8 @@ class Datasets(Protocol):
"""
Register a new dataset.
:param purpose: The purpose of the dataset. One of
:param purpose: The purpose of the dataset.
One of:
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
{
"messages": [
@ -188,8 +189,9 @@ class Datasets(Protocol):
]
}
:param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"}
- E.g. {"description": "My dataset"}.
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
:returns: A Dataset.
"""
...
@ -197,13 +199,29 @@ class Datasets(Protocol):
async def get_dataset(
self,
dataset_id: str,
) -> Dataset: ...
) -> Dataset:
"""Get a dataset by its ID.
:param dataset_id: The ID of the dataset to get.
:returns: A Dataset.
"""
...
@webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ...
async def list_datasets(self) -> ListDatasetsResponse:
"""List all datasets.
:returns: A ListDatasetsResponse.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
async def unregister_dataset(
self,
dataset_id: str,
) -> None: ...
) -> None:
"""Unregister a dataset by its ID.
:param dataset_id: The ID of the dataset to unregister.
"""
...

View file

@ -93,8 +93,9 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation.
:returns: The job that was created to run the evaluation.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows(
@ -110,8 +111,9 @@ class Eval(Protocol):
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
:returns: EvaluateResponse object containing generations and scores.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
@ -119,7 +121,7 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob.
:returns: The status of the evaluation job.
"""
...
@ -138,5 +140,6 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
:returns: The result of the job.
"""
...

View file

@ -91,10 +91,11 @@ class Files(Protocol):
"""
Create a new upload session for a file identified by a bucket and key.
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
:param mime_type: MIME type of the file
:param size: File size in bytes
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
:param mime_type: MIME type of the file.
:param size: File size in bytes.
:returns: A FileUploadResponse.
"""
...
@ -107,7 +108,8 @@ class Files(Protocol):
Upload file content to an existing upload session.
On the server, request body will have the raw bytes that are uploaded.
:param upload_id: ID of the upload session
:param upload_id: ID of the upload session.
:returns: A FileResponse or None if the upload is not complete.
"""
...
@ -117,9 +119,10 @@ class Files(Protocol):
upload_id: str,
) -> FileUploadResponse:
"""
Returns information about an existsing upload session
Returns information about an existsing upload session.
:param upload_id: ID of the upload session
:param upload_id: ID of the upload session.
:returns: A FileUploadResponse.
"""
...
@ -130,6 +133,9 @@ class Files(Protocol):
) -> ListBucketResponse:
"""
List all buckets.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:returns: A ListBucketResponse.
"""
...
@ -141,7 +147,8 @@ class Files(Protocol):
"""
List all files in a bucket.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:returns: A ListFileResponse.
"""
...
@ -154,8 +161,9 @@ class Files(Protocol):
"""
Get a file info identified by a bucket and key.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
:returns: A FileResponse.
"""
...
@ -168,7 +176,7 @@ class Files(Protocol):
"""
Delete a file identified by a bucket and key.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
"""
...

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from collections.abc import AsyncIterator
from enum import Enum
from typing import (
@ -35,6 +36,16 @@ register_schema(ToolCall)
register_schema(ToolParamDefinition)
register_schema(ToolDefinition)
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
@json_schema_type
class GreedySamplingStrategy(BaseModel):
@ -187,7 +198,7 @@ class CompletionMessage(BaseModel):
role: Literal["assistant"] = "assistant"
content: InterleavedContent
stop_reason: StopReason
tool_calls: list[ToolCall] | None = Field(default_factory=list)
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
Message = Annotated[
@ -267,7 +278,7 @@ class ChatCompletionResponseEvent(BaseModel):
stop_reason: StopReason | None = None
class ResponseFormatType(Enum):
class ResponseFormatType(StrEnum):
"""Types of formats for structured (guided) decoding.
:cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model.
@ -286,7 +297,7 @@ class JsonSchemaResponseFormat(BaseModel):
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
"""
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
type: Literal[ResponseFormatType.json_schema] = ResponseFormatType.json_schema
json_schema: dict[str, Any]
@ -298,7 +309,7 @@ class GrammarResponseFormat(BaseModel):
:param bnf: The BNF grammar specification the response should conform to
"""
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
type: Literal[ResponseFormatType.grammar] = ResponseFormatType.grammar
bnf: dict[str, Any]
@ -394,7 +405,7 @@ class ChatCompletionRequest(BaseModel):
messages: list[Message]
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
tools: list[ToolDefinition] | None = Field(default_factory=list)
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
response_format: ResponseFormat | None = None
@ -567,14 +578,14 @@ class OpenAIResponseFormatText(BaseModel):
@json_schema_type
class OpenAIJSONSchema(TypedDict, total=False):
name: str
description: str | None = None
strict: bool | None = None
description: str | None
strict: bool | None
# Pydantic BaseModel cannot be used with a schema param, since it already
# has one. And, we don't want to alias here because then have to handle
# that alias when converting to OpenAI params. So, to support schema,
# we use a TypedDict.
schema: dict[str, Any] | None = None
schema: dict[str, Any] | None
@json_schema_type
@ -809,15 +820,32 @@ class BatchChatCompletionResponse(BaseModel):
batch: list[ChatCompletionResponse]
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
input_messages: list[OpenAIMessageParam]
@json_schema_type
class ListOpenAIChatCompletionResponse(BaseModel):
data: list[OpenAICompletionWithInputMessages]
has_more: bool
first_id: str
last_id: str
object: Literal["list"] = "list"
class Order(Enum):
asc = "asc"
desc = "desc"
@runtime_checkable
@trace_protocol
class Inference(Protocol):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
class InferenceProvider(Protocol):
"""
This protocol defines the interface that should be implemented by all inference providers.
"""
API_NAMESPACE: str = "Inference"
model_store: ModelStore | None = None
@ -834,13 +862,13 @@ class Inference(Protocol):
"""Generate a completion for the given content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content: The content to generate a completion for
:param sampling_params: (Optional) Parameters to control the sampling strategy
:param response_format: (Optional) Grammar specification for guided (structured) decoding
:param content: The content to generate a completion for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: If stream=False, returns a CompletionResponse with the full completion.
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk.
"""
...
@ -853,6 +881,15 @@ class Inference(Protocol):
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
"""Generate completions for a batch of content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content_batch: The content to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch completion is not implemented")
@webmethod(route="/inference/chat-completion", method="POST")
@ -872,9 +909,9 @@ class Inference(Protocol):
"""Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation
:param sampling_params: Parameters to control the sampling strategy
:param tools: (Optional) List of tool definitions available to the model
:param messages: List of messages in the conversation.
:param sampling_params: Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model.
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
.. deprecated::
Use tool_config instead.
@ -891,7 +928,7 @@ class Inference(Protocol):
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:param tool_config: (Optional) Configuration for tool use.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
"""
...
@ -906,6 +943,17 @@ class Inference(Protocol):
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
"""Generate chat completions for a batch of messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages_batch: The messages to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model.
:param tool_config: (Optional) Configuration for tool use.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchChatCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch chat completion is not implemented")
@webmethod(route="/inference/embeddings", method="POST")
@ -924,7 +972,7 @@ class Inference(Protocol):
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
"""
...
@ -956,22 +1004,23 @@ class Inference(Protocol):
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for
:param best_of: (Optional) The number of completions to generate
:param echo: (Optional) Whether to echo the prompt
:param frequency_penalty: (Optional) The penalty for repeated tokens
:param logit_bias: (Optional) The logit bias to use
:param logprobs: (Optional) The log probabilities to use
:param max_tokens: (Optional) The maximum number of tokens to generate
:param n: (Optional) The number of completions to generate
:param presence_penalty: (Optional) The penalty for repeated tokens
:param seed: (Optional) The seed to use
:param stop: (Optional) The stop tokens to use
:param stream: (Optional) Whether to stream the response
:param stream_options: (Optional) The stream options to use
:param temperature: (Optional) The temperature to use
:param top_p: (Optional) The top p to use
:param user: (Optional) The user to use
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:returns: An OpenAICompletion.
"""
...
@ -1005,27 +1054,64 @@ class Inference(Protocol):
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation
:param frequency_penalty: (Optional) The penalty for repeated tokens
:param function_call: (Optional) The function call to use
:param functions: (Optional) List of functions to use
:param logit_bias: (Optional) The logit bias to use
:param logprobs: (Optional) The log probabilities to use
:param max_completion_tokens: (Optional) The maximum number of tokens to generate
:param max_tokens: (Optional) The maximum number of tokens to generate
:param n: (Optional) The number of completions to generate
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls
:param presence_penalty: (Optional) The penalty for repeated tokens
:param response_format: (Optional) The response format to use
:param seed: (Optional) The seed to use
:param stop: (Optional) The stop tokens to use
:param stream: (Optional) Whether to stream the response
:param stream_options: (Optional) The stream options to use
:param temperature: (Optional) The temperature to use
:param tool_choice: (Optional) The tool choice to use
:param tools: (Optional) The tools to use
:param top_logprobs: (Optional) The top log probabilities to use
:param top_p: (Optional) The top p to use
:param user: (Optional) The user to use
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:returns: An OpenAIChatCompletion.
"""
...
class Inference(InferenceProvider):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
"""
@webmethod(route="/openai/v1/chat/completions", method="GET")
async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 20,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
"""List all chat completions.
:param after: The ID of the last chat completion to return.
:param limit: The maximum number of chat completions to return.
:param model: The model to filter by.
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
:returns: A ListOpenAIChatCompletionResponse.
"""
raise NotImplementedError("List chat completions is not implemented")
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
"""Describe a chat completion by its ID.
:param completion_id: ID of the chat completion.
:returns: A OpenAICompletionWithInputMessages.
"""
raise NotImplementedError("Get chat completion is not implemented")

View file

@ -36,10 +36,25 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/inspect/routes", method="GET")
async def list_routes(self) -> ListRoutesResponse: ...
async def list_routes(self) -> ListRoutesResponse:
"""List all routes.
:returns: A ListRoutesResponse.
"""
...
@webmethod(route="/health", method="GET")
async def health(self) -> HealthInfo: ...
async def health(self) -> HealthInfo:
"""Get the health of the service.
:returns: A HealthInfo.
"""
...
@webmethod(route="/version", method="GET")
async def version(self) -> VersionInfo: ...
async def version(self) -> VersionInfo:
"""Get the version of the service.
:returns: A VersionInfo.
"""
...

View file

@ -29,14 +29,14 @@ class ModelType(str, Enum):
@json_schema_type
class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value
type: Literal[ResourceType.model] = ResourceType.model
@property
def model_id(self) -> str:
return self.identifier
@property
def provider_model_id(self) -> str:
def provider_model_id(self) -> str | None:
return self.provider_resource_id
model_config = ConfigDict(protected_namespaces=())
@ -80,16 +80,32 @@ class OpenAIListModelsResponse(BaseModel):
@trace_protocol
class Models(Protocol):
@webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ...
async def list_models(self) -> ListModelsResponse:
"""List all models.
:returns: A ListModelsResponse.
"""
...
@webmethod(route="/openai/v1/models", method="GET")
async def openai_list_models(self) -> OpenAIListModelsResponse: ...
async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API.
:returns: A OpenAIListModelsResponse.
"""
...
@webmethod(route="/models/{model_id:path}", method="GET")
async def get_model(
self,
model_id: str,
) -> Model: ...
) -> Model:
"""Get a model by its identifier.
:param model_id: The identifier of the model to get.
:returns: A Model.
"""
...
@webmethod(route="/models", method="POST")
async def register_model(
@ -99,10 +115,25 @@ class Models(Protocol):
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model: ...
) -> Model:
"""Register a model.
:param model_id: The identifier of the model to register.
:param provider_model_id: The identifier of the model in the provider.
:param provider_id: The identifier of the provider.
:param metadata: Any additional metadata for this model.
:param model_type: The type of model to register.
:returns: A Model.
"""
...
@webmethod(route="/models/{model_id:path}", method="DELETE")
async def unregister_model(
self,
model_id: str,
) -> None: ...
) -> None:
"""Unregister a model.
:param model_id: The identifier of the model to unregister.
"""
...

View file

@ -182,7 +182,19 @@ class PostTraining(Protocol):
),
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob: ...
) -> PostTrainingJob:
"""Run supervised fine-tuning of a model.
:param job_uuid: The UUID of the job to create.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:param model: The model to fine-tune.
:param checkpoint_dir: The directory to save checkpoint(s) to.
:param algorithm_config: The algorithm configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize(
@ -193,16 +205,49 @@ class PostTraining(Protocol):
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob: ...
) -> PostTrainingJob:
"""Run preference optimization of a model.
:param job_uuid: The UUID of the job to create.
:param finetuned_model: The model to fine-tune.
:param algorithm_config: The algorithm configuration.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all training jobs.
:returns: A ListPostTrainingJobsResponse.
"""
...
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
"""Get the status of a training job.
:param job_uuid: The UUID of the job to get the status of.
:returns: A PostTrainingJobStatusResponse.
"""
...
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancel a training job.
:param job_uuid: The UUID of the job to cancel.
"""
...
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job.
:param job_uuid: The UUID of the job to get the artifacts of.
:returns: A PostTrainingJobArtifactsResponse.
"""
...

View file

@ -32,7 +32,18 @@ class Providers(Protocol):
"""
@webmethod(route="/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
async def list_providers(self) -> ListProvidersResponse:
"""List all available providers.
:returns: A ListProvidersResponse containing information about all providers.
"""
...
@webmethod(route="/providers/{provider_id}", method="GET")
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get detailed information about a specific provider.
:param provider_id: The ID of the provider to inspect.
:returns: A ProviderInfo object containing the provider's details.
"""
...

View file

@ -4,12 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from enum import Enum
from pydantic import BaseModel, Field
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class ResourceType(Enum):
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
class ResourceType(StrEnum):
model = "model"
shield = "shield"
vector_db = "vector_db"
@ -25,9 +36,9 @@ class Resource(BaseModel):
identifier: str = Field(description="Unique identifier for this resource in llama stack")
provider_resource_id: str = Field(
description="Unique identifier for this resource in the provider",
provider_resource_id: str | None = Field(
default=None,
description="Unique identifier for this resource in the provider",
)
provider_id: str = Field(description="ID of the provider that owns this resource")

View file

@ -53,5 +53,13 @@ class Safety(Protocol):
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse: ...
params: dict[str, Any],
) -> RunShieldResponse:
"""Run a shield.
:param shield_id: The identifier of the shield to run.
:param messages: The messages to run the shield on.
:param params: The parameters of the shield.
:returns: A RunShieldResponse.
"""
...

View file

@ -61,7 +61,15 @@ class Scoring(Protocol):
dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None],
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
) -> ScoreBatchResponse:
"""Score a batch of rows.
:param dataset_id: The ID of the dataset to score.
:param scoring_functions: The scoring functions to use for the scoring.
:param save_results_dataset: Whether to save the results to a dataset.
:returns: A ScoreBatchResponse.
"""
...
@webmethod(route="/scoring/score", method="POST")
async def score(
@ -73,6 +81,6 @@ class Scoring(Protocol):
:param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring.
:return: ScoreResponse object containing rows and aggregated results
:returns: A ScoreResponse object containing rows and aggregated results.
"""
...

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# TODO: use enum.StrEnum when we drop support for python 3.10
import sys
from enum import Enum
from typing import (
Annotated,
@ -19,18 +21,27 @@ from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
@json_schema_type
class ScoringFnParamsType(Enum):
class ScoringFnParamsType(StrEnum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
basic = "basic"
@json_schema_type
class AggregationFunctionType(Enum):
class AggregationFunctionType(StrEnum):
average = "average"
weighted_average = "weighted_average"
median = "median"
@ -40,36 +51,36 @@ class AggregationFunctionType(Enum):
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
judge_model: str
prompt_template: str | None = None
judge_score_regexes: list[str] | None = Field(
judge_score_regexes: list[str] = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
default_factory=lambda: [],
)
aggregation_functions: list[AggregationFunctionType] | None = Field(
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
default_factory=lambda: [],
)
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
parsing_regexes: list[str] | None = Field(
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
parsing_regexes: list[str] = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
default_factory=lambda: [],
)
aggregation_functions: list[AggregationFunctionType] | None = Field(
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
default_factory=lambda: [],
)
@json_schema_type
class BasicScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
aggregation_functions: list[AggregationFunctionType] | None = Field(
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@ -99,14 +110,14 @@ class CommonScoringFnFields(BaseModel):
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
@property
def scoring_fn_id(self) -> str:
return self.identifier
@property
def provider_scoring_fn_id(self) -> str:
def provider_scoring_fn_id(self) -> str | None:
return self.provider_resource_id
@ -123,10 +134,21 @@ class ListScoringFunctionsResponse(BaseModel):
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
"""List all scoring functions.
:returns: A ListScoringFunctionsResponse.
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
"""Get a scoring function by its ID.
:param scoring_fn_id: The ID of the scoring function to get.
:returns: A ScoringFn.
"""
...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(
@ -137,4 +159,14 @@ class ScoringFunctions(Protocol):
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None: ...
) -> None:
"""Register a scoring function.
:param scoring_fn_id: The ID of the scoring function to register.
:param description: The description of the scoring function.
:param return_type: The return type of the scoring function.
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
:param provider_id: The ID of the provider to use for the scoring function.
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
"""
...

View file

@ -21,14 +21,14 @@ class CommonShieldFields(BaseModel):
class Shield(CommonShieldFields, Resource):
"""A safety shield resource that can be used to check content"""
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
type: Literal[ResourceType.shield] = ResourceType.shield
@property
def shield_id(self) -> str:
return self.identifier
@property
def provider_shield_id(self) -> str:
def provider_shield_id(self) -> str | None:
return self.provider_resource_id
@ -46,10 +46,21 @@ class ListShieldsResponse(BaseModel):
@trace_protocol
class Shields(Protocol):
@webmethod(route="/shields", method="GET")
async def list_shields(self) -> ListShieldsResponse: ...
async def list_shields(self) -> ListShieldsResponse:
"""List all shields.
:returns: A ListShieldsResponse.
"""
...
@webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Shield: ...
async def get_shield(self, identifier: str) -> Shield:
"""Get a shield by its identifier.
:param identifier: The identifier of the shield to get.
:returns: A Shield.
"""
...
@webmethod(route="/shields", method="POST")
async def register_shield(
@ -58,4 +69,13 @@ class Shields(Protocol):
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield: ...
) -> Shield:
"""Register a shield.
:param shield_id: The identifier of the shield to register.
:param provider_shield_id: The identifier of the shield in the provider.
:param provider_id: The identifier of the provider.
:param params: The parameters of the shield.
:returns: A Shield.
"""
...

View file

@ -37,7 +37,7 @@ class Span(BaseModel):
name: str
start_time: datetime
end_time: datetime | None = None
attributes: dict[str, Any] | None = Field(default_factory=dict)
attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
def set_attribute(self, key: str, value: Any):
if self.attributes is None:
@ -74,19 +74,19 @@ class EventCommon(BaseModel):
trace_id: str
span_id: str
timestamp: datetime
attributes: dict[str, Primitive] | None = Field(default_factory=dict)
attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
@json_schema_type
class UnstructuredLogEvent(EventCommon):
type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
message: str
severity: LogSeverity
@json_schema_type
class MetricEvent(EventCommon):
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
type: Literal[EventType.METRIC] = EventType.METRIC
metric: str # this would be an enum
value: int | float
unit: str
@ -131,14 +131,14 @@ class StructuredLogType(Enum):
@json_schema_type
class SpanStartPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
name: str
parent_span_id: str | None = None
@json_schema_type
class SpanEndPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
status: SpanStatus
@ -151,7 +151,7 @@ register_schema(StructuredLogPayload, name="StructuredLogPayload")
@json_schema_type
class StructuredLogEvent(EventCommon):
type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
payload: StructuredLogPayload
@ -203,10 +203,61 @@ class QuerySpanTreeResponse(BaseModel):
data: dict[str, SpanWithStatus]
class MetricQueryType(Enum):
RANGE = "range"
INSTANT = "instant"
class MetricLabelOperator(Enum):
EQUALS = "="
NOT_EQUALS = "!="
REGEX_MATCH = "=~"
REGEX_NOT_MATCH = "!~"
class MetricLabelMatcher(BaseModel):
name: str
value: str
operator: MetricLabelOperator = MetricLabelOperator.EQUALS
@json_schema_type
class MetricLabel(BaseModel):
name: str
value: str
@json_schema_type
class MetricDataPoint(BaseModel):
timestamp: int
value: float
@json_schema_type
class MetricSeries(BaseModel):
metric: str
labels: list[MetricLabel]
values: list[MetricDataPoint]
class QueryMetricsResponse(BaseModel):
data: list[MetricSeries]
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/events", method="POST")
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
async def log_event(
self,
event: Event,
ttl_seconds: int = DEFAULT_TTL_DAYS * 86400,
) -> None:
"""Log an event.
:param event: The event to log.
:param ttl_seconds: The time to live of the event.
"""
...
@webmethod(route="/telemetry/traces", method="POST")
async def query_traces(
@ -215,13 +266,35 @@ class Telemetry(Protocol):
limit: int | None = 100,
offset: int | None = 0,
order_by: list[str] | None = None,
) -> QueryTracesResponse: ...
) -> QueryTracesResponse:
"""Query traces.
:param attribute_filters: The attribute filters to apply to the traces.
:param limit: The limit of traces to return.
:param offset: The offset of the traces to return.
:param order_by: The order by of the traces to return.
:returns: A QueryTracesResponse.
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ...
async def get_trace(self, trace_id: str) -> Trace:
"""Get a trace by its ID.
:param trace_id: The ID of the trace to get.
:returns: A Trace.
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
async def get_span(self, trace_id: str, span_id: str) -> Span:
"""Get a span by its ID.
:param trace_id: The ID of the trace to get the span from.
:param span_id: The ID of the span to get.
:returns: A Span.
"""
...
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
async def get_span_tree(
@ -229,7 +302,15 @@ class Telemetry(Protocol):
span_id: str,
attributes_to_return: list[str] | None = None,
max_depth: int | None = None,
) -> QuerySpanTreeResponse: ...
) -> QuerySpanTreeResponse:
"""Get a span tree by its ID.
:param span_id: The ID of the span to get the tree from.
:param attributes_to_return: The attributes to return in the tree.
:param max_depth: The maximum depth of the tree.
:returns: A QuerySpanTreeResponse.
"""
...
@webmethod(route="/telemetry/spans", method="POST")
async def query_spans(
@ -237,7 +318,15 @@ class Telemetry(Protocol):
attribute_filters: list[QueryCondition],
attributes_to_return: list[str],
max_depth: int | None = None,
) -> QuerySpansResponse: ...
) -> QuerySpansResponse:
"""Query spans.
:param attribute_filters: The attribute filters to apply to the spans.
:param attributes_to_return: The attributes to return in the spans.
:param max_depth: The maximum depth of the tree.
:returns: A QuerySpansResponse.
"""
...
@webmethod(route="/telemetry/spans/export", method="POST")
async def save_spans_to_dataset(
@ -246,4 +335,34 @@ class Telemetry(Protocol):
attributes_to_save: list[str],
dataset_id: str,
max_depth: int | None = None,
) -> None: ...
) -> None:
"""Save spans to a dataset.
:param attribute_filters: The attribute filters to apply to the spans.
:param attributes_to_save: The attributes to save to the dataset.
:param dataset_id: The ID of the dataset to save the spans to.
:param max_depth: The maximum depth of the tree.
"""
...
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
async def query_metrics(
self,
metric_name: str,
start_time: int,
end_time: int | None = None,
granularity: str | None = "1d",
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:
"""Query metrics.
:param metric_name: The name of the metric to query.
:param start_time: The start time of the metric to query.
:param end_time: The end time of the metric to query.
:param granularity: The granularity of the metric to query.
:param query_type: The type of query to perform.
:param label_matchers: The label matchers to apply to the metric.
:returns: A QueryMetricsResponse.
"""
...

View file

@ -7,7 +7,7 @@
from enum import Enum
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
@ -67,11 +67,33 @@ register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type
class RAGQueryConfig(BaseModel):
"""
Configuration for the RAG query generation.
:param query_generator_config: Configuration for the query generator.
:param max_tokens_in_context: Maximum number of tokens in the context.
:param max_chunks: Maximum number of chunks to retrieve.
:param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
"""
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
max_tokens_in_context: int = 4096
max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
@field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str:
if "{chunk.content}" not in v:
raise ValueError("chunk_template must contain {chunk.content}")
if "{index}" not in v:
raise ValueError("chunk_template must contain {index}")
if len(v) == 0:
raise ValueError("chunk_template must not be empty")
return v
@runtime_checkable

View file

@ -36,7 +36,7 @@ class ToolHost(Enum):
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
type: Literal[ResourceType.tool] = ResourceType.tool
toolgroup_id: str
tool_host: ToolHost
description: str
@ -62,7 +62,7 @@ class ToolGroupInput(BaseModel):
@json_schema_type
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
type: Literal[ResourceType.tool_group] = ResourceType.tool_group
mcp_endpoint: URL | None = None
args: dict[str, Any] | None = None
@ -103,37 +103,65 @@ class ToolGroups(Protocol):
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
"""Register a tool group"""
"""Register a tool group.
:param toolgroup_id: The ID of the tool group to register.
:param provider_id: The ID of the provider to use for the tool group.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:param args: A dictionary of arguments to pass to the tool group.
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
async def get_tool_group(
self,
toolgroup_id: str,
) -> ToolGroup: ...
) -> ToolGroup:
"""Get a tool group by its ID.
:param toolgroup_id: The ID of the tool group to get.
:returns: A ToolGroup.
"""
...
@webmethod(route="/toolgroups", method="GET")
async def list_tool_groups(self) -> ListToolGroupsResponse:
"""List tool groups with optional provider"""
"""List tool groups with optional provider.
:returns: A ListToolGroupsResponse.
"""
...
@webmethod(route="/tools", method="GET")
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
"""List tools with optional tool group"""
"""List tools with optional tool group.
:param toolgroup_id: The ID of the tool group to list tools for.
:returns: A ListToolsResponse.
"""
...
@webmethod(route="/tools/{tool_name:path}", method="GET")
async def get_tool(
self,
tool_name: str,
) -> Tool: ...
) -> Tool:
"""Get a tool by its name.
:param tool_name: The name of the tool to get.
:returns: A Tool.
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
async def unregister_toolgroup(
self,
toolgroup_id: str,
) -> None:
"""Unregister a tool group"""
"""Unregister a tool group.
:param toolgroup_id: The ID of the tool group to unregister.
"""
...
@ -152,9 +180,21 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ...
) -> ListToolDefsResponse:
"""List all tools in the runtime.
:param tool_group_id: The ID of the tool group to list tools for.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:returns: A ListToolDefsResponse.
"""
...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
"""Run a tool with the given arguments"""
"""Run a tool with the given arguments.
:param tool_name: The name of the tool to invoke.
:param kwargs: A dictionary of arguments to pass to the tool.
:returns: A ToolInvocationResult.
"""
...

View file

@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class VectorDB(Resource):
type: Literal[ResourceType.vector_db.value] = ResourceType.vector_db.value
type: Literal[ResourceType.vector_db] = ResourceType.vector_db
embedding_model: str
embedding_dimension: int
@ -25,7 +25,7 @@ class VectorDB(Resource):
return self.identifier
@property
def provider_vector_db_id(self) -> str:
def provider_vector_db_id(self) -> str | None:
return self.provider_resource_id
@ -44,13 +44,24 @@ class ListVectorDBsResponse(BaseModel):
@trace_protocol
class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET")
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
async def list_vector_dbs(self) -> ListVectorDBsResponse:
"""List all vector databases.
:returns: A ListVectorDBsResponse.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
async def get_vector_db(
self,
vector_db_id: str,
) -> VectorDB: ...
) -> VectorDB:
"""Get a vector database by its identifier.
:param vector_db_id: The identifier of the vector database to get.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs", method="POST")
async def register_vector_db(
@ -60,7 +71,22 @@ class VectorDBs(Protocol):
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB: ...
) -> VectorDB:
"""Register a vector database.
:param vector_db_id: The identifier of the vector database to register.
:param embedding_model: The embedding model to use.
:param embedding_dimension: The dimension of the embedding model.
:param provider_id: The identifier of the provider.
:param provider_vector_db_id: The identifier of the vector database in the provider.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database.
:param vector_db_id: The identifier of the vector database to unregister.
"""
...

View file

@ -46,7 +46,14 @@ class VectorIO(Protocol):
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None: ...
) -> None:
"""Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into.
:param chunks: The chunks to insert.
:param ttl_seconds: The time to live of the chunks.
"""
...
@webmethod(route="/vector-io/query", method="POST")
async def query_chunks(
@ -54,4 +61,12 @@ class VectorIO(Protocol):
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ...
) -> QueryChunksResponse:
"""Query chunks from a vector database.
:param vector_db_id: The identifier of the vector database to query.
:param query: The query to search for.
:param params: The parameters of the query.
:returns: A QueryChunksResponse.
"""
...

View file

@ -38,7 +38,10 @@ class LlamaCLIParser:
print_subcommand_description(self.parser, subparsers)
def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()
args = self.parser.parse_args()
if not isinstance(args, argparse.Namespace):
raise TypeError(f"Expected argparse.Namespace, got {type(args)}")
return args
def run(self, args: argparse.Namespace) -> None:
args.func(args)

View file

@ -12,6 +12,7 @@ import shutil
import sys
import textwrap
from functools import lru_cache
from importlib.abc import Traversable
from pathlib import Path
import yaml
@ -36,7 +37,8 @@ from llama_stack.distribution.datatypes import (
)
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.stack import replace_env_vars
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType
@ -202,7 +204,11 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
else:
with open(args.config) as f:
try:
build_config = BuildConfig(**yaml.safe_load(f))
contents = yaml.safe_load(f)
contents = replace_env_vars(contents)
build_config = BuildConfig(**contents)
if args.image_type:
build_config.image_type = args.image_type
except Exception as e:
cprint(
f"Could not parse config file {args.config}: {e}",
@ -245,11 +251,12 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
sys.exit(1)
if args.run:
run_config = Path(run_config)
config_dict = yaml.safe_load(run_config.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(config.external_providers_dir):
os.makedirs(config.external_providers_dir, exist_ok=True)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
run_command(run_args)
@ -257,7 +264,7 @@ def _generate_run_config(
build_config: BuildConfig,
build_dir: Path,
image_name: str,
) -> str:
) -> Path:
"""
Generate a run.yaml template file for user to edit from a build.yaml file
"""
@ -267,7 +274,9 @@ def _generate_run_config(
image_name=image_name,
apis=apis,
providers={},
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None,
external_providers_dir=build_config.external_providers_dir
if build_config.external_providers_dir
else EXTERNAL_PROVIDERS_DIR,
)
# build providers dict
provider_registry = get_provider_registry(build_config)
@ -334,7 +343,7 @@ def _run_stack_build_command_from_build_config(
image_name: str | None = None,
template_name: str | None = None,
config_path: str | None = None,
) -> str:
) -> Path | Traversable:
image_name = image_name or build_config.image_name
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
if template_name:

View file

@ -49,7 +49,7 @@ class StackBuild(Subcommand):
type=str,
help="Image Type to use for the build. If not specified, will use the image type from the template config.",
choices=[e.value for e in ImageType],
default=ImageType.CONDA.value,
default=None, # no default so we can detect if a user specified --image-type and override image_type in the config
)
self.parser.add_argument(

View file

@ -46,7 +46,7 @@ class StackListProviders(Subcommand):
else:
providers = [(k.value, prov) for k, prov in all_providers.items()]
providers = [p for api, p in providers if api in self.providable_apis]
providers = [(api, p) for api, p in providers if api in self.providable_apis]
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
@ -57,7 +57,7 @@ class StackListProviders(Subcommand):
rows = []
specs = [spec for p in providers for spec in p.values()]
specs = [spec for api, p in providers for spec in p.values()]
for spec in specs:
if spec.is_sample:
continue
@ -65,7 +65,7 @@ class StackListProviders(Subcommand):
[
spec.api.value,
spec.provider_type,
",".join(spec.pip_packages),
",".join(spec.pip_packages) if hasattr(spec, "pip_packages") else "",
]
)
print_table(

View file

@ -33,7 +33,8 @@ class StackRun(Subcommand):
self.parser.add_argument(
"config",
type=str,
help="Path to config file to use for the run",
nargs="?", # Make it optional
help="Path to config file to use for the run. Required for venv and conda environments.",
)
self.parser.add_argument(
"--port",
@ -47,28 +48,12 @@ class StackRun(Subcommand):
default=os.environ.get("CONDA_DEFAULT_ENV"),
help="Name of the image to run. Defaults to the current environment",
)
self.parser.add_argument(
"--disable-ipv6",
action="store_true",
help="Disable IPv6 support",
default=False,
)
self.parser.add_argument(
"--env",
action="append",
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
metavar="KEY=VALUE",
)
self.parser.add_argument(
"--tls-keyfile",
type=str,
help="Path to TLS key file for HTTPS",
)
self.parser.add_argument(
"--tls-certfile",
type=str,
help="Path to TLS certificate file for HTTPS",
)
self.parser.add_argument(
"--image-type",
type=str,
@ -98,44 +83,55 @@ class StackRun(Subcommand):
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
logger.info(f"Using run configuration: {config_file}")
try:
config_dict = yaml.safe_load(config_file.read_text())
except yaml.parser.ParserError as e:
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
try:
config = parse_and_maybe_upgrade_config(config_dict)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
image_type, image_name = self._get_image_type_and_name(args)
# Check if config is required based on image type
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not args.config:
self.parser.error("Config file is required for venv and conda environments")
if args.config:
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
logger.info(f"Using run configuration: {config_file}")
try:
config_dict = yaml.safe_load(config_file.read_text())
except yaml.parser.ParserError as e:
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
try:
config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(str(config.external_providers_dir)):
os.makedirs(str(config.external_providers_dir), exist_ok=True)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
else:
config = None
config_file = None
template_name = None
# If neither image type nor image name is provided, assume the server should be run directly
# using the current environment packages.
if not image_type and not image_name:
@ -157,9 +153,10 @@ class StackRun(Subcommand):
else:
run_args = formulate_run_args(image_type, image_name, config, template_name)
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6")
run_args.extend([str(args.port)])
if config_file:
run_args.extend(["--config", str(config_file)])
if args.env:
for env_var in args.env:
@ -172,6 +169,4 @@ class StackRun(Subcommand):
return
run_args.extend(["--env", f"{key}={value}"])
if args.tls_keyfile and args.tls_certfile:
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
run_command(run_args)

View file

@ -154,6 +154,12 @@ get_python_cmd() {
fi
}
# Add other required item commands generic to all containers
add_to_container << EOF
# Allows running as non-root user
RUN mkdir -p /.llama/providers.d /.cache
EOF
if [ -n "$run_config" ]; then
# Copy the run config to the build context since it's an absolute path
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
@ -166,17 +172,19 @@ EOF
# and update the configuration to reference the new container path
python_cmd=$(get_python_cmd)
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
if [ -n "$external_providers_dir" ]; then
external_providers_dir=$(eval echo "$external_providers_dir")
if [ -n "$external_providers_dir" ] && [ -d "$external_providers_dir" ]; then
echo "Copying external providers directory: $external_providers_dir"
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
add_to_container << EOF
COPY $external_providers_dir /app/providers.d
COPY providers.d /.llama/providers.d
EOF
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d
# Edit the run.yaml file to change the external_providers_dir to /.llama/providers.d
if [ "$(uname)" = "Darwin" ]; then
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
else
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
sed -i 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
fi
fi
fi
@ -255,9 +263,6 @@ fi
# Add other require item commands genearic to all containers
add_to_container << EOF
# Allows running as non-root user
RUN mkdir -p /.llama /.cache
RUN chmod -R g+rw /app /.llama /.cache
EOF

View file

@ -17,6 +17,7 @@ from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.datatypes import Api, ProviderSpec
@ -73,11 +74,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
existing_providers = config.providers.get(api_str, [])
if existing_providers:
logger.info(
f"Re-configuring existing providers for API `{api_str}`...",
"green",
attrs=["bold"],
)
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
updated_providers = []
for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`")
@ -91,7 +88,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
logger.info(f"Configuring API `{api_str}`...")
updated_providers = []
for i, provider_type in enumerate(plist):
if i >= 1:
@ -174,4 +171,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
if not config_dict.get("external_providers_dir", None):
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
return StackRunConfig(**config_dict)

View file

@ -5,9 +5,10 @@
# the root directory of this source tree.
from enum import Enum
from pathlib import Path
from typing import Annotated, Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO
@ -249,10 +250,18 @@ class ServerConfig(BaseModel):
default=None,
description="Path to TLS key file for HTTPS",
)
tls_cafile: str | None = Field(
default=None,
description="Path to TLS CA file for HTTPS with mutual TLS authentication",
)
auth: AuthenticationConfig | None = Field(
default=None,
description="Authentication configuration for the server",
)
host: str | None = Field(
default=None,
description="The host the server should listen on",
)
class StackRunConfig(BaseModel):
@ -304,11 +313,20 @@ a default SQLite store will be used.""",
description="Configuration for the HTTP(S) server",
)
external_providers_dir: str | None = Field(
external_providers_dir: Path | None = Field(
default=None,
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
@ -322,8 +340,17 @@ class BuildConfig(BaseModel):
default=None,
description="Name of the distribution to build",
)
external_providers_dir: str | None = Field(
external_providers_dir: Path | None = Field(
default=None,
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
"pip_packages MUST contain the provider package name.",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v

View file

@ -145,7 +145,7 @@ def get_provider_registry(
# Check if config has the external_providers_dir attribute
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
external_providers_dir = os.path.abspath(config.external_providers_dir)
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")

View file

@ -30,7 +30,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
@ -216,7 +216,19 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
"yellow",
)
if self.config_path_or_template_name.endswith(".yaml"):
print_pip_install_help(self.config.providers)
# Convert Provider objects to their types
provider_types: dict[str, str | list[str]] = {}
for api, providers in self.config.providers.items():
types = [p.provider_type for p in providers]
# Convert single-item lists to strings
provider_types[api] = types[0] if len(types) == 1 else types
build_config = BuildConfig(
distribution_spec=DistributionSpec(
providers=provider_types,
),
external_providers_dir=self.config.external_providers_dir,
)
print_pip_install_help(build_config)
else:
prefix = "!" if in_notebook() else ""
cprint(

View file

@ -99,7 +99,7 @@ class ProviderImpl(Providers):
try:
health = await asyncio.wait_for(impl.health(), timeout=timeout)
return api_name, health
except asyncio.TimeoutError:
except (asyncio.TimeoutError, TimeoutError):
return (
api_name,
HealthResponse(

View file

@ -44,7 +44,8 @@ class RequestProviderDataContext(AbstractContextManager):
class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__
assert spec, f"Provider spec not set on {self.__class__}"
if not spec:
raise ValueError(f"Provider spec not set on {self.__class__}")
provider_type = spec.provider_type
validator_class = spec.provider_data_validator

View file

@ -13,7 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import Inference, InferenceProvider
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
@ -83,6 +83,13 @@ def api_protocol_map() -> dict[Api, Any]:
}
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
return {
**api_protocol_map(),
Api.inference: InferenceProvider,
}
def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
@ -302,9 +309,6 @@ async def instantiate_provider(
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
):
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
@ -342,6 +346,8 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check()
additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])

View file

@ -573,6 +573,12 @@ class InferenceRouter(Inference):
for tool in tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
# Some providers make tool calls even when tool_choice is "none"
# so just clear them both out to avoid unexpected tool calls
if tool_choice == "none" and tools is not None:
tool_choice = None
tools = None
params = dict(
model=model_obj.identifier,
messages=messages,
@ -600,7 +606,19 @@ class InferenceRouter(Inference):
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_chat_completion(**params)
if stream:
return await provider.openai_chat_completion(**params)
else:
return await self._nonstream_openai_chat_completion(provider, params)
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
choice.message.tool_calls = None
return response
async def health(self) -> dict[str, HealthResponse]:
health_statuses = {}
@ -612,7 +630,7 @@ class InferenceRouter(Inference):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except asyncio.TimeoutError:
except (asyncio.TimeoutError, TimeoutError):
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",

View file

@ -93,7 +93,7 @@ class AuthenticationMiddleware:
# Validate token and get access attributes
try:
access_attributes = await self.auth_provider.validate_token(token, scope)
validation_result = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
@ -105,17 +105,20 @@ class AuthenticationMiddleware:
return await self._send_auth_error(send, "Authentication service error")
# Store attributes in request scope for access control
if access_attributes:
user_attributes = access_attributes.model_dump(exclude_none=True)
if validation_result.access_attributes:
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to token by default")
user_attributes = {
"namespaces": [token],
"roles": [token],
}
# Store attributes in request scope
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
scope["principal"] = validation_result.principal
logger.debug(
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
)
return await self.app(scope, receive, send)

View file

@ -5,12 +5,14 @@
# the root directory of this source tree.
import json
import time
from abc import ABC, abstractmethod
from enum import Enum
from urllib.parse import parse_qs
import httpx
from pydantic import BaseModel, Field
from jose import jwt
from pydantic import BaseModel, Field, field_validator
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
@ -18,9 +20,11 @@ from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
class TokenValidationResult(BaseModel):
principal: str | None = Field(
default=None,
description="The principal (username or persistent identifier) of the authenticated user",
)
access_attributes: AccessAttributes | None = Field(
default=None,
description="""
@ -43,6 +47,10 @@ class AuthResponse(BaseModel):
""",
)
class AuthResponse(TokenValidationResult):
"""The format of the authentication response from the auth endpoint."""
message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
@ -69,6 +77,7 @@ class AuthProviderType(str, Enum):
KUBERNETES = "kubernetes"
CUSTOM = "custom"
OAUTH2_TOKEN = "oauth2_token"
class AuthProviderConfig(BaseModel):
@ -82,7 +91,7 @@ class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token and return access attributes."""
pass
@ -92,12 +101,16 @@ class AuthProvider(ABC):
pass
class KubernetesAuthProviderConfig(BaseModel):
api_server_url: str
ca_cert_path: str | None = None
class KubernetesAuthProvider(AuthProvider):
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
def __init__(self, config: dict[str, str]):
self.api_server_url = config["api_server_url"]
self.ca_cert_path = config.get("ca_cert_path")
def __init__(self, config: KubernetesAuthProviderConfig):
self.config = config
self._client = None
async def _get_client(self):
@ -110,16 +123,16 @@ class KubernetesAuthProvider(AuthProvider):
# Configure the client
configuration = client.Configuration()
configuration.host = self.api_server_url
if self.ca_cert_path:
configuration.ssl_ca_cert = self.ca_cert_path
configuration.verify_ssl = bool(self.ca_cert_path)
configuration.host = self.config.api_server_url
if self.config.ca_cert_path:
configuration.ssl_ca_cert = self.config.ca_cert_path
configuration.verify_ssl = bool(self.config.ca_cert_path)
# Create API client
self._client = ApiClient(configuration)
return self._client
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a Kubernetes token and return access attributes."""
try:
client = await self._get_client()
@ -146,9 +159,12 @@ class KubernetesAuthProvider(AuthProvider):
username = payload.get("sub", "")
groups = payload.get("groups", [])
return AccessAttributes(
roles=[username], # Use username as a role
teams=groups, # Use Kubernetes groups as teams
return TokenValidationResult(
principal=username,
access_attributes=AccessAttributes(
roles=[username], # Use username as a role
teams=groups, # Use Kubernetes groups as teams
),
)
except Exception as e:
@ -162,18 +178,125 @@ class KubernetesAuthProvider(AuthProvider):
self._client = None
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
attributes = AccessAttributes()
for claim_key, attribute_key in mapping.items():
if claim_key not in claims or not hasattr(attributes, attribute_key):
continue
claim = claims[claim_key]
if isinstance(claim, list):
values = claim
else:
values = claim.split()
current = getattr(attributes, attribute_key)
if current:
current.extend(values)
else:
setattr(attributes, attribute_key, values)
return attributes
class OAuth2TokenAuthProviderConfig(BaseModel):
# The JWKS URI for collecting public keys
jwks_uri: str
cache_ttl: int = 3600
audience: str = "llama-stack"
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
if value not in AccessAttributes.model_fields:
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v
class OAuth2TokenAuthProvider(AuthProvider):
"""
JWT token authentication provider that validates a JWT token and extracts access attributes.
This should be the standard authentication provider for most use cases.
"""
def __init__(self, config: OAuth2TokenAuthProviderConfig):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the JWT token."""
await self._refresh_jwks()
try:
header = jwt.get_unverified_header(token)
kid = header["kid"]
if kid not in self._jwks:
raise ValueError(f"Unknown key ID: {kid}")
key_data = self._jwks[kid]
algorithm = header.get("alg", "RS256")
claims = jwt.decode(
token,
key_data,
algorithms=[algorithm],
audience=self.config.audience,
options={"verify_exp": True},
)
except Exception as exc:
raise ValueError(f"Invalid JWT token: {token}") from exc
# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return TokenValidationResult(
principal=principal,
access_attributes=access_attributes,
)
async def close(self):
"""Close the HTTP client."""
async def _refresh_jwks(self) -> None:
if time.time() - self._jwks_at > self.config.cache_ttl:
async with httpx.AsyncClient() as client:
res = await client.get(self.config.jwks_uri, timeout=5)
res.raise_for_status()
jwks_data = res.json()["keys"]
self._jwks = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
self._jwks[kid] = k
self._jwks_at = time.time()
class CustomAuthProviderConfig(BaseModel):
endpoint: str
class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""
def __init__(self, config: dict[str, str]):
self.endpoint = config["endpoint"]
def __init__(self, config: CustomAuthProviderConfig):
self.config = config
self._client = None
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the custom authentication endpoint."""
if not self.endpoint:
raise ValueError("Authentication endpoint not configured")
if scope is None:
scope = {}
@ -202,7 +325,7 @@ class CustomAuthProvider(AuthProvider):
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.endpoint,
self.config.endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
@ -214,19 +337,7 @@ class CustomAuthProvider(AuthProvider):
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
# Store attributes in request scope for access control
if auth_response.access_attributes:
return auth_response.access_attributes
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [token],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
return auth_response.access_attributes
return auth_response
except Exception as e:
logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e
@ -253,9 +364,11 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
provider_type = config.provider_type.lower()
if provider_type == "kubernetes":
return KubernetesAuthProvider(config.config)
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
elif provider_type == "custom":
return CustomAuthProvider(config.config)
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
else:
supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")

View file

@ -9,6 +9,7 @@ import asyncio
import inspect
import json
import os
import ssl
import sys
import traceback
import warnings
@ -17,6 +18,7 @@ from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Annotated, Any
import rich.pretty
import yaml
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
@ -114,7 +116,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
return HTTPException(status_code=400, detail=str(exc))
elif isinstance(exc, PermissionError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, TimeoutError):
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
@ -139,7 +141,7 @@ async def shutdown(app):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
except (asyncio.TimeoutError, TimeoutError):
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
@ -186,11 +188,30 @@ async def sse_generator(event_gen_coroutine):
)
async def log_request_pre_validation(request: Request):
if request.method in ("POST", "PUT", "PATCH"):
try:
body_bytes = await request.body()
if body_bytes:
try:
parsed_body = json.loads(body_bytes.decode())
log_output = rich.pretty.pretty_repr(parsed_body)
except (json.JSONDecodeError, UnicodeDecodeError):
log_output = repr(body_bytes)
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
else:
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
except Exception as e:
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
await log_request_pre_validation(request)
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
@ -337,22 +358,11 @@ def main(args: argparse.Namespace | None = None):
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
parser.add_argument(
"--tls-keyfile",
help="Path to TLS key file for HTTPS",
required="--tls-certfile" in sys.argv,
)
parser.add_argument(
"--tls-certfile",
help="Path to TLS certificate file for HTTPS",
required="--tls-keyfile" in sys.argv,
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
@ -361,9 +371,9 @@ def main(args: argparse.Namespace | None = None):
args = parser.parse_args()
# Check for deprecated argument usage
if "--yaml-config" in sys.argv:
if "--config" in sys.argv:
warnings.warn(
"The '--yaml-config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
"The '--config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
DeprecationWarning,
stacklevel=2,
)
@ -381,7 +391,7 @@ def main(args: argparse.Namespace | None = None):
raise ValueError(f"Template {args.template} does not exist")
log_line = f"Using template {args.template} config file: {config_file}"
else:
raise ValueError("Either --yaml-config or --template must be provided")
raise ValueError("Either --config or --template must be provided")
logger_config = None
with open(config_file) as fp:
@ -486,21 +496,24 @@ def main(args: argparse.Namespace | None = None):
port = args.port or config.server.port
ssl_config = None
if args.tls_keyfile:
keyfile = args.tls_keyfile
certfile = args.tls_certfile
else:
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
ssl_config = {
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
if config.server.tls_cafile:
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
listen_host = config.server.host or ["::", "0.0.0.0"]
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {

View file

@ -29,7 +29,7 @@ error_handler() {
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <env_type> <env_path_or_name> <yaml_config> <port> <script_args...>"
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
exit 1
fi
@ -40,37 +40,51 @@ env_path_or_name="$1"
container_image="localhost/$env_path_or_name"
shift
yaml_config="$1"
shift
port="$1"
shift
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# Initialize env_vars as an string
# Initialize variables
yaml_config=""
env_vars=""
other_args=""
# Process environment variables from --env arguments
# Process remaining arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
--config)
if [[ -n "$2" ]]; then
yaml_config="$2"
shift 2
else
echo -e "${RED}Error: $1 requires a CONFIG argument${NC}" >&2
exit 1
fi
;;
--env)
if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac
done
# Check if yaml_config is required based on env_type
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]] && [ -z "$yaml_config" ]; then
echo -e "${RED}Error: --config is required for venv and conda environments${NC}" >&2
exit 1
fi
PYTHON_BINARY="python"
case "$env_type" in
"venv")
@ -106,8 +120,14 @@ esac
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
if [ -n "$yaml_config" ]; then
yaml_config_arg="--config $yaml_config"
else
yaml_config_arg=""
fi
$PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
$yaml_config_arg \
--port "$port" \
$env_vars \
$other_args
@ -149,15 +169,26 @@ elif [[ "$env_type" == "container" ]]; then
version_tag=$(curl -s $URL | jq -r '.info.version')
fi
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
# Build the command with optional yaml config
cmd="$CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \
$env_vars \
-v "$yaml_config:/app/config.yaml" \
$mounts \
--env LLAMA_STACK_PORT=$port \
--entrypoint python \
$container_image:$version_tag \
-m llama_stack.distribution.server.server \
--yaml-config /app/config.yaml \
$other_args
-m llama_stack.distribution.server.server"
# Add yaml config if provided, otherwise use default
if [ -n "$yaml_config" ]; then
cmd="$cmd -v $yaml_config:/app/run.yaml --config /app/run.yaml"
else
cmd="$cmd --config /app/run.yaml"
fi
# Add any other args
cmd="$cmd $other_args"
# Execute the command
eval $cmd
fi

View file

@ -73,7 +73,7 @@ class DiskDistributionRegistry(DistributionRegistry):
async def get_all(self) -> list[RoutableObjectWithProvider]:
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
values = await self.kvstore.values_in_range(start_key, end_key)
return _parse_registry_values(values)
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
@ -134,7 +134,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
return
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
values = await self.kvstore.values_in_range(start_key, end_key)
objects = _parse_registry_values(values)
async with self._locked_cache() as cache:

View file

@ -124,7 +124,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
else:
full_response = response
message_placeholder.markdown(full_response.completion_message.content)
full_response = response.completion_message.content
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})

View file

@ -14,3 +14,5 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"

View file

@ -22,8 +22,10 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType
def formulate_run_args(image_type, image_name, config, template_name) -> list:
env_name = ""
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
env_name = f"distribution-{template_name}" if template_name else config.container_image
if image_type == LlamaStackImageType.CONTAINER.value:
env_name = (
f"distribution-{template_name}" if template_name else (config.container_image if config else image_name)
)
elif image_type == LlamaStackImageType.CONDA.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
env_name = image_name or current_conda_env

View file

@ -245,7 +245,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{"function_description": self._gen_function_description(custom_tools)},
)
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> str:
template_str = textwrap.dedent(
"""
Here is a list of functions in JSON format that you can invoke.
@ -286,10 +286,12 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
"""
)
return PromptTemplate(
template = PromptTemplate(
template_str.strip("\n"),
{"tools": [t.model_dump() for t in custom_tools]},
).render()
)
rendered: str = template.render()
return rendered
def data_examples(self) -> list[list[ToolDefinition]]:
return [

View file

@ -173,9 +173,7 @@ INCORRECT: [get_events(location="Singapore")] <- If function not in list
- Don't repeat tool response verbatim
- Don't add supplementary information
Here is a list of functions in JSON format that you can invoke.
Here is a list of functions in JSON format that you can invoke:
[
{
"name": "get_weather",
@ -196,10 +194,7 @@ Here is a list of functions in JSON format that you can invoke.
}
}
}
]
You can answer general questions or invoke tools when necessary.
In addition to tool calls, you should also augment your responses by using the tool outputs.<|eot|><|header_start|>user<|header_end|>
]<|eot|><|header_start|>user<|header_end|>
What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|>

View file

@ -61,7 +61,6 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
- Don't repeat tool response verbatim
- Don't add supplementary information
{{ function_description }}
""".strip("\n")
)
@ -76,8 +75,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
Here is a list of functions in JSON format that you can invoke.
Here is a list of functions in JSON format that you can invoke:
[
{% for t in tools -%}
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
@ -108,10 +106,6 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{% endif -%}
{%- endfor %}
]
You can answer general questions or invoke tools when necessary.
In addition to tool calls, you should also augment your responses by using the tool outputs.
"""
)
return PromptTemplate(

View file

@ -948,6 +948,8 @@ def llama_meta_net_info(model: Model) -> LlamaDownloadInfo:
elif model.core_model_id == CoreModelId.llama_guard_2_8b:
folder = "llama-guard-2"
else:
if model.huggingface_repo is None:
raise ValueError(f"Model {model.core_model_id} has no huggingface_repo set")
folder = model.huggingface_repo.split("/")[-1]
if "Llama-2" in folder:
folder = folder.lower()
@ -1024,3 +1026,4 @@ def llama_meta_pth_size(model: Model) -> int:
return 54121549657
else:
return 100426653046
return 0

View file

@ -95,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_groups_api: ToolGroups,
vector_io_api: VectorIO,
persistence_store: KVStore,
created_at: str,
):
self.agent_id = agent_id
self.agent_config = agent_config
@ -104,6 +105,7 @@ class ChatAgent(ShieldRunnerMixin):
self.storage = AgentPersistence(agent_id, persistence_store)
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.created_at = created_at
ShieldRunnerMixin.__init__(
self,

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from llama_stack.apis.agents import (
Agent,
@ -20,14 +20,13 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
ListAgentSessionsResponse,
ListAgentsResponse,
OpenAIResponseInputMessage,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
Session,
Turn,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import (
Inference,
ToolConfig,
@ -39,13 +38,14 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl
from .persistence import AgentInfo
logger = logging.getLogger()
logger.setLevel(logging.INFO)
class MetaReferenceAgentsImpl(Agents):
@ -82,43 +82,47 @@ class MetaReferenceAgentsImpl(Agents):
agent_config: AgentConfig,
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
created_at = datetime.now(timezone.utc)
agent_info = AgentInfo(
**agent_config.model_dump(),
created_at=created_at,
)
# Store the agent info
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.model_dump_json(),
value=agent_info.model_dump_json(),
)
return AgentCreateResponse(
agent_id=agent_id,
)
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
agent_info_json = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
if not agent_config:
raise ValueError(f"Could not find agent config for {agent_id}")
if not agent_info_json:
raise ValueError(f"Could not find agent info for {agent_id}")
try:
agent_config = json.loads(agent_config)
except json.JSONDecodeError as e:
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
try:
agent_config = AgentConfig(**agent_config)
agent_info = AgentInfo.model_validate_json(agent_info_json)
except Exception as e:
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
raise ValueError(f"Could not validate agent info for {agent_id}") from e
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config,
agent_config=agent_info,
inference_api=self.inference_api,
safety_api=self.safety_api,
vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
created_at=agent_info.created_at,
)
async def create_agent_session(
@ -212,6 +216,7 @@ class MetaReferenceAgentsImpl(Agents):
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
@ -226,24 +231,75 @@ class MetaReferenceAgentsImpl(Agents):
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
# Delete turns first, then the session
await agent.storage.delete_session_turns(session_id)
await agent.storage.delete_session(session_id)
async def delete_agent(self, agent_id: str) -> None:
# First get all sessions for this agent
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Delete all sessions
for session in sessions:
await self.delete_agents_session(agent_id, session.session_id)
# Finally delete the agent itself
await self.persistence_store.delete(f"agent:{agent_id}")
async def shutdown(self) -> None:
pass
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
agent_list: list[Agent] = []
for agent_key in agent_keys:
agent_id = agent_key.split(":")[1]
async def list_agents(self) -> ListAgentsResponse:
pass
# Get the agent info using the key
agent_info_json = await self.persistence_store.get(agent_key)
if not agent_info_json:
logger.error(f"Could not find agent info for key {agent_key}")
continue
try:
agent_info = AgentInfo.model_validate_json(agent_info_json)
agent_list.append(
Agent(
agent_id=agent_id,
agent_config=agent_info,
created_at=agent_info.created_at,
)
)
except Exception as e:
logger.error(f"Error parsing agent info for {agent_id}: {e}")
continue
# Convert Agent objects to dictionaries
agent_dicts = [agent.model_dump() for agent in agent_list]
return paginate_records(agent_dicts, start_index, limit)
async def get_agent(self, agent_id: str) -> Agent:
pass
chat_agent = await self._get_agent_impl(agent_id)
agent = Agent(
agent_id=agent_id,
agent_config=chat_agent.agent_config,
created_at=chat_agent.created_at,
)
return agent
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
self, agent_id: str, start_index: int | None = None, limit: int | None = None
) -> PaginatedResponse:
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Convert Session objects to dictionaries
session_dicts = [session.model_dump() for session in sessions]
return paginate_records(session_dicts, start_index, limit)
async def shutdown(self) -> None:
pass
# OpenAI responses
@ -255,7 +311,7 @@ class MetaReferenceAgentsImpl(Agents):
async def create_openai_response(
self,
input: str | list[OpenAIResponseInputMessage],
input: str | list[OpenAIResponseInput],
model: str,
previous_response_id: str | None = None,
store: bool | None = True,

View file

@ -7,22 +7,29 @@
import json
import uuid
from collections.abc import AsyncIterator
from typing import cast
from typing import Any, cast
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessage,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputItemList,
OpenAIResponseInputMessageContent,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolFunction,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseOutput,
OpenAIResponseOutputMessage,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.inference.inference import (
@ -32,10 +39,13 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
@ -50,31 +60,110 @@ logger = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
async def _convert_response_content_to_chat_content(
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
The content schemas of each API look similar, but are not exactly the same.
"""
if isinstance(content, str):
return content
converted_parts = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
if content_part.image_url:
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
elif isinstance(content_part, str):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
)
return converted_parts
async def _convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
"""
messages: list[OpenAIMessageParam] = []
for output_message in previous_response.output:
if isinstance(output_message, OpenAIResponseOutputMessage):
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
if isinstance(input, list):
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.call_id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
else:
content = await _convert_response_content_to_chat_content(input_item.content)
message_type = await _get_message_type_by_role(input_item.role)
if message_type is None:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
messages.append(message_type(content=content))
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
output_messages = []
for choice in choices:
output_content = ""
if isinstance(choice.message.content, str):
output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text
# TODO: handle image content
output_messages.append(
OpenAIResponseOutputMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
status="completed",
)
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
"""
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
"""
output_content = ""
if isinstance(choice.message.content, str):
output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
return output_messages
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
status="completed",
role="assistant",
)
async def _get_message_type_by_role(role: str):
role_to_type = {
"user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam,
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
input_items: OpenAIResponseInputItemList
response: OpenAIResponseObject
class OpenAIResponsesImpl:
@ -90,19 +179,45 @@ class OpenAIResponsesImpl:
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
async def get_openai_response(
self,
id: str,
) -> OpenAIResponseObject:
async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems:
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
response_json = await self.persistence_store.get(key=key)
if response_json is None:
raise ValueError(f"OpenAI response with id '{id}' not found")
return OpenAIResponseObject.model_validate_json(response_json)
return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json)
async def _prepend_previous_response(
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
):
if previous_response_id:
previous_response_with_input = await self._get_previous_response_with_input(previous_response_id)
# previous response input items
new_input_items = previous_response_with_input.input_items.data
# previous response output items
new_input_items.extend(previous_response_with_input.response.output)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
input = new_input_items
return input
async def get_openai_response(
self,
id: str,
) -> OpenAIResponseObject:
response_with_input = await self._get_previous_response_with_input(id)
return response_with_input.response
async def create_openai_response(
self,
input: str | list[OpenAIResponseInputMessage],
input: str | list[OpenAIResponseInput],
model: str,
previous_response_id: str | None = None,
store: bool | None = True,
@ -112,31 +227,8 @@ class OpenAIResponsesImpl:
):
stream = False if stream is None else stream
messages: list[OpenAIMessageParam] = []
if previous_response_id:
previous_response = await self.get_openai_response(previous_response_id)
messages.extend(await _previous_response_to_messages(previous_response))
# TODO: refactor this user_content parsing out into a separate method
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
if isinstance(input, list):
user_content = []
for user_input in input:
if isinstance(user_input.content, list):
for user_input_content in user_input.content:
if isinstance(user_input_content, OpenAIResponseInputMessageContentText):
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text))
elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage):
if user_input_content.image_url:
image_url = OpenAIImageURL(
url=user_input_content.image_url, detail=user_input_content.detail
)
user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
else:
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
else:
user_content = input
messages.append(OpenAIUserMessageParam(content=user_content))
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
chat_response = await self.inference_api.openai_chat_completion(
model=model,
@ -150,6 +242,7 @@ class OpenAIResponsesImpl:
# TODO: refactor this into a separate method that handles streaming
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
# TODO: these chunk_ fields are hacky and only take the last chunk into account
chunk_created = 0
chunk_model = ""
@ -163,7 +256,30 @@ class OpenAIResponsesImpl:
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
assistant_message = OpenAIAssistantMessageParam(content="".join(chat_response_content))
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
# Ensure we don't have any empty type field in the tool call dict.
# The OpenAI client used by providers often returns a type=None here.
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
@ -181,12 +297,26 @@ class OpenAIResponsesImpl:
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
output_messages: list[OpenAIResponseOutput] = []
if chat_response.choices[0].message.tool_calls:
output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
)
else:
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
# Assume if the first tool is a function, all tools are functions
if isinstance(tools[0], OpenAIResponseInputToolFunction):
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
)
else:
output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
@ -195,13 +325,43 @@ class OpenAIResponsesImpl:
status="completed",
output=output_messages,
)
logger.debug(f"OpenAI Responses response: {response}")
if store:
# Store in kvstore
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
input_items = OpenAIResponseInputItemList(data=input_items_data)
prev_response = OpenAIResponsePreviousResponseWithInputItems(
input_items=input_items,
response=response,
)
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
await self.persistence_store.set(
key=key,
value=response.model_dump_json(),
value=prev_response.model_dump_json(),
)
if stream:
@ -221,7 +381,9 @@ class OpenAIResponsesImpl:
chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
if input_tool.type == "web_search":
if input_tool.type == "function":
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
elif input_tool.type == "web_search":
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
tool_def = ToolDefinition(
@ -247,12 +409,11 @@ class OpenAIResponsesImpl:
self,
model_id: str,
stream: bool,
chat_response: OpenAIChatCompletion,
choice: OpenAIChoice,
messages: list[OpenAIMessageParam],
temperature: float,
) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = []
choice = chat_response.choices[0]
# If the choice is not an assistant message, we don't need to execute any tools
if not isinstance(choice.message, OpenAIAssistantMessageParam):
@ -262,6 +423,9 @@ class OpenAIResponsesImpl:
if not choice.message.tool_calls:
return output_messages
# Copy the messages list to avoid mutating the original list
messages = messages.copy()
# Add the assistant message with tool_calls response to the messages list
messages.append(choice.message)
@ -307,7 +471,9 @@ class OpenAIResponsesImpl:
)
# type cast to appease mypy
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices)
tool_final_outputs = [
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
]
# TODO: Wire in annotations with URLs, titles, etc to these output messages
output_messages.extend(tool_final_outputs)
return output_messages

View file

@ -9,9 +9,7 @@ import logging
import uuid
from datetime import datetime, timezone
from pydantic import BaseModel
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes
@ -20,15 +18,17 @@ from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
class AgentSessionInfo(Session):
# TODO: is this used anywhere?
vector_db_id: str | None = None
started_at: datetime
access_attributes: AccessAttributes | None = None
class AgentInfo(AgentConfig):
created_at: datetime
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore):
self.agent_id = agent_id
@ -46,6 +46,7 @@ class AgentPersistence:
session_name=name,
started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
turns=[],
)
await self.kvstore.set(
@ -109,7 +110,7 @@ class AgentPersistence:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
values = await self.kvstore.range(
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
@ -121,7 +122,6 @@ class AgentPersistence:
except Exception as e:
log.error(f"Error parsing turn: {e}")
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
@ -170,3 +170,43 @@ class AgentPersistence:
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)
return int(value) if value else None
async def list_sessions(self) -> list[Session]:
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:",
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
)
sessions = []
for value in values:
try:
session_info = Session(**json.loads(value))
sessions.append(session_info)
except Exception as e:
log.error(f"Error parsing session info: {e}")
continue
return sessions
async def delete_session_turns(self, session_id: str) -> None:
"""Delete all turns and their associated data for a session.
Args:
session_id: The ID of the session whose turns should be deleted.
"""
turns = await self.get_session_turns(session_id)
for turn in turns:
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
async def delete_session(self, session_id: str) -> None:
"""Delete a session and all its associated turns.
Args:
session_id: The ID of the session to delete.
Raises:
ValueError: If the session does not exist.
"""
session_info = await self.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")

View file

@ -11,9 +11,9 @@ from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.pagination import paginate_records
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from .config import LocalFSDatasetIOConfig
@ -64,7 +64,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key)
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)

View file

@ -58,7 +58,7 @@ class MetaReferenceEvalImpl(
# Load existing benchmarks from kvstore
start_key = EVAL_TASKS_PREFIX
end_key = f"{EVAL_TASKS_PREFIX}\xff"
stored_benchmarks = await self.kvstore.range(start_key, end_key)
stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key)
for benchmark in stored_benchmarks:
benchmark = Benchmark.model_validate_json(benchmark)

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
InferenceProvider,
InterleavedContent,
LogProbConfig,
Message,
@ -86,7 +86,7 @@ class MetaReferenceInferenceImpl(
OpenAICompletionToLlamaStackMixin,
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:

View file

@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator
from llama_stack.apis.inference import (
CompletionResponse,
Inference,
InferenceProvider,
InterleavedContent,
LogProbConfig,
Message,
@ -38,7 +38,7 @@ class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
def evacuate_model_from_device(model, device: str):
"""Safely clear a model from memory and free device resources.
This function handles the proper cleanup of a model by:
1. Moving the model to CPU if it's on a non-CPU device
2. Deleting the model object to free memory
3. Running garbage collection
4. Clearing CUDA cache if the model was on a CUDA device
Args:
model: The PyTorch model to clear
device: The device type the model is currently on ('cuda', 'mps', 'cpu')
Note:
- For CUDA devices, this will clear the CUDA cache after moving the model to CPU
- For MPS devices, only moves the model to CPU (no cache clearing available)
- For CPU devices, only deletes the model object and runs garbage collection
"""
if device != "cpu":
model.to("cpu")
del model
gc.collect()
if device == "cuda":
# we need to import such that this is only imported when the method is called
import torch
torch.cuda.empty_cache()

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import Api
from .config import HuggingFacePostTrainingConfig
# post_training api and the huggingface provider is still experimental and under heavy development
async def get_provider_impl(
config: HuggingFacePostTrainingConfig,
deps: dict[Api, Any],
):
from .post_training import HuggingFacePostTrainingImpl
impl = HuggingFacePostTrainingImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
return impl

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal
from pydantic import BaseModel
class HuggingFacePostTrainingConfig(BaseModel):
# Device to run training on (cuda, cpu, mps)
device: str = "cuda"
# Distributed training backend if using multiple devices
# fsdp: Fully Sharded Data Parallel
# deepspeed: DeepSpeed ZeRO optimization
distributed_backend: Literal["fsdp", "deepspeed"] | None = None
# Format for saving model checkpoints
# full_state: Save complete model state
# huggingface: Save in HuggingFace format (recommended for compatibility)
checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface"
# Template for formatting chat inputs and outputs
# Used to structure the conversation format for training
chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}"
# Model-specific configuration parameters
# trust_remote_code: Allow execution of custom model code
# attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance
model_specific_config: dict = {
"trust_remote_code": True,
"attn_implementation": "sdpa",
}
# Maximum sequence length for training
# Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon)
# Longer sequences may cause memory issues on MPS devices
max_seq_length: int = 2048
# Enable gradient checkpointing to reduce memory usage
# Trades computation for memory by recomputing activations
gradient_checkpointing: bool = False
# Maximum number of checkpoints to keep
# Older checkpoints are deleted when this limit is reached
save_total_limit: int = 3
# Number of training steps between logging updates
logging_steps: int = 10
# Ratio of training steps used for learning rate warmup
# Helps stabilize early training
warmup_ratio: float = 0.1
# L2 regularization coefficient
# Helps prevent overfitting
weight_decay: float = 0.01
# Number of worker processes for data loading
# Higher values can improve data loading speed but increase memory usage
dataloader_num_workers: int = 4
# Whether to pin memory in data loader
# Can improve data transfer speed to GPU but uses more memory
dataloader_pin_memory: bool = True
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}

View file

@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.config import (
HuggingFacePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
RESOURCES_STATS = "resources_stats"
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
class HuggingFacePostTrainingImpl:
def __init__(
self,
config: HuggingFacePostTrainingConfig,
datasetio_api: DatasetIO,
datasets: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
self._scheduler = Scheduler()
async def shutdown(self) -> None:
await self._scheduler.shutdown()
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
metadata=resources_stats,
)
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str,
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting HF finetuning")
recipe = HFFinetuningSingleDevice(
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
)
resources_allocated, checkpoints = await recipe.train(
model=model,
output_dir=checkpoint_dir,
job_uuid=job_uuid,
lora_config=algorithm_config,
config=training_config,
provider_config=self.config,
)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("HF finetuning completed")
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
raise NotImplementedError("DPO alignment is not implemented yet")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

View file

@ -0,0 +1,683 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
import json
import logging
import multiprocessing
import os
import signal
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import psutil
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Force PyTorch to use OpenBLAS instead of MKL
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
LoraFinetuningConfig,
TrainingConfig,
)
from ..config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__)
def get_gb(to_convert: int) -> str:
"""Converts memory stats to GB and formats to 2 decimal places.
Args:
to_convert: Memory value in bytes
Returns:
str: Memory value in GB formatted to 2 decimal places
"""
return f"{(to_convert / (1024**3)):.2f}"
def get_memory_stats(device: torch.device) -> dict[str, Any]:
"""Get memory statistics for the given device."""
stats = {
"system_memory": {
"total": get_gb(psutil.virtual_memory().total),
"available": get_gb(psutil.virtual_memory().available),
"used": get_gb(psutil.virtual_memory().used),
"percent": psutil.virtual_memory().percent,
}
}
if device.type == "cuda":
stats["device_memory"] = {
"allocated": get_gb(torch.cuda.memory_allocated(device)),
"reserved": get_gb(torch.cuda.memory_reserved(device)),
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
}
elif device.type == "mps":
# MPS doesn't provide direct memory stats, but we can track system memory
stats["device_memory"] = {
"note": "MPS memory stats not directly available",
"system_memory_used": get_gb(psutil.virtual_memory().used),
}
elif device.type == "cpu":
# For CPU, we track process memory usage
process = psutil.Process()
stats["device_memory"] = {
"process_rss": get_gb(process.memory_info().rss),
"process_vms": get_gb(process.memory_info().vms),
"process_percent": process.memory_percent(),
}
return stats
def setup_torch_device(device_str: str) -> torch.device:
"""Initialize and validate a PyTorch device.
This function handles device initialization and validation for different device types:
- CUDA: Validates CUDA availability and handles device selection
- MPS: Validates MPS availability for Apple Silicon
- CPU: Basic validation
- HPU: Raises error as it's not supported
Args:
device_str: String specifying the device ('cuda', 'cpu', 'mps')
Returns:
torch.device: The initialized and validated device
Raises:
RuntimeError: If device initialization fails or device is not supported
"""
try:
device = torch.device(device_str)
except RuntimeError as e:
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
# Validate device capabilities
if device.type == "cuda":
if not torch.cuda.is_available():
raise RuntimeError(
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
)
if device.index is None:
device = torch.device(device.type, torch.cuda.current_device())
elif device.type == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
elif device.type == "hpu":
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
return device
class HFFinetuningSingleDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
):
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool:
"""Validate that the dataset has the required fields."""
required_fields = ["input_query", "expected_answer", "chat_completion_input"]
return all(field in row for row in rows for field in required_fields)
def _process_instruct_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row in instruct format."""
if "chat_completion_input" in row and "expected_answer" in row:
try:
messages = json.loads(row["chat_completion_input"])
if not isinstance(messages, list) or len(messages) != 1:
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
return None, None
if "content" not in messages[0]:
logger.warning(f"Message missing content: {messages[0]}")
return None, None
return messages[0]["content"], row["expected_answer"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
return None, None
return None, None
def _process_dialog_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row in dialog format."""
if "dialog" in row:
try:
dialog = json.loads(row["dialog"])
if not isinstance(dialog, list) or len(dialog) < 2:
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
return None, None
if dialog[0].get("role") != "user":
logger.warning(f"First message must be from user: {dialog[0]}")
return None, None
if not any(msg.get("role") == "assistant" for msg in dialog):
logger.warning("Dialog must have at least one assistant message")
return None, None
# Convert to human/gpt format
role_map = {"user": "human", "assistant": "gpt"}
conversations = []
for msg in dialog:
if "role" not in msg or "content" not in msg:
logger.warning(f"Message missing role or content: {msg}")
continue
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
# Format as a single conversation
return conversations[0]["value"], conversations[1]["value"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse dialog: {row['dialog']}")
return None, None
return None, None
def _process_fallback_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row using fallback formats."""
if "input" in row and "output" in row:
return row["input"], row["output"]
elif "prompt" in row and "completion" in row:
return row["prompt"], row["completion"]
elif "question" in row and "answer" in row:
return row["question"], row["answer"]
return None, None
def _format_text(self, input_text: str, output_text: str, provider_config: HuggingFacePostTrainingConfig) -> str:
"""Format input and output text based on model requirements."""
if hasattr(provider_config, "chat_template"):
return provider_config.chat_template.format(input=input_text, output=output_text)
return f"{input_text}\n{output_text}"
def _create_dataset(
self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Create and preprocess the dataset."""
formatted_rows = []
for row in rows:
input_text = None
output_text = None
# Process based on format
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
if config.data_config.data_format.value == "instruct":
input_text, output_text = self._process_instruct_format(row)
elif config.data_config.data_format.value == "dialog":
input_text, output_text = self._process_dialog_format(row)
else:
input_text, output_text = self._process_fallback_format(row)
if input_text and output_text:
formatted_text = self._format_text(input_text, output_text, provider_config)
formatted_rows.append({"text": formatted_text})
if not formatted_rows:
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
raise ValueError(
f"No valid input/output pairs found in the dataset for format: {config.data_config.data_format.value}"
)
return Dataset.from_list(formatted_rows)
def _preprocess_dataset(
self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Preprocess the dataset with tokenizer."""
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding=True,
truncation=True,
max_length=provider_config.max_seq_length,
return_tensors=None,
)
return ds.map(
tokenize_function,
batched=True,
remove_columns=ds.column_names,
)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def _run_training_sync(
self,
model: str,
provider_config: dict[str, Any],
peft_config: LoraConfig | None,
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Synchronous wrapper for running training process.
This method serves as a bridge between the multiprocessing Process and the async training function.
It creates a new event loop to run the async training process.
Args:
model: The model identifier to load
dataset_id: ID of the dataset to use for training
provider_config: Configuration specific to the HuggingFace provider
peft_config: Optional LoRA configuration
config: General training configuration
output_dir_path: Optional path to save the model
"""
import asyncio
logger.info("Starting training process with async wrapper")
asyncio.run(
self._run_training(
model=model,
provider_config=provider_config,
peft_config=peft_config,
config=config,
output_dir_path=output_dir_path,
)
)
async def load_dataset(
self,
model: str,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[Dataset, Dataset, AutoTokenizer]:
"""Load and prepare the dataset for training.
Args:
model: The model identifier to load
config: Training configuration
provider_config: Provider-specific configuration
Returns:
tuple: (train_dataset, eval_dataset, tokenizer)
"""
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await self._setup_data(config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
try:
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
# Set pad token to eos token if not present
# This is common for models that don't have a dedicated pad token
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
# Set padding side to right for causal language modeling
# This ensures that padding tokens don't interfere with the model's ability
# to predict the next token in the sequence
tokenizer.padding_side = "right"
# Set truncation side to right to keep the beginning of the sequence
# This is important for maintaining context and instruction format
tokenizer.truncation_side = "right"
# Set model max length to match provider config
# This ensures consistent sequence lengths across the training process
tokenizer.model_max_length = provider_config.max_seq_length
logger.info("Tokenizer initialized successfully")
except Exception as e:
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset
logger.info("Splitting dataset into train and validation sets")
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
return train_dataset, eval_dataset, tokenizer
def load_model(
self,
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def setup_training_args(
self,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
device: torch.device,
output_dir_path: Path | None,
steps_per_epoch: int,
) -> SFTConfig:
"""Setup training arguments.
Args:
config: Training configuration
provider_config: Provider-specific configuration
device: The device to train on
output_dir_path: Optional path to save the model
steps_per_epoch: Number of steps per epoch
Returns:
Configured SFTConfig object
"""
logger.info("Configuring training arguments")
lr = 2e-5
if config.optimizer_config:
lr = config.optimizer_config.lr
logger.info(f"Using custom learning rate: {lr}")
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
data_config = config.data_config
# Calculate steps
total_steps = steps_per_epoch * config.n_epochs
max_steps = min(config.max_steps_per_epoch, total_steps)
eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch
save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("Training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Eval steps: {eval_steps}")
logger.info(f"- Save steps: {save_steps}")
logger.info(f"- Logging steps: {logging_steps}")
# Configure save strategy
save_strategy = "no"
if output_dir_path:
save_strategy = "steps"
logger.info(f"Will save checkpoints to {output_dir_path}")
return SFTConfig(
max_steps=max_steps,
output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size,
fp16=device.type == "cuda",
bf16=False, # Causes CPU issues.
eval_strategy="steps",
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
save_strategy=save_strategy,
report_to="none",
max_seq_length=provider_config.max_seq_length,
gradient_accumulation_steps=config.gradient_accumulation_steps,
gradient_checkpointing=provider_config.gradient_checkpointing,
learning_rate=lr,
warmup_ratio=provider_config.warmup_ratio,
weight_decay=provider_config.weight_decay,
remove_unused_columns=False,
dataloader_pin_memory=provider_config.dataloader_pin_memory,
dataloader_num_workers=provider_config.dataloader_num_workers,
dataset_text_field="text",
packing=False,
load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
eval_steps=eval_steps,
save_steps=save_steps,
logging_steps=logging_steps,
)
def save_model(
self,
model_obj: AutoModelForCausalLM,
trainer: SFTTrainer,
peft_config: LoraConfig | None,
output_dir_path: Path,
) -> None:
"""Save the trained model.
Args:
model_obj: The model to save
trainer: The trainer instance
peft_config: Optional LoRA configuration
output_dir_path: Path to save the model
"""
logger.info("Saving final model")
model_obj.config.use_cache = True
if peft_config:
logger.info("Merging LoRA weights with base model")
model_obj = trainer.model.merge_and_unload()
else:
model_obj = trainer.model
save_path = output_dir_path / "merged_model"
logger.info(f"Saving model to {save_path}")
model_obj.save_pretrained(save_path)
async def _run_training(
self,
model: str,
provider_config: dict[str, Any],
peft_config: LoraConfig | None,
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Run the training process with signal handling."""
def signal_handler(signum, frame):
"""Handle termination signals gracefully."""
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Convert config dicts back to objects
logger.info("Initializing configuration objects")
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
config_obj = TrainingConfig(**config)
# Initialize and validate device
device = setup_torch_device(provider_config_obj.device)
logger.info(f"Using device '{device}'")
# Load dataset and tokenizer
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
# Calculate steps per epoch
if not config_obj.data_config:
raise ValueError("DataConfig is required for training")
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
# Setup training arguments
training_args = self.setup_training_args(
config_obj,
provider_config_obj,
device,
output_dir_path,
steps_per_epoch,
)
# Load model
model_obj = self.load_model(model, device, provider_config_obj)
# Initialize trainer
logger.info("Initializing SFTTrainer")
trainer = SFTTrainer(
model=model_obj,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
args=training_args,
)
try:
# Train
logger.info("Starting training")
trainer.train()
logger.info("Training completed successfully")
# Save final model if output directory is provided
if output_dir_path:
self.save_model(model_obj, trainer, peft_config, output_dir_path)
finally:
# Clean up resources
logger.info("Cleaning up resources")
if hasattr(trainer, "model"):
evacuate_model_from_device(trainer.model, device.type)
del trainer
gc.collect()
logger.info("Cleanup completed")
async def train(
self,
model: str,
output_dir: str | None,
job_uuid: str,
lora_config: LoraFinetuningConfig,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
"""Train a model using HuggingFace's SFTTrainer"""
# Initialize and validate device
device = setup_torch_device(provider_config.device)
logger.info(f"Using device '{device}'")
output_dir_path = None
if output_dir:
output_dir_path = Path(output_dir)
# Track memory stats
memory_stats = {
"initial": get_memory_stats(device),
"after_training": None,
"final": None,
}
# Configure LoRA
peft_config = None
if lora_config:
peft_config = LoraConfig(
lora_alpha=lora_config.alpha,
lora_dropout=0.1,
r=lora_config.rank,
bias="none",
task_type="CAUSAL_LM",
target_modules=lora_config.lora_attn_modules,
)
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
# Train in a separate process
logger.info("Starting training in separate process")
try:
# Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process(
target=self._run_training_sync,
kwargs={
"model": model,
"provider_config": provider_config.model_dump(),
"peft_config": peft_config,
"config": config.model_dump(),
"output_dir_path": output_dir_path,
},
)
process.start()
# Monitor the process
while process.is_alive():
process.join(timeout=1) # Check every second
if not process.is_alive():
break
# Get the return code
if process.exitcode != 0:
raise RuntimeError(f"Training failed with exit code {process.exitcode}")
memory_stats["after_training"] = get_memory_stats(device)
checkpoints = None
if output_dir_path:
# Create checkpoint
checkpoint = Checkpoint(
identifier=f"{model}-sft-{config.n_epochs}",
created_at=datetime.now(timezone.utc),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(output_dir_path / "merged_model"),
)
checkpoints = [checkpoint]
return memory_stats, checkpoints
finally:
memory_stats["final"] = get_memory_stats(device)
gc.collect()

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
import logging
import os
import time
@ -39,7 +38,6 @@ from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
EfficiencyConfig,
LoraFinetuningConfig,
OptimizerConfig,
QATFinetuningConfig,
@ -48,6 +46,7 @@ from llama_stack.apis.post_training import (
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
@ -90,8 +89,6 @@ class LoraFinetuningSingleDevice:
) -> None:
assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized"
assert isinstance(training_config.efficiency_config, EfficiencyConfig), "EfficiencyConfig must be initialized"
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
@ -557,11 +554,7 @@ class LoraFinetuningSingleDevice:
checkpoints.append(checkpoint)
# clean up the memory after training finishes
if self._device.type != "cpu":
self._model.to("cpu")
torch.cuda.empty_cache()
del self._model
gc.collect()
evacuate_model_from_device(self._model, self._device.type)
return (memory_stats, checkpoints)

View file

@ -20,7 +20,10 @@ from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import (
Event,
MetricEvent,
MetricLabelMatcher,
MetricQueryType,
QueryCondition,
QueryMetricsResponse,
QuerySpanTreeResponse,
QueryTracesResponse,
Span,
@ -123,6 +126,17 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
else:
raise ValueError(f"Unknown event type: {event}")
async def query_metrics(
self,
metric_name: str,
start_time: int,
end_time: int | None = None,
granularity: str | None = "1d",
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:
raise NotImplementedError("Querying metrics is not implemented")
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage

View file

@ -87,6 +87,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
doc.metadata,
)
)
@ -105,7 +106,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
if not vector_db_ids:
return RAGQueryResult(content=None)
raise ValueError(
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
)
query_config = query_config or RAGQueryConfig()
query = await generate_rag_query(
@ -140,19 +143,21 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
]
for i, c in enumerate(chunks):
metadata = c.metadata
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
tokens += metadata["token_count"]
tokens += metadata["metadata_token_count"]
if tokens > query_config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
picked.append(
TextContentItem(
text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n",
)
)
metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
picked.append(TextContentItem(text=text_content))
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
picked.append(
TextContentItem(

View file

@ -125,7 +125,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
# Load existing banks from kvstore
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.range(start_key, end_key)
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)

View file

@ -280,11 +280,10 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=[
"openai",
],
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
),
),
remote_provider_spec(

View file

@ -21,6 +21,17 @@ def available_providers() -> list[ProviderSpec]:
Api.datasets,
],
),
InlineProviderSpec(
api=Api.post_training,
provider_type="inline::huggingface",
pip_packages=["torch", "trl", "transformers", "peft", "datasets"],
module="llama_stack.providers.inline.post_training.huggingface",
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
remote_provider_spec(
api=Api.post_training,
adapter=AdapterSpec(

View file

@ -12,8 +12,8 @@ from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.pagination import paginate_records
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from .config import HuggingfaceDatasetIOConfig
@ -42,7 +42,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key)
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import CerebrasCompatConfig
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .cerebras import CerebrasCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import FireworksCompatConfig
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .fireworks import FireworksCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import GroqCompatConfig
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .groq import GroqCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
@ -61,6 +61,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
@ -81,7 +82,7 @@ logger = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter(
Inference,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, url: str) -> None:
@ -139,6 +140,8 @@ class OllamaInferenceAdapter(
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -202,6 +205,8 @@ class OllamaInferenceAdapter(
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
@ -346,6 +351,8 @@ class OllamaInferenceAdapter(
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None:
provider_resource_id = model.provider_resource_id
@ -389,29 +396,25 @@ class OllamaInferenceAdapter(
raise ValueError("Ollama does not support non-string prompts for completion")
model_obj = await self._get_model(model)
params = {
k: v
for k, v in {
"model": model_obj.provider_resource_id,
"prompt": prompt,
"best_of": best_of,
"echo": echo,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"logprobs": logprobs,
"max_tokens": max_tokens,
"n": n,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options,
"temperature": temperature,
"top_p": top_p,
"user": user,
}.items()
if v is not None
}
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self.openai_client.completions.create(**params) # type: ignore
async def openai_chat_completion(
@ -441,41 +444,31 @@ class OllamaInferenceAdapter(
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self._get_model(model)
# ollama still makes tool calls even when tool_choice is "none"
# so we need to remove the tools in that case
if tool_choice == "none" and tools is not None:
tools = None
params = {
k: v
for k, v in {
"model": model_obj.provider_resource_id,
"messages": messages,
"frequency_penalty": frequency_penalty,
"function_call": function_call,
"functions": functions,
"logit_bias": logit_bias,
"logprobs": logprobs,
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"n": n,
"parallel_tool_calls": parallel_tool_calls,
"presence_penalty": presence_penalty,
"response_format": response_format,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options,
"temperature": temperature,
"tool_choice": tool_choice,
"tools": tools,
"top_logprobs": top_logprobs,
"top_p": top_p,
"user": user,
}.items()
if v is not None
}
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self.openai_client.chat.completions.create(**params) # type: ignore
async def batch_completion(

View file

@ -4,27 +4,60 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
# the models w/ "openai/" prefix are the litellm specific model names.
# they should be deprecated in favor of the canonical openai model names.
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openai/chatgpt-4o-latest",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
"gpt-4",
"gpt-4-turbo",
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-mini",
"gpt-4o-audio-preview",
"chatgpt-4o-latest",
"o1",
"o1-mini",
"o3-mini",
"o4-mini",
]
@dataclass
class EmbeddingModelInfo:
"""Structured representation of embedding model information."""
embedding_dimension: int
context_length: int
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"openai/text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"openai/text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
}
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
ProviderModelEntry(
provider_model_id="openai/text-embedding-3-small",
provider_model_id=model_id,
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1536, "context_length": 8192},
),
ProviderModelEntry(
provider_model_id="openai/text-embedding-3-large",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 3072, "context_length": 8192},
),
metadata={
"embedding_dimension": model_info.embedding_dimension,
"context_length": model_info.context_length,
},
)
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
]

View file

@ -4,12 +4,41 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
#
# This OpenAI adapter implements Inference methods using two clients -
#
# | Inference Method | Implementation Source |
# |----------------------------|--------------------------|
# | completion | LiteLLMOpenAIMixin |
# | chat_completion | LiteLLMOpenAIMixin |
# | embedding | LiteLLMOpenAIMixin |
# | batch_completion | LiteLLMOpenAIMixin |
# | batch_chat_completion | LiteLLMOpenAIMixin |
# | openai_completion | AsyncOpenAI |
# | openai_chat_completion | AsyncOpenAI |
#
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
@ -19,9 +48,120 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
provider_data_api_key_field="openai_api_key",
)
self.config = config
# we set is_openai_compat so users can use the canonical
# openai model names like "gpt-4" or "gpt-3.5-turbo"
# and the model name will be translated to litellm's
# "openai/gpt-4" or "openai/gpt-3.5-turbo" transparently.
# if we do not set this, users will be exposed to the
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
self._openai_client = AsyncOpenAI(
api_key=self.config.api_key,
)
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
if guided_choice is not None:
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
params = await prepare_openai_completion_params(
model=(await self.model_store.get_model(model)).provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._openai_client.completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
params = await prepare_openai_completion_params(
model=(await self.model_store.get_model(model)).provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self._openai_client.chat.completions.create(**params)

View file

@ -4,16 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.apis.inference import Inference
from .config import SambaNovaImplConfig
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
async def get_adapter_impl(config: SambaNovaImplConfig, _deps) -> Inference:
from .sambanova import SambaNovaInferenceAdapter
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"

View file

@ -6,25 +6,32 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str | None = Field(
default=None,
description="Sambanova Cloud API key",
)
@json_schema_type
class SambaNovaImplConfig(BaseModel):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: str | None = Field(
api_key: SecretStr | None = Field(
default=None,
description="The SambaNova.ai API Key",
description="The SambaNova cloud API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": "${env.SAMBANOVA_API_KEY}",
"api_key": api_key,
}

View file

@ -11,43 +11,43 @@ from llama_stack.providers.utils.inference.model_registry import (
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"Meta-Llama-3.1-8B-Instruct",
"sambanova/Meta-Llama-3.1-8B-Instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.1-70B-Instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.1-405B-Instruct",
"sambanova/Meta-Llama-3.1-405B-Instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.2-1B-Instruct",
"sambanova/Meta-Llama-3.2-1B-Instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.2-3B-Instruct",
"sambanova/Meta-Llama-3.2-3B-Instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.3-70B-Instruct",
"sambanova/Meta-Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"Llama-3.2-11B-Vision-Instruct",
"sambanova/Llama-3.2-11B-Vision-Instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"Llama-3.2-90B-Vision-Instruct",
"sambanova/Llama-3.2-90B-Vision-Instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"Llama-4-Scout-17B-16E-Instruct",
"sambanova/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]

View file

@ -5,305 +5,249 @@
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator
from collections.abc import Iterable
from openai import OpenAI
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
from openai.types.chat import (
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import (
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
from openai.types.chat import (
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
)
from openai.types.chat import (
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
GreedySamplingStrategy,
Inference,
LogProbConfig,
JsonSchemaResponseFormat,
Message,
ResponseFormat,
SamplingParams,
StopReason,
SystemMessage,
TextTruncation,
ToolCall,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ToolResponseMessage,
TopKSamplingStrategy,
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
convert_tooldef_to_openai_tool,
get_sampling_options,
)
from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url
from .config import SambaNovaImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class SambaNovaInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config
async def initialize(self) -> None:
return
async def convert_message_to_openai_dict_with_b64_images(
message: Message | dict,
) -> OpenAIChatCompletionMessage:
"""
Convert a Message to an OpenAI API-compatible dictionary.
"""
# users can supply a dict instead of a Message object, we'll
# convert it to a Message object and proceed with some type safety.
if isinstance(message, dict):
if "role" not in message:
raise ValueError("role is required in message")
if message["role"] == "user":
message = UserMessage(**message)
elif message["role"] == "assistant":
message = CompletionMessage(**message)
elif message["role"] == "tool":
message = ToolResponseMessage(**message)
elif message["role"] == "system":
message = SystemMessage(**message)
else:
raise ValueError(f"Unsupported message role: {message['role']}")
async def shutdown(self) -> None:
pass
def _get_client(self) -> OpenAI:
return OpenAI(base_url=self.config.url, api_key=self.config.api_key)
async def completion(
self,
model_id: str,
# Map Llama Stack spec to OpenAI spec -
# str -> str
# {"type": "text", "text": ...} -> {"type": "text", "text": ...}
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
# List[...] -> List[...]
async def _convert_message_content(
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = ToolPromptFormat.json,
stream: bool | None = False,
tool_config: ToolConfig | None = None,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
request_sambanova = await self.convert_chat_completion_request(request)
if stream:
return self._stream_chat_completion(request_sambanova)
else:
return await self._nonstream_chat_completion(request_sambanova)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
response = self._get_client().chat.completions.create(**request)
choice = response.choices[0]
result = ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "",
stop_reason=self.convert_to_sambanova_finish_reason(choice.finish_reason),
tool_calls=self.convert_to_sambanova_tool_calls(choice.message.tool_calls),
),
logprobs=None,
)
return result
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def _to_async_generator():
streaming = self._get_client().chat.completions.create(**request)
for chunk in streaming:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def convert_chat_completion_request(self, request: ChatCompletionRequest) -> dict:
compatible_request = self.convert_sampling_params(request.sampling_params)
compatible_request["model"] = request.model
compatible_request["messages"] = await self.convert_to_sambanova_messages(request.messages)
compatible_request["stream"] = request.stream
compatible_request["logprobs"] = False
compatible_request["extra_headers"] = {
b"User-Agent": b"llama-stack: sambanova-inference-adapter",
}
compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools)
return compatible_request
def convert_sampling_params(self, sampling_params: SamplingParams, legacy: bool = False) -> dict:
params = {}
if sampling_params:
params["frequency_penalty"] = sampling_params.repetition_penalty
if sampling_params.max_tokens:
if legacy:
params["max_tokens"] = sampling_params.max_tokens
else:
params["max_completion_tokens"] = sampling_params.max_tokens
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
params["top_p"] = sampling_params.strategy.top_p
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
params["extra_body"]["top_k"] = sampling_params.strategy.top_k
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
params["temperature"] = 0.0
return params
async def convert_to_sambanova_messages(self, messages: list[Message]) -> list[dict]:
conversation = []
for message in messages:
content = {}
content["content"] = await self.convert_to_sambanova_content(message)
if isinstance(message, UserMessage):
content["role"] = "user"
elif isinstance(message, CompletionMessage):
content["role"] = "assistant"
tools = []
for tool_call in message.tool_calls:
tools.append(
{
"id": tool_call.call_id,
"function": {
"name": tool_call.name,
"arguments": json.dumps(tool_call.arguments),
},
"type": "function",
}
)
content["tool_calls"] = tools
elif isinstance(message, ToolResponseMessage):
content["role"] = "tool"
content["tool_call_id"] = message.call_id
elif isinstance(message, SystemMessage):
content["role"] = "system"
conversation.append(content)
return conversation
async def convert_to_sambanova_content(self, message: Message) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
url = await convert_image_content_to_url(content, download=True)
# A fix to make sure the call sucess.
components = url.split(";base64")
url = f"{components[0].lower()};base64{components[1]}"
return {
"type": "image_url",
"image_url": {"url": url},
}
) -> str | Iterable[OpenAIChatCompletionContentPartParam]:
async def impl(
content_: InterleavedContent,
) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str):
return content_
elif isinstance(content_, TextContentItem):
return OpenAIChatCompletionContentPartTextParam(
type="text",
text=content_.text,
)
elif isinstance(content_, ImageContentItem):
return OpenAIChatCompletionContentPartImageParam(
type="image_url",
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_, download=True)),
)
elif isinstance(content_, list):
return [await impl(item) for item in content_]
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {"type": "text", "text": text}
raise ValueError(f"Unsupported content type: {type(content_)}")
if isinstance(message.content, list):
# If it is a list, the text content should be wrapped in dict
content = [await _convert_content(c) for c in message.content]
ret = await impl(content)
# OpenAI*Message expects a str or list
if isinstance(ret, str) or isinstance(ret, list):
return ret
else:
content = message.content
return [ret]
return content
out: OpenAIChatCompletionMessage = None
if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage(
role="user",
content=await _convert_message_content(message.content),
)
elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=await _convert_message_content(message.content),
tool_calls=[
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
arguments=json.dumps(tool.arguments),
),
type="function",
)
for tool in message.tool_calls
]
or None,
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
tool_call_id=message.call_id,
content=await _convert_message_content(message.content),
)
elif isinstance(message, SystemMessage):
out = OpenAIChatCompletionSystemMessage(
role="system",
content=await _convert_message_content(message.content),
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
def convert_to_sambanova_tool(self, tools: list[ToolDefinition]) -> list[dict]:
if tools is None:
return tools
return out
compatiable_tools = []
for tool in tools:
properties = {}
compatiable_required = []
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
if tool_param.description:
properties[tool_key]["description"] = tool_param.description
if tool_param.default:
properties[tool_key]["default"] = tool_param.default
if tool_param.required:
compatiable_required.append(tool_key)
class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
_config: SambaNovaImplConfig
compatiable_tool = {
"type": "function",
"function": {
"name": tool.tool_name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": compatiable_required,
},
def __init__(self, config: SambaNovaImplConfig):
self.config = config
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=self.config.api_key,
provider_data_api_key_field="sambanova_api_key",
)
def _get_api_key(self) -> str:
config_api_key = self.config.api_key if self.config.api_key else None
if config_api_key:
return config_api_key.get_secret_value()
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.sambanova_api_key:
raise ValueError(
'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }'
)
return provider_data.sambanova_api_key
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
input_dict["messages"] = [await convert_message_to_openai_dict_with_b64_images(m) for m in request.messages]
if fmt := request.response_format:
if not isinstance(fmt, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
fmt = fmt.json_schema
name = fmt["title"]
del fmt["title"]
fmt["additionalProperties"] = False
# Apply additionalProperties: False recursively to all objects
fmt = self._add_additional_properties_recursive(fmt)
input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt,
"strict": True,
},
}
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice:
input_dict["tool_choice"] = (
request.tool_config.tool_choice.value
if isinstance(request.tool_config.tool_choice, ToolChoice)
else request.tool_config.tool_choice
)
compatiable_tools.append(compatiable_tool)
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self._get_api_key()
if len(compatiable_tools) > 0:
return compatiable_tools
return None
def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason:
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
"model": request.model,
"api_key": api_key,
"api_base": self.config.url,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
def convert_to_sambanova_tool_calls(
self,
tool_calls,
) -> list[ToolCall]:
if not tool_calls:
return []
async def initialize(self):
await super().initialize()
compitable_tool_calls = [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]
return compitable_tool_calls
async def shutdown(self):
await super().shutdown()

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import SambaNovaCompatConfig
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .sambanova import SambaNovaCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import TogetherCompatConfig
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .together import TogetherCompatInferenceAdapter

View file

@ -158,27 +158,28 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
}.get(finish_reason, StopReason.end_of_turn)
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator:
event_type = ChatCompletionResponseEventType.start
tool_call_buf = UnparseableToolCall()
async for chunk in stream:
if not chunk.choices:
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
continue
choice = chunk.choices[0]
if choice.finish_reason:
args_str = tool_call_buf.arguments
args = None
try:
args = {} if not args_str else json.loads(args_str)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
if args:
yield ChatCompletionResponseStreamChunk(
def _process_vllm_chat_completion_end_of_stream(
finish_reason: str | None,
last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType,
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
) -> list[OpenAIChatCompletionChunk]:
chunks = []
if finish_reason is not None:
stop_reason = _convert_to_vllm_finish_reason(finish_reason)
else:
stop_reason = StopReason.end_of_message
tool_call_bufs = tool_call_bufs or {}
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
args_str = tool_call_buf.arguments or "{}"
try:
args = json.loads(args_str)
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
event_type=current_event_type,
delta=ToolCallDelta(
tool_call=ToolCall(
call_id=tool_call_buf.call_id,
@ -190,8 +191,12 @@ async def _process_vllm_chat_completion_stream_response(
),
)
)
elif args_str:
yield ChatCompletionResponseStreamChunk(
)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
@ -200,21 +205,62 @@ async def _process_vllm_chat_completion_stream_response(
),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
)
)
elif choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
# TODO: remove str() when dict type for 'arguments' is no longer allowed
tool_call_buf.arguments += str(tool_call.arguments)
else:
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=last_chunk_content or ""),
logprobs=None,
stop_reason=stop_reason,
)
)
)
return chunks
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
event_type = ChatCompletionResponseEventType.progress
tool_call_bufs: dict[str, UnparseableToolCall] = {}
end_of_stream_processed = False
async for chunk in stream:
if not chunk.choices:
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
return
choice = chunk.choices[0]
if choice.delta.tool_calls:
for delta_tool_call in choice.delta.tool_calls:
tool_call = convert_tool_call(delta_tool_call)
if delta_tool_call.index not in tool_call_bufs:
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
tool_call_buf = tool_call_bufs[delta_tool_call.index]
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += (
tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
)
if choice.finish_reason:
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=choice.finish_reason,
last_chunk_content=choice.delta.content,
current_event_type=event_type,
tool_call_bufs=tool_call_bufs,
)
for c in chunks:
yield c
end_of_stream_processed = True
elif not choice.delta.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
@ -224,6 +270,17 @@ async def _process_vllm_chat_completion_stream_response(
)
event_type = ChatCompletionResponseEventType.progress
if end_of_stream_processed:
return
# the stream ended without a chunk containing finish_reason - we have to generate the
# respective completion chunks manually
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs
)
for c in chunks:
yield c
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
@ -272,6 +329,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -302,6 +361,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice

View file

@ -26,8 +26,7 @@ from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = logging.getLogger(__name__)
ChromaClientType = chromadb.AsyncHttpClient | chromadb.PersistentClient
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
# this is a helper to allow us to use async and non-async chroma clients interchangeably

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,55 +0,0 @@
inference:
tests:
- inference/test_vision_inference.py::test_vision_chat_completion_streaming
- inference/test_vision_inference.py::test_vision_chat_completion_non_streaming
- inference/test_text_inference.py::test_structured_output
- inference/test_text_inference.py::test_chat_completion_streaming
- inference/test_text_inference.py::test_chat_completion_non_streaming
- inference/test_text_inference.py::test_chat_completion_with_tool_calling
- inference/test_text_inference.py::test_chat_completion_with_tool_calling_streaming
scenarios:
- provider_fixtures:
inference: ollama
- fixture_combo_id: fireworks
- provider_fixtures:
inference: together
# - inference: tgi
# - inference: vllm_remote
inference_models:
- meta-llama/Llama-3.1-8B-Instruct
- meta-llama/Llama-3.2-11B-Vision-Instruct
agents:
tests:
- agents/test_agents.py::test_agent_turns_with_safety
- agents/test_agents.py::test_rag_agent
scenarios:
- fixture_combo_id: ollama
- fixture_combo_id: together
- fixture_combo_id: fireworks
inference_models:
- meta-llama/Llama-3.2-1B-Instruct
safety_shield: meta-llama/Llama-Guard-3-1B
memory:
tests:
- memory/test_memory.py::test_query_documents
scenarios:
- fixture_combo_id: ollama
- provider_fixtures:
inference: sentence_transformers
memory: faiss
- fixture_combo_id: chroma
inference_models:
- meta-llama/Llama-3.2-1B-Instruct
embedding_model: all-MiniLM-L6-v2

View file

@ -1,296 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from collections import defaultdict
from pathlib import Path
from typing import Any
import pytest
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from termcolor import colored
from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.datatypes import RemoteProviderConfig
from .env import get_env_or_fail
from .report import Report
class ProviderFixture(BaseModel):
providers: list[Provider]
provider_data: dict[str, Any] | None = None
class TestScenario(BaseModel):
# provider fixtures can be either a mark or a dictionary of api -> providers
provider_fixtures: dict[str, str] = Field(default_factory=dict)
fixture_combo_id: str | None = None
class APITestConfig(BaseModel):
scenarios: list[TestScenario] = Field(default_factory=list)
inference_models: list[str] = Field(default_factory=list)
# test name format should be <relative_path.py>::<test_name>
tests: list[str] = Field(default_factory=list)
class MemoryApiTestConfig(APITestConfig):
embedding_model: str | None = Field(default_factory=None)
class AgentsApiTestConfig(APITestConfig):
safety_shield: str | None = Field(default_factory=None)
class TestConfig(BaseModel):
inference: APITestConfig | None = None
agents: AgentsApiTestConfig | None = None
memory: MemoryApiTestConfig | None = None
def get_test_config_from_config_file(metafunc_config):
config_file = metafunc_config.getoption("--config")
if config_file is None:
return None
config_file_path = Path(__file__).parent / config_file
if not config_file_path.exists():
raise ValueError(
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
)
with open(config_file_path) as config_file:
config = yaml.safe_load(config_file)
return TestConfig(**config)
def get_test_config_for_api(metafunc_config, api):
test_config = get_test_config_from_config_file(metafunc_config)
if test_config is None:
return None
return getattr(test_config, api)
def get_provider_fixture_overrides_from_test_config(metafunc_config, api, default_provider_fixture_combinations):
api_config = get_test_config_for_api(metafunc_config, api)
if api_config is None:
return None
fixture_combo_ids = set()
custom_provider_fixture_combos = []
for scenario in api_config.scenarios:
if scenario.fixture_combo_id:
fixture_combo_ids.add(scenario.fixture_combo_id)
else:
custom_provider_fixture_combos.append(
pytest.param(
scenario.provider_fixtures,
id=scenario.provider_fixtures.get("inference") or "",
)
)
if len(fixture_combo_ids) > 0:
for default_fixture in default_provider_fixture_combinations:
if default_fixture.id in fixture_combo_ids:
custom_provider_fixture_combos.append(default_fixture)
return custom_provider_fixture_combos
def remote_stack_fixture() -> ProviderFixture:
if url := os.getenv("REMOTE_STACK_URL", None):
config = RemoteProviderConfig.from_url(url)
else:
config = RemoteProviderConfig(
host=get_env_or_fail("REMOTE_STACK_HOST"),
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
)
return ProviderFixture(
providers=[
Provider(
provider_id="test::remote",
provider_type="test::remote",
config=config.model_dump(),
)
],
)
def pytest_configure(config):
config.option.tbstyle = "short"
config.option.disable_warnings = True
"""Load environment variables at start of test run"""
# Load from .env file if it exists
env_file = Path(__file__).parent / ".env"
if env_file.exists():
load_dotenv(env_file)
# Load any environment variables passed via --env
env_vars = config.getoption("--env") or []
for env_var in env_vars:
key, value = env_var.split("=", 1)
os.environ[key] = value
if config.getoption("--output") is not None:
config.pluginmanager.register(Report(config.getoption("--output")))
def pytest_addoption(parser):
parser.addoption(
"--providers",
default="",
help=(
"Provider configuration in format: api1=provider1,api2=provider2. "
"Example: --providers inference=ollama,safety=meta-reference"
),
)
parser.addoption(
"--config",
action="store",
help="Set test config file (supported format: YAML), e.g. --config=test_config.yml",
)
parser.addoption(
"--output",
action="store",
help="Set output file for test report, e.g. --output=pytest_report.md",
)
"""Add custom command line options"""
parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value")
parser.addoption(
"--inference-model",
action="store",
default="meta-llama/Llama-3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
parser.addoption(
"--safety-shield",
action="store",
default="meta-llama/Llama-Guard-3-1B",
help="Specify the safety shield to use for testing",
)
parser.addoption(
"--embedding-model",
action="store",
default=None,
help="Specify the embedding model to use for testing",
)
parser.addoption(
"--judge-model",
action="store",
default="meta-llama/Llama-3.1-8B-Instruct",
help="Specify the judge model to use for testing",
)
def make_provider_id(providers: dict[str, str]) -> str:
return ":".join(f"{api}={provider}" for api, provider in sorted(providers.items()))
def get_provider_marks(providers: dict[str, str]) -> list[Any]:
marks = []
for provider in providers.values():
marks.append(getattr(pytest.mark, provider))
return marks
def get_provider_fixture_overrides(config, available_fixtures: dict[str, list[str]]) -> list[pytest.param] | None:
provider_str = config.getoption("--providers")
if not provider_str:
return None
fixture_dict = parse_fixture_string(provider_str, available_fixtures)
return [
pytest.param(
fixture_dict,
id=make_provider_id(fixture_dict),
marks=get_provider_marks(fixture_dict),
)
]
def parse_fixture_string(provider_str: str, available_fixtures: dict[str, list[str]]) -> dict[str, str]:
"""Parse provider string of format 'api1=provider1,api2=provider2'"""
if not provider_str:
return {}
fixtures = {}
pairs = provider_str.split(",")
for pair in pairs:
if "=" not in pair:
raise ValueError(f"Invalid provider specification: {pair}. Expected format: api=provider")
api, fixture = pair.split("=")
if api not in available_fixtures:
raise ValueError(f"Unknown API: {api}. Available APIs: {list(available_fixtures.keys())}")
if fixture not in available_fixtures[api]:
raise ValueError(
f"Unknown provider '{fixture}' for API '{api}'. Available providers: {list(available_fixtures[api])}"
)
fixtures[api] = fixture
# Check that all provided APIs are supported
for api in available_fixtures.keys():
if api not in fixtures:
raise ValueError(
f"Missing provider fixture for API '{api}'. Available providers: {list(available_fixtures[api])}"
)
return fixtures
def pytest_itemcollected(item):
# Get all markers as a list
filtered = ("asyncio", "parametrize")
marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered]
if marks:
marks = colored(",".join(marks), "yellow")
item.name = f"{item.name}[{marks}]"
def pytest_collection_modifyitems(session, config, items):
test_config = get_test_config_from_config_file(config)
if test_config is None:
return
required_tests = defaultdict(set)
for api_test_config in [
test_config.inference,
test_config.memory,
test_config.agents,
]:
if api_test_config is None:
continue
for test in api_test_config.tests:
arr = test.split("::")
if len(arr) != 2:
raise ValueError(f"Invalid format for test name {test}")
test_path, func_name = arr
required_tests[Path(__file__).parent / test_path].add(func_name)
new_items, deselected_items = [], []
for item in items:
func_name = getattr(item, "originalname", item.name)
if func_name in required_tests[item.fspath]:
new_items.append(item)
continue
deselected_items.append(item)
items[:] = new_items
config.hook.pytest_deselected(items=deselected_items)
pytest_plugins = [
"llama_stack.providers.tests.inference.fixtures",
"llama_stack.providers.tests.safety.fixtures",
"llama_stack.providers.tests.vector_io.fixtures",
"llama_stack.providers.tests.agents.fixtures",
"llama_stack.providers.tests.datasetio.fixtures",
"llama_stack.providers.tests.scoring.fixtures",
"llama_stack.providers.tests.eval.fixtures",
"llama_stack.providers.tests.post_training.fixtures",
"llama_stack.providers.tests.tools.fixtures",
]

View file

@ -1,176 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections import defaultdict
from pathlib import Path
import pytest
from pytest import ExitCode
from pytest_html.basereport import _process_outcome
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.models.llama.sku_types import CoreModelId
INFERENCE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
SUPPORTED_MODELS = {
"ollama": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
},
"fireworks": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
},
"together": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
},
}
class Report:
def __init__(self, output_path):
valid_file_format = (
output_path.split(".")[1] in ["md", "markdown"] if len(output_path.split(".")) == 2 else False
)
if not valid_file_format:
raise ValueError(f"Invalid output file {output_path}. Markdown file is required")
self.output_path = output_path
self.test_data = defaultdict(dict)
self.inference_tests = defaultdict(dict)
@pytest.hookimpl
def pytest_runtest_logreport(self, report):
# This hook is called in several phases, including setup, call and teardown
# The test is considered failed / error if any of the outcomes is not "Passed"
outcome = _process_outcome(report)
data = {
"outcome": report.outcome,
"longrepr": report.longrepr,
"name": report.nodeid,
}
if report.nodeid not in self.test_data:
self.test_data[report.nodeid] = data
elif self.test_data[report.nodeid] != outcome and outcome != "Passed":
self.test_data[report.nodeid] = data
@pytest.hookimpl
def pytest_sessionfinish(self, session, exitstatus):
if exitstatus <= ExitCode.INTERRUPTED:
return
report = []
report.append("# Llama Stack Integration Test Results Report")
report.append("\n## Summary")
report.append("\n## Supported Models: ")
header = "| Model Descriptor |"
dividor = "|:---|"
for k in SUPPORTED_MODELS.keys():
header += f"{k} |"
dividor += ":---:|"
report.append(header)
report.append(dividor)
rows = []
for model in all_registered_models():
if "Instruct" not in model.core_model_id.value and "Guard" not in model.core_model_id.value:
continue
row = f"| {model.core_model_id.value} |"
for k in SUPPORTED_MODELS.keys():
if model.core_model_id.value in SUPPORTED_MODELS[k]:
row += " ✅ |"
else:
row += " ❌ |"
rows.append(row)
report.extend(rows)
report.append("\n### Tests:")
for provider in SUPPORTED_MODELS.keys():
if provider not in self.inference_tests:
continue
report.append(f"\n #### {provider}")
test_table = [
"| Area | Model | API | Functionality Test | Status |",
"|:-----|:-----|:-----|:-----|:-----|",
]
for api in INFERENCE_APIS:
tests = self.inference_tests[provider][api]
for test_nodeid in tests:
row = "|{area} | {model} | {api} | {test} | {result} ".format(
area="Text" if "text" in test_nodeid else "Vision",
model=("Llama-3.1-8B-Instruct" if "text" in test_nodeid else "Llama3.2-11B-Vision-Instruct"),
api=f"/{api}",
test=self.get_simple_function_name(test_nodeid),
result=("" if self.test_data[test_nodeid]["outcome"] == "passed" else ""),
)
test_table += [row]
report.extend(test_table)
report.append("\n")
output_file = Path(self.output_path)
output_file.write_text("\n".join(report))
print(f"\n Report generated: {output_file.absolute()}")
@pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(self, session, config, items):
for item in items:
inference = item.callspec.params.get("inference_stack")
if "inference" in item.nodeid:
func_name = getattr(item, "originalname", item.name)
for api in INFERENCE_APIS:
if api in func_name:
api_tests = self.inference_tests[inference].get(api, set())
api_tests.add(item.nodeid)
self.inference_tests[inference][api] = api_tests
def get_simple_function_name(self, nodeid):
"""Extract function name from nodeid.
Examples:
- 'tests/test_math.py::test_addition' -> 'test_addition'
- 'tests/test_math.py::TestClass::test_method' -> test_method'
"""
parts = nodeid.split("::")
func_name = nodeid # Fallback to full nodeid if pattern doesn't match
if len(parts) == 2: # Simple function
func_name = parts[1]
elif len(parts) == 3: # Class method
func_name = parts[2]
return func_name.split("[")[0]

View file

@ -19,7 +19,7 @@ from llama_stack.apis.inference import (
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
@ -59,9 +59,12 @@ logger = get_logger(name=__name__, category="inference")
class LiteLLMOpenAIMixin(
ModelRegistryHelper,
Inference,
InferenceProvider,
NeedsRequestProviderData,
):
# TODO: avoid exposing the litellm specific model names to the user.
# potential change: add a prefix param that gets added to the model name
# when calling litellm.
def __init__(
self,
model_entries,
@ -92,7 +95,9 @@ class LiteLLMOpenAIMixin(
return model
def get_litellm_model_name(self, model_id: str) -> str:
return "openai/" + model_id if self.is_openai_compat else model_id
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
# model_id.startswith("openai/") is for backwards compatibility.
return "openai/" + model_id if self.is_openai_compat and not model_id.startswith("openai/") else model_id
async def completion(
self,

View file

@ -531,13 +531,19 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
tool_name = tc.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
# arguments_json can be None, so attempt it first and fall back to arguments
if hasattr(tc, "arguments_json") and tc.arguments_json:
arguments = tc.arguments_json
else:
arguments = json.dumps(tc.arguments)
result["tool_calls"].append(
{
"id": tc.call_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
"arguments": arguments,
},
}
)

View file

@ -382,7 +382,7 @@ def augment_messages_for_tools_llama_3_1(
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:

View file

@ -16,4 +16,6 @@ class KVStore(Protocol):
async def delete(self, key: str) -> None: ...
async def range(self, start_key: str, end_key: str) -> list[str]: ...
async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...

View file

@ -26,9 +26,16 @@ class InmemoryKVStoreImpl(KVStore):
async def set(self, key: str, value: str) -> None:
self._store[key] = value
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
return [key for key in self._store.keys() if key >= start_key and key < end_key]
async def delete(self, key: str) -> None:
del self._store[key]
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis.value:

View file

@ -57,7 +57,7 @@ class MongoDBKVStoreImpl(KVStore):
key = self._namespaced_key(key)
await self.collection.delete_one({"key": key})
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {
@ -68,3 +68,10 @@ class MongoDBKVStoreImpl(KVStore):
async for doc in cursor:
result.append(doc["value"])
return result
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {"key": {"$gte": start_key, "$lt": end_key}}
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
return [doc["key"] for doc in cursor]

View file

@ -85,7 +85,7 @@ class PostgresKVStoreImpl(KVStore):
(key,),
)
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
@ -99,3 +99,13 @@ class PostgresKVStoreImpl(KVStore):
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]

View file

@ -42,7 +42,7 @@ class RedisKVStoreImpl(KVStore):
key = self._namespaced_key(key)
await self.redis.delete(key)
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
cursor = 0
@ -67,3 +67,10 @@ class RedisKVStoreImpl(KVStore):
]
return []
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
if not matching_keys:
return []
return [k.decode("utf-8") for k in matching_keys]

View file

@ -54,7 +54,7 @@ class SqliteKVStoreImpl(KVStore):
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit()
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
@ -65,3 +65,13 @@ class SqliteKVStoreImpl(KVStore):
_, value, _ = row
result.append(value)
return result
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
)
rows = await cursor.fetchall()
return [row[0] for row in rows]

View file

@ -118,45 +118,53 @@ async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
else:
return r.text
pattern = re.compile("^(https?://|file://|data:)")
if pattern.match(doc.content):
if doc.content.startswith("data:"):
return content_from_data(doc.content)
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text
elif isinstance(doc.content, str):
pattern = re.compile("^(https?://|file://|data:)")
if pattern.match(doc.content):
if doc.content.startswith("data:"):
return content_from_data(doc.content)
async with httpx.AsyncClient() as client:
r = await client.get(doc.content)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
else:
return r.text
return interleaved_content_as_str(doc.content)
return r.text
return doc.content
else:
# will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent
return interleaved_content_as_str(doc.content)
def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]:
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
) -> list[Chunk]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
try:
metadata_string = str(metadata)
except Exception as e:
raise ValueError("Failed to serialize metadata to string") from e
metadata_tokens = tokenizer.encode(metadata_string, bos=False, eos=False)
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunk_metadata = metadata.copy()
chunk_metadata["document_id"] = document_id
chunk_metadata["token_count"] = len(toks)
chunk_metadata["metadata_token_count"] = len(metadata_tokens)
# chunk is a string
chunks.append(
Chunk(
content=chunk,
metadata={
"token_count": len(toks),
"document_id": document_id,
},
metadata=chunk_metadata,
)
)

Some files were not shown because too many files have changed in this diff Show more