Merge branch 'main' into add-mongodb-vector_io

This commit is contained in:
Young Han 2025-11-11 11:13:23 -08:00 committed by GitHub
commit 5e9d28f0b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
1791 changed files with 125464 additions and 386541 deletions

View file

@ -3,8 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.library_client import ( # noqa: F401
AsyncLlamaStackAsLibraryClient,
LlamaStackAsLibraryClient,
)

View file

@ -87,6 +87,7 @@ class Agents(Protocol):
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
),
] = None,
max_tool_calls: int | None = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a model response.
@ -97,6 +98,7 @@ class Agents(Protocol):
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response.
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response.
:returns: An OpenAIResponseObject.
"""
...

View file

@ -403,7 +403,7 @@ class OpenAIResponseText(BaseModel):
# Must match type Literals of OpenAIResponseInputToolWebSearch below
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11", "web_search_2025_08_26"]
@json_schema_type
@ -415,9 +415,12 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
"""
# Must match values of WebSearchToolTypes above
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
"web_search"
)
type: (
Literal["web_search"]
| Literal["web_search_preview"]
| Literal["web_search_preview_2025_03_11"]
| Literal["web_search_2025_08_26"]
) = "web_search"
# TODO: actually use search_context_size somewhere...
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
# TODO: add user_location
@ -591,6 +594,7 @@ class OpenAIResponseObject(BaseModel):
:param truncation: (Optional) Truncation strategy applied to the response
:param usage: (Optional) Token usage information for the response
:param instructions: (Optional) System message inserted into the model's context
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response
"""
created_at: int
@ -612,6 +616,7 @@ class OpenAIResponseObject(BaseModel):
truncation: str | None = None
usage: OpenAIResponseUsage | None = None
instructions: str | None = None
max_tool_calls: int | None = None
@json_schema_type

View file

@ -74,7 +74,7 @@ class Benchmarks(Protocol):
"""
...
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA, deprecated=True)
async def register_benchmark(
self,
benchmark_id: str,
@ -95,7 +95,7 @@ class Benchmarks(Protocol):
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA, deprecated=True)
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark.

View file

@ -34,3 +34,44 @@ class PaginatedResponse(BaseModel):
data: list[dict[str, Any]]
has_more: bool
url: str | None = None
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be included with the response
# To do this, we will need to augment all response types with a metrics field.
# We have hit a blocker from stainless SDK that prevents us from doing this.
# The blocker is that if we were to augment the response types that have a data field
# in them like so
# class ListModelsResponse(BaseModel):
# metrics: Optional[List[MetricEvent]] = None
# data: List[Models]
# ...
# The client SDK will need to access the data by using a .data field, which is not
# ergonomic. Stainless SDK does support unwrapping the response type, but it
# requires that the response type to only have a single field.
# We will need a way in the client SDK to signal that the metrics are needed
# and if they are needed, the client SDK has to return the full response type
# without unwrapping it.
@json_schema_type
class MetricInResponse(BaseModel):
"""A metric value included in API responses.
:param metric: The name of the metric
:param value: The numeric value of the metric
:param unit: (Optional) The unit of measurement for the metric value
"""
metric: str
value: int | float
unit: str | None = None
class MetricResponseMixin(BaseModel):
"""Mixin class for API responses that can include metrics.
:param metrics: (Optional) List of metrics associated with the API response
"""
metrics: list[MetricInResponse] | None = None

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
def telemetry_traceable(cls):
"""
Mark a protocol for automatic tracing when telemetry is enabled.
This is a metadata-only decorator with no dependencies on core.
Actual tracing is applied by core routers at runtime if telemetry is enabled.
Usage:
@runtime_checkable
@telemetry_traceable
class MyProtocol(Protocol):
...
"""
cls.__marked_for_tracing__ = True
return cls

View file

@ -6,26 +6,22 @@
from .conversations import (
Conversation,
ConversationCreateRequest,
ConversationDeletedResource,
ConversationItem,
ConversationItemCreateRequest,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
ConversationUpdateRequest,
Metadata,
)
__all__ = [
"Conversation",
"ConversationCreateRequest",
"ConversationDeletedResource",
"ConversationItem",
"ConversationItemCreateRequest",
"ConversationItemDeletedResource",
"ConversationItemList",
"Conversations",
"ConversationUpdateRequest",
"Metadata",
]

View file

@ -20,8 +20,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.common.tracing import telemetry_traceable
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
Metadata = dict[str, str]
@ -102,32 +102,6 @@ register_schema(ConversationItem, name="ConversationItem")
# ]
@json_schema_type
class ConversationCreateRequest(BaseModel):
"""Request body for creating a conversation."""
items: list[ConversationItem] | None = Field(
default=[],
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
metadata: Metadata | None = Field(
default={},
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
max_length=16,
)
@json_schema_type
class ConversationUpdateRequest(BaseModel):
"""Request body for updating a conversation."""
metadata: Metadata = Field(
...,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
)
@json_schema_type
class ConversationDeletedResource(BaseModel):
"""Response for deleted conversation."""
@ -183,7 +157,7 @@ class ConversationItemDeletedResource(BaseModel):
@runtime_checkable
@trace_protocol
@telemetry_traceable
class Conversations(Protocol):
"""Conversations

View file

@ -146,7 +146,7 @@ class ListDatasetsResponse(BaseModel):
class Datasets(Protocol):
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA, deprecated=True)
async def register_dataset(
self,
purpose: DatasetPurpose,
@ -235,7 +235,7 @@ class Datasets(Protocol):
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA, deprecated=True)
async def unregister_dataset(
self,
dataset_id: str,

View file

@ -11,8 +11,8 @@ from fastapi import File, Form, Response, UploadFile
from pydantic import BaseModel, Field
from llama_stack.apis.common.responses import Order
from llama_stack.apis.common.tracing import telemetry_traceable
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -102,7 +102,7 @@ class OpenAIFileDeleteResponse(BaseModel):
@runtime_checkable
@trace_protocol
@telemetry_traceable
class Files(Protocol):
"""Files

View file

@ -1,43 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
)
class LogEvent:
def __init__(
self,
content: str = "",
end: str = "\n",
color="white",
):
self.content = content
self.color = color
self.end = "\n" if end is None else end
def print(self, flush=True):
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
class EventLogger:
async def log(self, event_generator):
async for chunk in event_generator:
if isinstance(chunk, ChatCompletionResponseStreamChunk):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == ChatCompletionResponseEventType.progress:
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("")
else:
yield LogEvent("Assistant> ", color="cyan", end="")
yield LogEvent(chunk.completion_message.content, color="yellow")

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from collections.abc import AsyncIterator
from enum import Enum
from enum import Enum, StrEnum
from typing import (
Annotated,
Any,
@ -15,29 +15,18 @@ from typing import (
)
from fastapi import Body
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import Order
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.responses import (
Order,
)
from llama_stack.apis.common.tracing import telemetry_traceable
from llama_stack.apis.models import Model
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.core.telemetry.telemetry import MetricResponseMixin
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
register_schema(ToolCall)
register_schema(ToolDefinition)
from enum import StrEnum
@json_schema_type
class GreedySamplingStrategy(BaseModel):
@ -202,58 +191,6 @@ class ToolResponseMessage(BaseModel):
content: InterleavedContent
@json_schema_type
class CompletionMessage(BaseModel):
"""A message containing the model's (assistant) response in a chat conversation.
:param role: Must be "assistant" to identify this as the model's response
:param content: The content of the model's response
:param stop_reason: Reason why the model stopped generating. Options are:
- `StopReason.end_of_turn`: The model finished generating the entire response.
- `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response.
- `StopReason.out_of_tokens`: The model ran out of token budget.
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
"""
role: Literal["assistant"] = "assistant"
content: InterleavedContent
stop_reason: StopReason
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
Message = Annotated[
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
Field(discriminator="role"),
]
register_schema(Message, name="Message")
@json_schema_type
class ToolResponse(BaseModel):
"""Response from a tool invocation.
:param call_id: Unique identifier for the tool call this response is for
:param tool_name: Name of the tool that was invoked
:param content: The response content from the tool
:param metadata: (Optional) Additional metadata about the tool response
"""
call_id: str
tool_name: BuiltinTool | str
content: InterleavedContent
metadata: dict[str, Any] | None = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class ToolChoice(Enum):
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.
@ -290,22 +227,6 @@ class ChatCompletionResponseEventType(Enum):
progress = "progress"
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""An event during chat completion generation.
:param event_type: Type of the event
:param delta: Content generated since last event. This can be one or more tokens, or a tool call.
:param logprobs: Optional log probabilities for generated tokens
:param stop_reason: Optional reason why generation stopped, if complete
"""
event_type: ChatCompletionResponseEventType
delta: ContentDelta
logprobs: list[TokenLogProbs] | None = None
stop_reason: StopReason | None = None
class ResponseFormatType(StrEnum):
"""Types of formats for structured (guided) decoding.
@ -358,34 +279,6 @@ class CompletionRequest(BaseModel):
logprobs: LogProbConfig | None = None
@json_schema_type
class CompletionResponse(MetricResponseMixin):
"""Response from a completion request.
:param content: The generated completion text
:param stop_reason: Reason why generation stopped
:param logprobs: Optional log probabilities for generated tokens
"""
content: str
stop_reason: StopReason
logprobs: list[TokenLogProbs] | None = None
@json_schema_type
class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens.
:param stop_reason: Optional reason why generation stopped, if complete
:param logprobs: Optional log probabilities for generated tokens
"""
delta: str
stop_reason: StopReason | None = None
logprobs: list[TokenLogProbs] | None = None
class SystemMessageBehavior(Enum):
"""Config for how to override the default system prompt.
@ -399,70 +292,6 @@ class SystemMessageBehavior(Enum):
replace = "replace"
@json_schema_type
class ToolConfig(BaseModel):
"""Configuration for tool use.
:param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
:param system_message_behavior: (Optional) Config for how to override the default system prompt.
- `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt.
- `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
def model_post_init(self, __context: Any) -> None:
if isinstance(self.tool_choice, str):
try:
self.tool_choice = ToolChoice[self.tool_choice]
except KeyError:
pass
# This is an internally used class
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: str
messages: list[Message]
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
response_format: ResponseFormat | None = None
stream: bool | None = False
logprobs: LogProbConfig | None = None
@json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed chat completion response.
:param event: The event containing the new content
"""
event: ChatCompletionResponseEvent
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request.
:param completion_message: The complete response message
:param logprobs: Optional log probabilities for generated tokens
"""
completion_message: CompletionMessage
logprobs: list[TokenLogProbs] | None = None
@json_schema_type
class EmbeddingsResponse(BaseModel):
"""Response containing generated embeddings.
@ -1160,7 +989,7 @@ class OpenAIEmbeddingsRequestWithExtraBody(BaseModel, extra="allow"):
@runtime_checkable
@trace_protocol
@telemetry_traceable
class InferenceProvider(Protocol):
"""
This protocol defines the interface that should be implemented by all inference providers.

View file

@ -76,7 +76,7 @@ class Inspect(Protocol):
List all available API routes with their methods and implementing providers.
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns all non-deprecated routes.
:returns: Response containing information about all available routes.
"""
...

View file

@ -9,9 +9,9 @@ from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field, field_validator
from llama_stack.apis.common.tracing import telemetry_traceable
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -105,7 +105,7 @@ class OpenAIListModelsResponse(BaseModel):
@runtime_checkable
@trace_protocol
@telemetry_traceable
class Models(Protocol):
async def list_models(self) -> ListModelsResponse:
"""List all models.
@ -136,7 +136,7 @@ class Models(Protocol):
"""
...
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
async def register_model(
self,
model_id: str,
@ -158,7 +158,7 @@ class Models(Protocol):
"""
...
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
async def unregister_model(
self,
model_id: str,

View file

@ -10,8 +10,8 @@ from typing import Protocol, runtime_checkable
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.common.tracing import telemetry_traceable
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -92,7 +92,7 @@ class ListPromptsResponse(BaseModel):
@runtime_checkable
@trace_protocol
@telemetry_traceable
class Prompts(Protocol):
"""Prompts

View file

@ -9,10 +9,10 @@ from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.common.tracing import telemetry_traceable
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.shields import Shield
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -94,7 +94,7 @@ class ShieldStore(Protocol):
@runtime_checkable
@trace_protocol
@telemetry_traceable
class Safety(Protocol):
"""Safety

View file

@ -178,7 +178,7 @@ class ScoringFunctions(Protocol):
"""
...
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
async def register_scoring_function(
self,
scoring_fn_id: str,
@ -199,7 +199,9 @@ class ScoringFunctions(Protocol):
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
@webmethod(
route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
"""Unregister a scoring function.

View file

@ -8,9 +8,9 @@ from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.common.tracing import telemetry_traceable
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -48,7 +48,7 @@ class ListShieldsResponse(BaseModel):
@runtime_checkable
@trace_protocol
@telemetry_traceable
class Shields(Protocol):
@webmethod(route="/shields", method="GET", level=LLAMA_STACK_API_V1)
async def list_shields(self) -> ListShieldsResponse:
@ -67,7 +67,7 @@ class Shields(Protocol):
"""
...
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
async def register_shield(
self,
shield_id: str,
@ -85,7 +85,7 @@ class Shields(Protocol):
"""
...
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1)
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
async def unregister_shield(self, identifier: str) -> None:
"""Unregister a shield.

View file

@ -5,18 +5,13 @@
# the root directory of this source tree.
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Protocol
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, field_validator
from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class RRFRanker(BaseModel):
"""
Reciprocal Rank Fusion (RRF) ranker configuration.
@ -30,7 +25,6 @@ class RRFRanker(BaseModel):
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
@json_schema_type
class WeightedRanker(BaseModel):
"""
Weighted ranker configuration that combines vector and keyword scores.
@ -55,10 +49,8 @@ Ranker = Annotated[
RRFRanker | WeightedRanker,
Field(discriminator="type"),
]
register_schema(Ranker, name="Ranker")
@json_schema_type
class RAGDocument(BaseModel):
"""
A document to be used for document ingestion in the RAG Tool.
@ -75,7 +67,6 @@ class RAGDocument(BaseModel):
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryResult(BaseModel):
"""Result of a RAG query containing retrieved content and metadata.
@ -87,7 +78,6 @@ class RAGQueryResult(BaseModel):
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryGenerator(Enum):
"""Types of query generators for RAG systems.
@ -101,7 +91,6 @@ class RAGQueryGenerator(Enum):
custom = "custom"
@json_schema_type
class RAGSearchMode(StrEnum):
"""
Search modes for RAG query retrieval:
@ -115,7 +104,6 @@ class RAGSearchMode(StrEnum):
HYBRID = "hybrid"
@json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the default RAG query generator.
@ -127,7 +115,6 @@ class DefaultRAGQueryGeneratorConfig(BaseModel):
separator: str = " "
@json_schema_type
class LLMRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the LLM-based RAG query generator.
@ -145,10 +132,8 @@ RAGQueryGeneratorConfig = Annotated[
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
Field(discriminator="type"),
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type
class RAGQueryConfig(BaseModel):
"""
Configuration for the RAG query generation.
@ -181,38 +166,3 @@ class RAGQueryConfig(BaseModel):
if len(v) == 0:
raise ValueError("chunk_template must not be empty")
return v
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert(
self,
documents: list[RAGDocument],
vector_store_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
"""Index documents so they can be used by the RAG system.
:param documents: List of documents to index in the RAG system
:param vector_store_id: ID of the vector database to store the document embeddings
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
"""
...
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
async def query(
self,
content: InterleavedContent,
vector_store_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent.
:param content: The query content to search for in the indexed documents
:param vector_store_ids: List of vector database IDs to search within
:param query_config: (Optional) Configuration parameters for the query operation
:returns: RAGQueryResult containing the retrieved content and metadata
"""
...

View file

@ -11,13 +11,11 @@ from pydantic import BaseModel
from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.common.tracing import telemetry_traceable
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from .rag_tool import RAGToolRuntime
@json_schema_type
class ToolDef(BaseModel):
@ -109,9 +107,9 @@ class ListToolDefsResponse(BaseModel):
@runtime_checkable
@trace_protocol
@telemetry_traceable
class ToolGroups(Protocol):
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
async def register_tool_group(
self,
toolgroup_id: str,
@ -169,7 +167,7 @@ class ToolGroups(Protocol):
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
async def unregister_toolgroup(
self,
toolgroup_id: str,
@ -191,12 +189,10 @@ class SpecialToolGroup(Enum):
@runtime_checkable
@trace_protocol
@telemetry_traceable
class ToolRuntime(Protocol):
tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
async def list_runtime_tools(

View file

@ -13,10 +13,10 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from fastapi import Body
from pydantic import BaseModel, Field
from llama_stack.apis.common.tracing import telemetry_traceable
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.strong_typing.schema import register_schema
@ -260,7 +260,7 @@ class VectorStoreSearchResponsePage(BaseModel):
"""
object: str = "vector_store.search_results.page"
search_query: str
search_query: list[str]
data: list[VectorStoreSearchResponse]
has_more: bool = False
next_page: str | None = None
@ -396,19 +396,19 @@ class VectorStoreListFilesResponse(BaseModel):
@json_schema_type
class VectorStoreFileContentsResponse(BaseModel):
"""Response from retrieving the contents of a vector store file.
class VectorStoreFileContentResponse(BaseModel):
"""Represents the parsed content of a vector store file.
:param file_id: Unique identifier for the file
:param filename: Name of the file
:param attributes: Key-value attributes associated with the file
:param content: List of content items from the file
:param object: The object type, which is always `vector_store.file_content.page`
:param data: Parsed content of the file
:param has_more: Indicates if there are more content pages to fetch
:param next_page: The token for the next page, if any
"""
file_id: str
filename: str
attributes: dict[str, Any]
content: list[VectorStoreContent]
object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page"
data: list[VectorStoreContent]
has_more: bool
next_page: str | None = None
@json_schema_type
@ -478,7 +478,7 @@ class OpenAICreateVectorStoreRequestWithExtraBody(BaseModel, extra="allow"):
name: str | None = None
file_ids: list[str] | None = None
expires_after: dict[str, Any] | None = None
chunking_strategy: dict[str, Any] | None = None
chunking_strategy: VectorStoreChunkingStrategy | None = None
metadata: dict[str, Any] | None = None
@ -502,7 +502,7 @@ class VectorStoreTable(Protocol):
@runtime_checkable
@trace_protocol
@telemetry_traceable
class VectorIO(Protocol):
vector_store_table: VectorStoreTable | None = None
@ -732,12 +732,12 @@ class VectorIO(Protocol):
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
) -> VectorStoreFileContentResponse:
"""Retrieves the contents of a vector store file.
:param vector_store_id: The ID of the vector store containing the file to retrieve.
:param file_id: The ID of the file to retrieve.
:returns: A list of InterleavedContent representing the file contents.
:returns: A VectorStoreFileContentResponse representing the file contents.
"""
...

View file

@ -46,6 +46,10 @@ class StackListDeps(Subcommand):
def _run_stack_list_deps_command(self, args: argparse.Namespace) -> None:
# always keep implementation completely silo-ed away from CLI so CLI
# can be fast to load and reduces dependencies
if not args.config and not args.providers:
self.parser.print_help()
self.parser.exit()
from ._list_deps import run_stack_list_deps_command
return run_stack_list_deps_command(args)

View file

@ -9,48 +9,69 @@ from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
class StackListBuilds(Subcommand):
"""List built stacks in .llama/distributions directory"""
"""List available distributions (both built-in and custom)"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama stack list",
description="list the build stacks",
description="list available distributions",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._list_stack_command)
def _get_distribution_dirs(self) -> dict[str, Path]:
"""Return a dictionary of distribution names and their paths"""
distributions = {}
dist_dir = Path.home() / ".llama" / "distributions"
def _get_distribution_dirs(self) -> dict[str, tuple[Path, str]]:
"""Return a dictionary of distribution names and their paths with source type
Returns:
dict mapping distro name to (path, source_type) where source_type is 'built-in' or 'custom'
"""
distributions = {}
# Get built-in distributions from source code
distro_dir = Path(__file__).parent.parent.parent / "distributions"
if distro_dir.exists():
for stack_dir in distro_dir.iterdir():
if stack_dir.is_dir() and not stack_dir.name.startswith(".") and not stack_dir.name.startswith("__"):
distributions[stack_dir.name] = (stack_dir, "built-in")
# Get custom/run distributions from ~/.llama/distributions
# These override built-in ones if they have the same name
if DISTRIBS_BASE_DIR.exists():
for stack_dir in DISTRIBS_BASE_DIR.iterdir():
if stack_dir.is_dir() and not stack_dir.name.startswith("."):
# Clean up the name (remove llamastack- prefix if present)
name = stack_dir.name.replace("llamastack-", "")
distributions[name] = (stack_dir, "custom")
if dist_dir.exists():
for stack_dir in dist_dir.iterdir():
if stack_dir.is_dir():
distributions[stack_dir.name] = stack_dir
return distributions
def _list_stack_command(self, args: argparse.Namespace) -> None:
distributions = self._get_distribution_dirs()
if not distributions:
print("No stacks found in ~/.llama/distributions")
print("No distributions found")
return
headers = ["Stack Name", "Path"]
headers.extend(["Build Config", "Run Config"])
headers = ["Stack Name", "Source", "Path", "Build Config", "Run Config"]
rows = []
for name, path in distributions.items():
row = [name, str(path)]
for name, (path, source_type) in sorted(distributions.items()):
row = [name, source_type, str(path)]
# Check for build and run config files
build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No"
run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No"
# For built-in distributions, configs are named build.yaml and run.yaml
# For custom distributions, configs are named {name}-build.yaml and {name}-run.yaml
if source_type == "built-in":
build_config = "Yes" if (path / "build.yaml").exists() else "No"
run_config = "Yes" if (path / "run.yaml").exists() else "No"
else:
build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No"
run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No"
row.extend([build_config, run_config])
rows.append(row)
print_table(rows, headers, separate_rows=True)

View file

@ -253,7 +253,7 @@ class StackRun(Subcommand):
)
return
ui_dir = REPO_ROOT / "llama_stack" / "ui"
ui_dir = REPO_ROOT / "llama_stack_ui"
logs_dir = Path("~/.llama/ui/logs").expanduser()
try:
# Create logs directory if it doesn't exist

View file

@ -15,7 +15,6 @@ from llama_stack.apis.inspect import (
RouteInfo,
VersionInfo,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.routes import get_all_api_routes
@ -46,8 +45,8 @@ class DistributionInspectImpl(Inspect):
# Helper function to determine if a route should be included based on api_filter
def should_include_route(webmethod) -> bool:
if api_filter is None:
# Default: only non-deprecated v1 APIs
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
# Default: only non-deprecated APIs
return not webmethod.deprecated
elif api_filter == "deprecated":
# Special filter: show deprecated routes regardless of their actual level
return bool(webmethod.deprecated)

View file

@ -18,14 +18,21 @@ from typing import Any, TypeVar, Union, get_args, get_origin
import httpx
import yaml
from fastapi import Response as FastAPIResponse
from llama_stack_client import (
NOT_GIVEN,
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
)
try:
from llama_stack_client import (
NOT_GIVEN,
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
)
except ImportError as e:
raise ImportError(
"llama-stack-client is not installed. Please install it with `uv pip install llama-stack[client]`."
) from e
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint

View file

@ -397,6 +397,18 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
# Apply tracing if telemetry is enabled and any base class has __marked_for_tracing__ marker
if run_config.telemetry.enabled:
traced_classes = [
base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False)
]
if traced_classes:
from llama_stack.core.telemetry.trace_protocol import trace_protocol
for cls in traced_classes:
trace_protocol(cls)
protocols = api_protocol_map_for_compliance_check(run_config)
additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups

View file

@ -45,6 +45,7 @@ async def get_routing_table_impl(
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
await impl.initialize()
return impl
@ -92,5 +93,6 @@ async def get_auto_router_impl(
api_to_dep_impl["safety_config"] = run_config.safety
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize()
return impl

View file

@ -190,7 +190,7 @@ class InferenceRouter(Inference):
response = await provider.openai_completion(params)
response.model = request_model_id
if self.telemetry_enabled:
if self.telemetry_enabled and response.usage is not None:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
@ -253,7 +253,7 @@ class InferenceRouter(Inference):
if self.store:
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
if self.telemetry_enabled:
if self.telemetry_enabled and response.usage is not None:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,

View file

@ -6,7 +6,7 @@
from typing import Any
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
@ -52,7 +52,7 @@ class SafetyRouter(Safety):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")

View file

@ -8,14 +8,9 @@ from typing import Any
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolRuntime,
)
from llama_stack.log import get_logger
@ -26,36 +21,6 @@ logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
self,
content: InterleavedContent,
vector_store_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_store_ids, query_config)
async def insert(
self,
documents: list[RAGDocument],
vector_store_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
@ -63,11 +28,6 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
self.rag_tool = self.RagToolImpl(routing_table)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass

View file

@ -20,9 +20,11 @@ from llama_stack.apis.vector_io import (
SearchRankingOptions,
VectorIO,
VectorStoreChunkingStrategy,
VectorStoreChunkingStrategyStatic,
VectorStoreChunkingStrategyStaticConfig,
VectorStoreDeleteResponse,
VectorStoreFileBatchObject,
VectorStoreFileContentsResponse,
VectorStoreFileContentResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFilesListInBatchResponse,
@ -167,6 +169,13 @@ class VectorIORouter(VectorIO):
if embedding_dimension is not None:
params.model_extra["embedding_dimension"] = embedding_dimension
# Set chunking strategy explicitly if not provided
if params.chunking_strategy is None or params.chunking_strategy.type == "auto":
# actualize the chunking strategy to static
params.chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig()
)
return await provider.openai_create_vector_store(params)
async def openai_list_vector_stores(
@ -283,6 +292,8 @@ class VectorIORouter(VectorIO):
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
if chunking_strategy is None or chunking_strategy.type == "auto":
chunking_strategy = VectorStoreChunkingStrategyStatic(static=VectorStoreChunkingStrategyStaticConfig())
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
@ -327,7 +338,7 @@ class VectorIORouter(VectorIO):
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
) -> VectorStoreFileContentResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(

View file

@ -15,7 +15,7 @@ from llama_stack.apis.vector_io.vector_io import (
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileContentResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
@ -195,7 +195,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
) -> VectorStoreFileContentResponse:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(

View file

@ -13,7 +13,6 @@ from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.core.resolver import api_protocol_map
from llama_stack.schema_utils import WebMethod
@ -25,33 +24,16 @@ RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def toolgroup_protocol_map():
return {
SpecialToolGroup.rag_tool: RAGToolRuntime,
}
def get_all_api_routes(
external_apis: dict[Api, ExternalApiSpec] | None = None,
) -> dict[Api, list[tuple[Route, WebMethod]]]:
apis = {}
protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
routes = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods:
# Get all webmethods for this method (supports multiple decorators)
webmethods = getattr(method, "__webmethods__", [])

View file

@ -31,7 +31,7 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
@ -78,7 +78,6 @@ class LlamaStack(
Inspect,
ToolGroups,
ToolRuntime,
RAGToolRuntime,
Files,
Prompts,
Conversations,

View file

@ -163,47 +163,6 @@ class MetricEvent(EventCommon):
unit: str
@json_schema_type
class MetricInResponse(BaseModel):
"""A metric value included in API responses.
:param metric: The name of the metric
:param value: The numeric value of the metric
:param unit: (Optional) The unit of measurement for the metric value
"""
metric: str
value: int | float
unit: str | None = None
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be included with the response
# To do this, we will need to augment all response types with a metrics field.
# We have hit a blocker from stainless SDK that prevents us from doing this.
# The blocker is that if we were to augment the response types that have a data field
# in them like so
# class ListModelsResponse(BaseModel):
# metrics: Optional[List[MetricEvent]] = None
# data: List[Models]
# ...
# The client SDK will need to access the data by using a .data field, which is not
# ergonomic. Stainless SDK does support unwrapping the response type, but it
# requires that the response type to only have a single field.
# We will need a way in the client SDK to signal that the metrics are needed
# and if they are needed, the client SDK has to return the full response type
# without unwrapping it.
class MetricResponseMixin(BaseModel):
"""Mixin class for API responses that can include metrics.
:param metrics: (Optional) List of metrics associated with the API response
"""
metrics: list[MetricInResponse] | None = None
@json_schema_type
class StructuredLogType(Enum):
"""The type of structured log event payload.
@ -427,6 +386,7 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"counters": {},
"gauges": {},
"up_down_counters": {},
"histograms": {},
}
_global_lock = threading.Lock()
_TRACER_PROVIDER = None
@ -540,6 +500,16 @@ class Telemetry:
)
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["histograms"]:
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
name=name,
unit=unit,
description=f"Histogram for {name}",
)
return cast(metrics.Histogram, _GLOBAL_STORAGE["histograms"][name])
def _log_metric(self, event: MetricEvent) -> None:
# Add metric as an event to the current span
try:
@ -571,7 +541,16 @@ class Telemetry:
# Log to OpenTelemetry meter if available
if self.meter is None:
return
if isinstance(event.value, int):
# Use histograms for token-related metrics (per-request measurements)
# Use counters for other cumulative metrics
token_metrics = {"prompt_tokens", "completion_tokens", "total_tokens"}
if event.metric in token_metrics:
# Token metrics are per-request measurements, use histogram
histogram = self._get_or_create_histogram(event.metric, event.unit)
histogram.record(event.value, attributes=_clean_attributes(event.attributes))
elif isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=_clean_attributes(event.attributes))
elif isinstance(event.value, float):

View file

@ -129,6 +129,15 @@ def trace_protocol[T: type[Any]](cls: T) -> T:
else:
return sync_wrapper
# Wrap methods on the class itself (for classes applied at runtime)
# Skip if already wrapped (indicated by __wrapped__ attribute)
for name, method in vars(cls).items():
if inspect.isfunction(method) and not name.startswith("_"):
if not hasattr(method, "__wrapped__"):
wrapped = trace_method(method)
setattr(cls, name, wrapped) # noqa: B010
# Also set up __init_subclass__ for future subclasses
original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None))
def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807

View file

@ -1,11 +0,0 @@
# More info on playground configuration can be found here:
# https://llama-stack.readthedocs.io/en/latest/playground
FROM python:3.12-slim
WORKDIR /app
COPY . /app/
RUN /usr/local/bin/python -m pip install --upgrade pip && \
/usr/local/bin/pip3 install -r requirements.txt
EXPOSE 8501
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]

View file

@ -1,50 +0,0 @@
# (Experimental) LLama Stack UI
## Docker Setup
:warning: This is a work in progress.
## Developer Setup
1. Start up Llama Stack API server. More details [here](https://llamastack.github.io/latest/getting_started/index.htmll).
```
llama stack list-deps together | xargs -L1 uv pip install
llama stack run together
```
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
```bash
llama-stack-client datasets register \
--dataset-id "mmlu" \
--provider-id "huggingface" \
--url "https://huggingface.co/datasets/llamastack/evals" \
--metadata '{"path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train"}' \
--schema '{"input_query": {"type": "string"}, "expected_answer": {"type": "string", "chat_completion_input": {"type": "string"}}}'
```
```bash
llama-stack-client benchmarks register \
--eval-task-id meta-reference-mmlu \
--provider-id meta-reference \
--dataset-id mmlu \
--scoring-functions basic::regex_parser_multiple_choice_answer
```
3. Start Streamlit UI
```bash
uv run --with ".[ui]" streamlit run llama_stack.core/ui/app.py
```
## Environment Variables
| Environment Variable | Description | Default Value |
|----------------------------|------------------------------------|---------------------------|
| LLAMA_STACK_ENDPOINT | The endpoint for the Llama Stack | http://localhost:8321 |
| FIREWORKS_API_KEY | API key for Fireworks provider | (empty string) |
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,55 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
def main():
# Evaluation pages
application_evaluation_page = st.Page(
"page/evaluations/app_eval.py",
title="Evaluations (Scoring)",
icon="📊",
default=False,
)
native_evaluation_page = st.Page(
"page/evaluations/native_eval.py",
title="Evaluations (Generation + Scoring)",
icon="📊",
default=False,
)
# Playground pages
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
# Distribution pages
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
provider_page = st.Page(
"page/distribution/providers.py",
title="API Providers",
icon="🔍",
default=False,
)
pg = st.navigation(
{
"Playground": [
chat_page,
rag_page,
tool_page,
application_evaluation_page,
native_evaluation_page,
],
"Inspect": [provider_page, resources_page],
},
expanded=False,
)
pg.run()
if __name__ == "__main__":
main()

View file

@ -1,32 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from llama_stack_client import LlamaStackClient
class LlamaStackApi:
def __init__(self):
self.client = LlamaStackClient(
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
provider_data={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
"tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""),
},
)
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):
"""Run scoring on a single row"""
if not scoring_params:
scoring_params = dict.fromkeys(scoring_function_ids)
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
llama_stack_api = LlamaStackApi()

View file

@ -1,42 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import os
import pandas as pd
import streamlit as st
def process_dataset(file):
if file is None:
return "No file uploaded", None
try:
# Determine file type and read accordingly
file_ext = os.path.splitext(file.name)[1].lower()
if file_ext == ".csv":
df = pd.read_csv(file)
elif file_ext in [".xlsx", ".xls"]:
df = pd.read_excel(file)
else:
return "Unsupported file format. Please upload a CSV or Excel file.", None
return df
except Exception as e:
st.error(f"Error processing file: {str(e)}")
return None
def data_url_from_file(file) -> str:
file_content = file.getvalue()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type = file.type
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,18 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def datasets():
st.header("Datasets")
datasets_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()}
if len(datasets_info) > 0:
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
st.json(datasets_info[selected_dataset], expanded=True)

View file

@ -1,20 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def benchmarks():
# Benchmarks Section
st.header("Benchmarks")
benchmarks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.benchmarks.list()}
if len(benchmarks_info) > 0:
selected_benchmark = st.selectbox("Select an eval task", list(benchmarks_info.keys()), key="benchmark_inspect")
st.json(benchmarks_info[selected_benchmark], expanded=True)

View file

@ -1,18 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def models():
# Models Section
st.header("Models")
models_info = {m.id: m.model_dump() for m in llama_stack_api.client.models.list()}
selected_model = st.selectbox("Select a model", list(models_info.keys()))
st.json(models_info[selected_model])

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def providers():
st.header("🔍 API Providers")
apis_providers_lst = llama_stack_api.client.providers.list()
api_to_providers = {}
for api_provider in apis_providers_lst:
if api_provider.api in api_to_providers:
api_to_providers[api_provider.api].append(api_provider)
else:
api_to_providers[api_provider.api] = [api_provider]
for api in api_to_providers.keys():
st.markdown(f"###### {api}")
st.dataframe([x.to_dict() for x in api_to_providers[api]], width=500)
providers()

View file

@ -1,48 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from streamlit_option_menu import option_menu
from llama_stack.core.ui.page.distribution.datasets import datasets
from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.core.ui.page.distribution.models import models
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
from llama_stack.core.ui.page.distribution.shields import shields
def resources_page():
options = [
"Models",
"Shields",
"Scoring Functions",
"Datasets",
"Benchmarks",
]
icons = ["magic", "shield", "file-bar-graph", "database", "list-task"]
selected_resource = option_menu(
None,
options,
icons=icons,
orientation="horizontal",
styles={
"nav-link": {
"font-size": "12px",
},
},
)
if selected_resource == "Benchmarks":
benchmarks()
elif selected_resource == "Datasets":
datasets()
elif selected_resource == "Models":
models()
elif selected_resource == "Scoring Functions":
scoring_functions()
elif selected_resource == "Shields":
shields()
resources_page()

View file

@ -1,18 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def scoring_functions():
st.header("Scoring Functions")
scoring_functions_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.scoring_functions.list()}
selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys()))
st.json(scoring_functions_info[selected_scoring_function], expanded=True)

View file

@ -1,19 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def shields():
# Shields Section
st.header("Shields")
shields_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()}
selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
st.json(shields_info[selected_shield])

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,143 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pandas as pd
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
from llama_stack.core.ui.modules.utils import process_dataset
def application_evaluation_page():
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Scoring)")
# File uploader
uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"])
if uploaded_file is None:
st.error("No file uploaded")
return
# Process uploaded file
df = process_dataset(uploaded_file)
if df is None:
st.error("Error processing file")
return
# Display dataset information
st.success("Dataset loaded successfully!")
# Display dataframe preview
st.subheader("Dataset Preview")
st.dataframe(df)
# Select Scoring Functions to Run Evaluation On
st.subheader("Select Scoring Functions")
scoring_functions = llama_stack_api.client.scoring_functions.list()
scoring_functions = {sf.identifier: sf for sf in scoring_functions}
scoring_functions_names = list(scoring_functions.keys())
selected_scoring_functions = st.multiselect(
"Choose one or more scoring functions",
options=scoring_functions_names,
help="Choose one or more scoring functions.",
)
available_models = llama_stack_api.client.models.list()
available_models = [m.identifier for m in available_models]
scoring_params = {}
if selected_scoring_functions:
st.write("Selected:")
for scoring_fn_id in selected_scoring_functions:
scoring_fn = scoring_functions[scoring_fn_id]
st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}")
new_params = None
if scoring_fn.params:
new_params = {}
for param_name, param_value in scoring_fn.params.to_dict().items():
if param_name == "type":
new_params[param_name] = param_value
continue
if param_name == "judge_model":
value = st.selectbox(
f"Select **{param_name}** for {scoring_fn_id}",
options=available_models,
index=0,
key=f"{scoring_fn_id}_{param_name}",
)
new_params[param_name] = value
else:
value = st.text_area(
f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format",
value=json.dumps(param_value, indent=2),
height=80,
)
try:
new_params[param_name] = json.loads(value)
except json.JSONDecodeError:
st.error(f"Invalid JSON for **{param_name}** in {scoring_fn_id}")
st.json(new_params)
scoring_params[scoring_fn_id] = new_params
# Add run evaluation button & slider
total_rows = len(df)
num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows)
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = df.to_dict(orient="records")
if num_rows < total_rows:
rows = rows[:num_rows]
# Create separate containers for progress text and results
progress_text_container = st.empty()
results_container = st.empty()
output_res = {}
for i, r in enumerate(rows):
# Update progress
progress = i / len(rows)
progress_bar.progress(progress, text=progress_text)
# Run evaluation for current row
score_res = llama_stack_api.run_scoring(
r,
scoring_function_ids=selected_scoring_functions,
scoring_params=scoring_params,
)
for k in r.keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(r[k])
for fn_id in selected_scoring_functions:
if fn_id not in output_res:
output_res[fn_id] = []
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
# Display current row results using separate containers
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
results_container.json(
score_res.to_json(),
expanded=2,
)
progress_bar.progress(1.0, text="Evaluation complete!")
# Display results in dataframe
if output_res:
output_df = pd.DataFrame(output_res)
st.subheader("Evaluation Results")
st.dataframe(output_df)
application_evaluation_page()

View file

@ -1,253 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pandas as pd
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def select_benchmark_1():
# Select Benchmarks
st.subheader("1. Choose An Eval Task")
benchmarks = llama_stack_api.client.benchmarks.list()
benchmarks = {et.identifier: et for et in benchmarks}
benchmarks_names = list(benchmarks.keys())
selected_benchmark = st.selectbox(
"Choose an eval task.",
options=benchmarks_names,
help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.",
)
with st.expander("View Eval Task"):
st.json(benchmarks[selected_benchmark], expanded=True)
st.session_state["selected_benchmark"] = selected_benchmark
st.session_state["benchmarks"] = benchmarks
if st.button("Confirm", key="confirm_1"):
st.session_state["selected_benchmark_1_next"] = True
def define_eval_candidate_2():
if not st.session_state.get("selected_benchmark_1_next", None):
return
st.subheader("2. Define Eval Candidate")
st.info(
"""
Define the configurations for the evaluation candidate model or agent used for generation.
Select "model" if you want to run generation with inference API, or "agent" if you want to run generation with agent API through specifying AgentConfig.
"""
)
with st.expander("Define Eval Candidate", expanded=True):
# Define Eval Candidate
candidate_type = st.radio("Candidate Type", ["model", "agent"])
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models]
selected_model = st.selectbox(
"Choose a model",
available_models,
index=0,
)
# Sampling Parameters
st.markdown("##### Sampling Parameters")
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
)
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=1,
help="The maximum number of tokens to generate",
)
repetition_penalty = st.slider(
"Repetition Penalty",
min_value=1.0,
max_value=2.0,
value=1.0,
step=0.1,
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
)
if candidate_type == "model":
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
eval_candidate = {
"type": "model",
"model": selected_model,
"sampling_params": {
"strategy": strategy,
"max_tokens": max_tokens,
"repetition_penalty": repetition_penalty,
},
}
elif candidate_type == "agent":
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful AI assistant.",
help="Initial instructions given to the AI to set its behavior and context",
)
tools_json = st.text_area(
"Tools Configuration (JSON)",
value=json.dumps(
[
{
"type": "brave_search",
"engine": "brave",
"api_key": "ENTER_BRAVE_API_KEY_HERE",
}
]
),
help="Enter tool configurations in JSON format. Each tool should have a name, description, and parameters.",
height=200,
)
try:
tools = json.loads(tools_json)
except json.JSONDecodeError:
st.error("Invalid JSON format for tools configuration")
tools = []
eval_candidate = {
"type": "agent",
"config": {
"model": selected_model,
"instructions": system_prompt,
"tools": tools,
"tool_choice": "auto",
"tool_prompt_format": "json",
"input_shields": [],
"output_shields": [],
"enable_session_persistence": False,
},
}
st.session_state["eval_candidate"] = eval_candidate
if st.button("Confirm", key="confirm_2"):
st.session_state["selected_eval_candidate_2_next"] = True
def run_evaluation_3():
if not st.session_state.get("selected_eval_candidate_2_next", None):
return
st.subheader("3. Run Evaluation")
# Add info box to explain configurations being used
st.info(
"""
Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button.
"""
)
selected_benchmark = st.session_state["selected_benchmark"]
benchmarks = st.session_state["benchmarks"]
eval_candidate = st.session_state["eval_candidate"]
dataset_id = benchmarks[selected_benchmark].dataset_id
rows = llama_stack_api.client.datasets.iterrows(
dataset_id=dataset_id,
)
total_rows = len(rows.data)
# Add number of examples control
num_rows = st.number_input(
"Number of Examples to Evaluate",
min_value=1,
max_value=total_rows,
value=5,
help="Number of examples from the dataset to evaluate. ",
)
benchmark_config = {
"type": "benchmark",
"eval_candidate": eval_candidate,
"scoring_params": {},
}
with st.expander("View Evaluation Task", expanded=True):
st.json(benchmarks[selected_benchmark], expanded=True)
with st.expander("View Evaluation Task Configuration", expanded=True):
st.json(benchmark_config, expanded=True)
# Add run button and handle evaluation
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = rows.data
if num_rows < total_rows:
rows = rows[:num_rows]
# Create separate containers for progress text and results
progress_text_container = st.empty()
results_container = st.empty()
output_res = {}
for i, r in enumerate(rows):
# Update progress
progress = i / len(rows)
progress_bar.progress(progress, text=progress_text)
# Run evaluation for current row
eval_res = llama_stack_api.client.eval.evaluate_rows(
benchmark_id=selected_benchmark,
input_rows=[r],
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
benchmark_config=benchmark_config,
)
for k in r.keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(r[k])
for k in eval_res.generations[0].keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(eval_res.generations[0][k])
for scoring_fn in benchmarks[selected_benchmark].scoring_functions:
if scoring_fn not in output_res:
output_res[scoring_fn] = []
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
results_container.json(eval_res, expanded=2)
progress_bar.progress(1.0, text="Evaluation complete!")
# Display results in dataframe
if output_res:
output_df = pd.DataFrame(output_res)
st.subheader("Evaluation Results")
st.dataframe(output_df)
def native_evaluation_page():
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Generation + Scoring)")
select_benchmark_1()
define_eval_candidate_2()
run_evaluation_3()
native_evaluation_page()

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,134 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
# Sidebar configurations
with st.sidebar:
st.header("Configuration")
available_models = llama_stack_api.client.models.list()
available_models = [
model.id
for model in available_models
if model.custom_metadata and model.custom_metadata.get("model_type") == "llm"
]
selected_model = st.selectbox(
"Choose a model",
available_models,
index=0,
)
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
)
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=1,
help="The maximum number of tokens to generate",
)
repetition_penalty = st.slider(
"Repetition Penalty",
min_value=1.0,
max_value=2.0,
value=1.0,
step=0.1,
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
)
stream = st.checkbox("Stream", value=True)
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful AI assistant.",
help="Initial instructions given to the AI to set its behavior and context",
)
# Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True):
st.session_state.messages = []
st.rerun()
# Main chat interface
st.title("🦙 Chat")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Example: What is Llama Stack?"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Display assistant response
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
response = llama_stack_api.client.inference.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
model_id=selected_model,
stream=stream,
sampling_params={
"strategy": strategy,
"max_tokens": max_tokens,
"repetition_penalty": repetition_penalty,
},
)
if stream:
for chunk in response:
if chunk.event.event_type == "progress":
full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
else:
full_response = response.completion_message.content
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})

View file

@ -1,352 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import enum
import json
import uuid
import streamlit as st
from llama_stack_client import Agent
from llama_stack_client.lib.agents.react.agent import ReActAgent
from llama_stack_client.lib.agents.react.tool_parser import ReActOutput
from llama_stack.core.ui.modules.api import llama_stack_api
class AgentType(enum.Enum):
REGULAR = "Regular"
REACT = "ReAct"
def tool_chat_page():
st.title("🛠 Tools")
client = llama_stack_api.client
models = client.models.list()
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
tool_groups = client.toolgroups.list()
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
selected_vector_stores = []
def reset_agent():
st.session_state.clear()
st.cache_resource.clear()
with st.sidebar:
st.title("Configuration")
st.subheader("Model")
model = st.selectbox(label="Model", options=model_list, on_change=reset_agent, label_visibility="collapsed")
st.subheader("Available ToolGroups")
toolgroup_selection = st.pills(
label="Built-in tools",
options=builtin_tools_list,
selection_mode="multi",
on_change=reset_agent,
format_func=lambda tool: "".join(tool.split("::")[1:]),
help="List of built-in tools from your llama stack server.",
)
if "builtin::rag" in toolgroup_selection:
vector_stores = llama_stack_api.client.vector_stores.list() or []
if not vector_stores:
st.info("No vector databases available for selection.")
vector_stores = [vector_store.identifier for vector_store in vector_stores]
selected_vector_stores = st.multiselect(
label="Select Document Collections to use in RAG queries",
options=vector_stores,
on_change=reset_agent,
)
mcp_selection = st.pills(
label="MCP Servers",
options=mcp_tools_list,
selection_mode="multi",
on_change=reset_agent,
format_func=lambda tool: "".join(tool.split("::")[1:]),
help="List of MCP servers registered to your llama stack server.",
)
toolgroup_selection.extend(mcp_selection)
grouped_tools = {}
total_tools = 0
for toolgroup_id in toolgroup_selection:
tools = client.tools.list(toolgroup_id=toolgroup_id)
grouped_tools[toolgroup_id] = [tool.name for tool in tools]
total_tools += len(tools)
st.markdown(f"Active Tools: 🛠 {total_tools}")
for group_id, tools in grouped_tools.items():
with st.expander(f"🔧 Tools from `{group_id}`"):
for idx, tool in enumerate(tools, start=1):
st.markdown(f"{idx}. `{tool.split(':')[-1]}`")
st.subheader("Agent Configurations")
st.subheader("Agent Type")
agent_type = st.radio(
label="Select Agent Type",
options=["Regular", "ReAct"],
on_change=reset_agent,
)
if agent_type == "ReAct":
agent_type = AgentType.REACT
else:
agent_type = AgentType.REGULAR
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=64,
help="The maximum number of tokens to generate",
on_change=reset_agent,
)
for i, tool_name in enumerate(toolgroup_selection):
if tool_name == "builtin::rag":
tool_dict = dict(
name="builtin::rag",
args={
"vector_store_ids": list(selected_vector_stores),
},
)
toolgroup_selection[i] = tool_dict
@st.cache_resource
def create_agent():
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
return ReActAgent(
client=client,
model=model,
tools=toolgroup_selection,
response_format={
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
else:
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
st.session_state.agent_type = agent_type
agent = create_agent()
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
session_id = st.session_state["agent_session_id"]
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
if prompt := st.chat_input(placeholder=""):
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
turn_response = agent.create_turn(
session_id=session_id,
messages=[{"role": "user", "content": prompt}],
stream=True,
)
def response_generator(turn_response):
if st.session_state.get("agent_type") == AgentType.REACT:
return _handle_react_response(turn_response)
else:
return _handle_regular_response(turn_response)
def _handle_react_response(turn_response):
current_step_content = ""
final_answer = None
tool_results = []
for response in turn_response:
if not hasattr(response.event, "payload"):
yield (
"\n\n🚨 :red[_Llama Stack server Error:_]\n"
"The response received is missing an expected `payload` attribute.\n"
"This could indicate a malformed response or an internal issue within the server.\n\n"
f"Error details: {response}"
)
return
payload = response.event.payload
if payload.event_type == "step_progress" and hasattr(payload.delta, "text"):
current_step_content += payload.delta.text
continue
if payload.event_type == "step_complete":
step_details = payload.step_details
if step_details.step_type == "inference":
yield from _process_inference_step(current_step_content, tool_results, final_answer)
current_step_content = ""
elif step_details.step_type == "tool_execution":
tool_results = _process_tool_execution(step_details, tool_results)
current_step_content = ""
else:
current_step_content = ""
if not final_answer and tool_results:
yield from _format_tool_results_summary(tool_results)
def _process_inference_step(current_step_content, tool_results, final_answer):
try:
react_output_data = json.loads(current_step_content)
thought = react_output_data.get("thought")
action = react_output_data.get("action")
answer = react_output_data.get("answer")
if answer and answer != "null" and answer is not None:
final_answer = answer
if thought:
with st.expander("🤔 Thinking...", expanded=False):
st.markdown(f":grey[__{thought}__]")
if action and isinstance(action, dict):
tool_name = action.get("tool_name")
tool_params = action.get("tool_params")
with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False):
st.json(tool_params)
if answer and answer != "null" and answer is not None:
yield f"\n\n✅ **Final Answer:**\n{answer}"
except json.JSONDecodeError:
yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```"
except Exception as e:
yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```"
return final_answer
def _process_tool_execution(step_details, tool_results):
try:
if hasattr(step_details, "tool_responses") and step_details.tool_responses:
for tool_response in step_details.tool_responses:
tool_name = tool_response.tool_name
content = tool_response.content
tool_results.append((tool_name, content))
with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False):
try:
parsed_content = json.loads(content)
st.json(parsed_content)
except json.JSONDecodeError:
st.code(content, language=None)
else:
with st.expander("⚙️ Observation", expanded=False):
st.markdown(":grey[_Tool execution step completed, but no response data found._]")
except Exception as e:
with st.expander("⚙️ Error in Tool Execution", expanded=False):
st.markdown(f":red[_Error processing tool execution: {str(e)}_]")
return tool_results
def _format_tool_results_summary(tool_results):
yield "\n\n**Here's what I found:**\n"
for tool_name, content in tool_results:
try:
parsed_content = json.loads(content)
if tool_name == "web_search" and "top_k" in parsed_content:
yield from _format_web_search_results(parsed_content)
elif "results" in parsed_content and isinstance(parsed_content["results"], list):
yield from _format_results_list(parsed_content["results"])
elif isinstance(parsed_content, dict) and len(parsed_content) > 0:
yield from _format_dict_results(parsed_content)
elif isinstance(parsed_content, list) and len(parsed_content) > 0:
yield from _format_list_results(parsed_content)
except json.JSONDecodeError:
yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n"
except (TypeError, AttributeError, KeyError, IndexError) as e:
print(f"Error processing {tool_name} result: {type(e).__name__}: {e}")
def _format_web_search_results(parsed_content):
for i, result in enumerate(parsed_content["top_k"], 1):
if i <= 3:
title = result.get("title", "Untitled")
url = result.get("url", "")
content_text = result.get("content", "").strip()
yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n"
def _format_results_list(results):
for i, result in enumerate(results, 1):
if i <= 3:
if isinstance(result, dict):
name = result.get("name", result.get("title", "Result " + str(i)))
description = result.get("description", result.get("content", result.get("summary", "")))
yield f"\n- **{name}**\n {description}\n"
else:
yield f"\n- {result}\n"
def _format_dict_results(parsed_content):
yield "\n```\n"
for key, value in list(parsed_content.items())[:5]:
if isinstance(value, str) and len(value) < 100:
yield f"{key}: {value}\n"
else:
yield f"{key}: [Complex data]\n"
yield "```\n"
def _format_list_results(parsed_content):
yield "\n"
for _, item in enumerate(parsed_content[:3], 1):
if isinstance(item, str):
yield f"- {item}\n"
elif isinstance(item, dict) and "text" in item:
yield f"- {item['text']}\n"
elif isinstance(item, dict) and len(item) > 0:
first_value = next(iter(item.values()))
if isinstance(first_value, str) and len(first_value) < 100:
yield f"- {first_value}\n"
def _handle_regular_response(turn_response):
for response in turn_response:
if hasattr(response.event, "payload"):
print(response.event.payload)
if response.event.payload.event_type == "step_progress":
if hasattr(response.event.payload.delta, "text"):
yield response.event.payload.delta.text
if response.event.payload.event_type == "step_complete":
if response.event.payload.step_details.step_type == "tool_execution":
if response.event.payload.step_details.tool_calls:
tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name)
yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n'
else:
yield "No tool_calls present in step_details"
else:
yield f"Error occurred in the Llama Stack Cluster: {response}"
with st.chat_message("assistant"):
response_content = st.write_stream(response_generator(turn_response))
st.session_state.messages.append({"role": "assistant", "content": response_content})
tool_chat_page()

View file

@ -1,5 +0,0 @@
llama-stack>=0.2.1
llama-stack-client>=0.2.1
pandas
streamlit
streamlit-option-menu

View file

@ -52,7 +52,17 @@ def resolve_config_or_distro(
logger.debug(f"Using distribution: {distro_config}")
return distro_config
# Strategy 3: Try as built distribution name
# Strategy 3: Try as distro config path (if no .yaml extension and contains a slash)
# eg: starter::run-with-postgres-store.yaml
# Use :: to avoid slash and confusion with a filesystem path
if "::" in config_or_distro:
distro_name, config_name = config_or_distro.split("::")
distro_config = _get_distro_config_path(distro_name, config_name)
if distro_config.exists():
logger.info(f"Using distribution: {distro_config}")
return distro_config
# Strategy 4: Try as built distribution name
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists():
logger.debug(f"Using built distribution: {distrib_config}")
@ -63,13 +73,15 @@ def resolve_config_or_distro(
logger.debug(f"Using built distribution: {distrib_config}")
return distrib_config
# Strategy 4: Failed - provide helpful error
# Strategy 5: Failed - provide helpful error
raise ValueError(_format_resolution_error(config_or_distro, mode))
def _get_distro_config_path(distro_name: str, mode: Mode) -> Path:
def _get_distro_config_path(distro_name: str, mode: str) -> Path:
"""Get the config file path for a distro."""
return DISTRO_DIR / distro_name / f"{mode}.yaml"
if not mode.endswith(".yaml"):
mode = f"{mode}.yaml"
return DISTRO_DIR / distro_name / mode
def _format_resolution_error(config_or_distro: str, mode: Mode) -> str:

View file

@ -84,6 +84,15 @@ def run_command(command: list[str]) -> int:
text=True,
check=False,
)
# Print stdout and stderr if command failed
if result.returncode != 0:
log.error(f"Command {' '.join(command)} failed with returncode {result.returncode}")
if result.stdout:
log.error(f"STDOUT: {result.stdout}")
if result.stderr:
log.error(f"STDERR: {result.stderr}")
return result.returncode
except subprocess.SubprocessError as e:
log.error(f"Subprocess error: {e}")

View file

@ -57,4 +57,5 @@ image_type: venv
additional_pip_packages:
- aiosqlite
- asyncpg
- psycopg2-binary
- sqlalchemy[asyncio]

View file

@ -13,5 +13,6 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template(name="ci-tests")
template.description = "CI tests for Llama Stack"
template.run_configs.pop("run-with-postgres-store.yaml", None)
return template

View file

@ -46,6 +46,9 @@ providers:
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:

View file

@ -3,3 +3,5 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .oci import get_distribution_template # noqa: F401

View file

@ -0,0 +1,35 @@
version: 2
distribution_spec:
description: Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM
inference with scalable cloud services
providers:
inference:
- provider_type: remote::oci
vector_io:
- provider_type: inline::faiss
- provider_type: remote::chromadb
- provider_type: remote::pgvector
safety:
- provider_type: inline::llama-guard
agents:
- provider_type: inline::meta-reference
eval:
- provider_type: inline::meta-reference
datasetio:
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
files:
- provider_type: inline::localfs
image_type: venv
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -0,0 +1,140 @@
---
orphan: true
---
# OCI Distribution
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} {{ model.doc_string }}`
{% endfor %}
{% endif %}
## Prerequisites
### Oracle Cloud Infrastructure Setup
Before using the OCI Generative AI distribution, ensure you have:
1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/)
2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy
3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models
4. **Authentication**: Configure authentication using either:
- **Instance Principal** (recommended for cloud-hosted deployments)
- **API Key** (for on-premises or development environments)
### Authentication Methods
#### Instance Principal Authentication (Recommended)
Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments.
Requirements:
- Instance must be running in an Oracle Cloud Infrastructure compartment
- Instance must have appropriate IAM policies to access Generative AI services
#### API Key Authentication
For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file.
### Required IAM Policies
Ensure your OCI user or instance has the following policy statements:
```
Allow group <group_name> to use generative-ai-inference-endpoints in compartment <compartment_name>
Allow group <group_name> to manage generative-ai-inference-endpoints in compartment <compartment_name>
```
## Supported Services
### Inference: OCI Generative AI
Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports:
- **Chat Completions**: Conversational AI with context awareness
- **Text Generation**: Complete prompts and generate text content
#### Available Models
Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models.
### Safety: Llama Guard
For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide:
- Content filtering and moderation
- Policy compliance checking
- Harmful content detection
### Vector Storage: Multiple Options
The distribution supports several vector storage providers:
- **FAISS**: Local in-memory vector search
- **ChromaDB**: Distributed vector database
- **PGVector**: PostgreSQL with vector extensions
### Additional Services
- **Dataset I/O**: Local filesystem and Hugging Face integration
- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities
- **Evaluation**: Meta reference evaluation framework
## Running Llama Stack with OCI
You can run the OCI distribution via Docker or local virtual environment.
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci
```
### Configuration Examples
#### Using Instance Principal (Recommended for Production)
```bash
export OCI_AUTH_TYPE=instance_principal
export OCI_REGION=us-chicago-1
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..<your-compartment-id>
```
#### Using API Key Authentication (Development)
```bash
export OCI_AUTH_TYPE=config_file
export OCI_CONFIG_FILE_PATH=~/.oci/config
export OCI_CLI_PROFILE=DEFAULT
export OCI_REGION=us-chicago-1
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id
```
## Regional Endpoints
OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit:
https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions
## Troubleshooting
### Common Issues
1. **Authentication Errors**: Verify your OCI credentials and IAM policies
2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region
3. **Permission Denied**: Check compartment permissions and Generative AI service access
4. **Region Unavailable**: Verify the specified region supports Generative AI services
### Getting Help
For additional support:
- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)
- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues)

View file

@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.oci.config import OCIConfig
def get_distribution_template(name: str = "oci") -> DistributionTemplate:
providers = {
"inference": [BuildProvider(provider_type="remote::oci")],
"vector_io": [
BuildProvider(provider_type="inline::faiss"),
BuildProvider(provider_type="remote::chromadb"),
BuildProvider(provider_type="remote::pgvector"),
],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [
BuildProvider(provider_type="remote::huggingface"),
BuildProvider(provider_type="inline::localfs"),
],
"scoring": [
BuildProvider(provider_type="inline::basic"),
BuildProvider(provider_type="inline::llm-as-judge"),
BuildProvider(provider_type="inline::braintrust"),
],
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
}
inference_provider = Provider(
provider_id="oci",
provider_type="remote::oci",
config=OCIConfig.sample_run_config(),
)
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
]
return DistributionTemplate(
name=name,
distro_type="remote_hosted",
description="Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM inference with scalable cloud services",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider],
"files": [files_provider],
},
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"OCI_AUTH_TYPE": (
"instance_principal",
"OCI authentication type (instance_principal or config_file)",
),
"OCI_REGION": (
"",
"OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1)",
),
"OCI_COMPARTMENT_OCID": (
"",
"OCI compartment ID for the Generative AI service",
),
"OCI_CONFIG_FILE_PATH": (
"~/.oci/config",
"OCI config file path (required if OCI_AUTH_TYPE is config_file)",
),
"OCI_CLI_PROFILE": (
"DEFAULT",
"OCI CLI profile name to use from config file",
),
},
)

View file

@ -0,0 +1,136 @@
version: 2
image_name: oci
apis:
- agents
- datasetio
- eval
- files
- inference
- safety
- scoring
- tool_runtime
- vector_io
providers:
inference:
- provider_id: oci
provider_type: remote::oci
config:
oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal}
oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}
oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT}
oci_region: ${env.OCI_REGION:=us-ashburn-1}
oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
persistence:
namespace: vector_io::faiss
backend: kv_default
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/oci/files}
metadata_store:
table_name: files_metadata
backend: sql_default
storage:
backends:
kv_default:
type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/kvstore.db
sql_default:
type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/sql_store.db
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models: []
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
server:
port: 8321
telemetry:
enabled: true

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .postgres_demo import get_distribution_template # noqa: F401

View file

@ -1,23 +0,0 @@
version: 2
distribution_spec:
description: Quick start template for running Llama Stack with several popular providers
providers:
inference:
- provider_type: remote::vllm
- provider_type: inline::sentence-transformers
vector_io:
- provider_type: remote::chromadb
safety:
- provider_type: inline::llama-guard
agents:
- provider_type: inline::meta-reference
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
image_type: venv
additional_pip_packages:
- asyncpg
- psycopg2-binary
- sqlalchemy[asyncio]

View file

@ -1,125 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import (
BuildProvider,
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.distributions.template import (
DistributionTemplate,
RunConfigSettings,
)
from llama_stack.providers.inline.inference.sentence_transformers import SentenceTransformersInferenceConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
def get_distribution_template() -> DistributionTemplate:
inference_providers = [
Provider(
provider_id="vllm-inference",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.VLLM_URL:=http://localhost:8000/v1}",
),
),
]
providers = {
"inference": [
BuildProvider(provider_type="remote::vllm"),
BuildProvider(provider_type="inline::sentence-transformers"),
],
"vector_io": [BuildProvider(provider_type="remote::chromadb")],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
}
name = "postgres-demo"
vector_io_providers = [
Provider(
provider_id="${env.ENABLE_CHROMADB:+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
url="${env.CHROMADB_URL:=}",
),
),
]
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
default_models = [
ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="vllm-inference",
)
]
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
embedding_model = ModelInput(
model_id="nomic-embed-text-v1.5",
provider_id=embedding_provider.provider_id,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
},
)
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Quick start template for running Llama Stack with several popular providers",
container_image=None,
template_path=None,
providers=providers,
available_models_by_provider={},
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": inference_providers + [embedding_provider],
"vector_io": vector_io_providers,
},
default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
storage_backends={
"kv_default": PostgresKVStoreConfig.sample_run_config(
table_name="llamastack_kvstore",
),
"sql_default": PostgresSqlStoreConfig.sample_run_config(),
},
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"8321",
"Port for the Llama Stack distribution server",
),
},
)

View file

@ -58,4 +58,5 @@ image_type: venv
additional_pip_packages:
- aiosqlite
- asyncpg
- psycopg2-binary
- sqlalchemy[asyncio]

View file

@ -0,0 +1,284 @@
version: 2
image_name: starter-gpu
apis:
- agents
- batches
- datasetio
- eval
- files
- inference
- post_training
- safety
- scoring
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic
provider_type: remote::anthropic
config:
api_key: ${env.ANTHROPIC_API_KEY:=}
- provider_id: gemini
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:=}
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
provider_type: remote::vertexai
config:
project: ${env.VERTEX_AI_PROJECT:=}
location: ${env.VERTEX_AI_LOCATION:=us-central1}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
persistence:
namespace: vector_io::faiss
backend: kv_default
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec.db
persistence:
namespace: vector_io::sqlite_vec
backend: kv_default
- provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus
config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter-gpu}/milvus.db
persistence:
namespace: vector_io::milvus
backend: kv_default
- provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
persistence:
namespace: vector_io::chroma_remote
backend: kv_default
- provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector
config:
host: ${env.PGVECTOR_HOST:=localhost}
port: ${env.PGVECTOR_PORT:=5432}
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
persistence:
namespace: vector_io::pgvector
backend: kv_default
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:
api_key: ${env.QDRANT_API_KEY:=}
persistence:
namespace: vector_io::qdrant_remote
backend: kv_default
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
provider_type: remote::weaviate
config:
weaviate_api_key: null
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
persistence:
namespace: vector_io::weaviate
backend: kv_default
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter-gpu/files}
metadata_store:
table_name: files_metadata
backend: sql_default
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
responses_store:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
post_training:
- provider_id: huggingface-gpu
provider_type: inline::huggingface-gpu
config:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/distributions/starter-gpu/dpo_output
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
- provider_id: reference
provider_type: inline::reference
config:
kvstore:
namespace: batches
backend: kv_postgres
storage:
backends:
kv_postgres:
type: kv_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
sql_postgres:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
stores:
metadata:
namespace: registry
backend: kv_postgres
inference:
table_name: inference_store
backend: sql_postgres
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_postgres
prompts:
namespace: prompts
backend: kv_postgres
registered_resources:
models: []
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups: []
server:
port: 8321
telemetry:
enabled: true

View file

@ -46,6 +46,9 @@ providers:
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:

View file

@ -58,4 +58,5 @@ image_type: venv
additional_pip_packages:
- aiosqlite
- asyncpg
- psycopg2-binary
- sqlalchemy[asyncio]

View file

@ -0,0 +1,281 @@
version: 2
image_name: starter
apis:
- agents
- batches
- datasetio
- eval
- files
- inference
- post_training
- safety
- scoring
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic
provider_type: remote::anthropic
config:
api_key: ${env.ANTHROPIC_API_KEY:=}
- provider_id: gemini
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:=}
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
provider_type: remote::vertexai
config:
project: ${env.VERTEX_AI_PROJECT:=}
location: ${env.VERTEX_AI_LOCATION:=us-central1}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
persistence:
namespace: vector_io::faiss
backend: kv_default
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
persistence:
namespace: vector_io::sqlite_vec
backend: kv_default
- provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus
config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
persistence:
namespace: vector_io::milvus
backend: kv_default
- provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
persistence:
namespace: vector_io::chroma_remote
backend: kv_default
- provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector
config:
host: ${env.PGVECTOR_HOST:=localhost}
port: ${env.PGVECTOR_PORT:=5432}
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
persistence:
namespace: vector_io::pgvector
backend: kv_default
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:
api_key: ${env.QDRANT_API_KEY:=}
persistence:
namespace: vector_io::qdrant_remote
backend: kv_default
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
provider_type: remote::weaviate
config:
weaviate_api_key: null
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
persistence:
namespace: vector_io::weaviate
backend: kv_default
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
table_name: files_metadata
backend: sql_default
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
responses_store:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
post_training:
- provider_id: torchtune-cpu
provider_type: inline::torchtune-cpu
config:
checkpoint_format: meta
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
- provider_id: reference
provider_type: inline::reference
config:
kvstore:
namespace: batches
backend: kv_postgres
storage:
backends:
kv_postgres:
type: kv_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
sql_postgres:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
stores:
metadata:
namespace: registry
backend: kv_postgres
inference:
table_name: inference_store
backend: sql_postgres
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_postgres
prompts:
namespace: prompts
backend: kv_postgres
registered_resources:
models: []
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups: []
server:
port: 8321
telemetry:
enabled: true

View file

@ -46,6 +46,9 @@ providers:
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:

View file

@ -17,6 +17,11 @@ from llama_stack.core.datatypes import (
ToolGroupInput,
VectorStoresConfig,
)
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
SqlStoreReference,
)
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.datatypes import RemoteProviderSpec
@ -39,6 +44,8 @@ from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOC
from llama_stack.providers.remote.vector_io.weaviate.config import (
WeaviateVectorIOConfig,
)
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
@ -185,6 +192,62 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
),
]
postgres_config = PostgresSqlStoreConfig.sample_run_config()
default_overrides = {
"inference": remote_inference_providers + [embedding_provider],
"vector_io": [
Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.MILVUS_URL:+milvus}",
provider_type="inline::milvus",
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.CHROMADB_URL:+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}/",
url="${env.CHROMADB_URL:=}",
),
),
Provider(
provider_id="${env.PGVECTOR_DB:+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}",
),
),
Provider(
provider_id="${env.QDRANT_URL:+qdrant}",
provider_type="remote::qdrant",
config=QdrantVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
url="${env.QDRANT_URL:=}",
),
),
Provider(
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
provider_type="remote::weaviate",
config=WeaviateVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
cluster_url="${env.WEAVIATE_CLUSTER_URL:=}",
),
),
],
"files": [files_provider],
}
return DistributionTemplate(
name=name,
@ -193,7 +256,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
container_image=None,
template_path=None,
providers=providers,
additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
additional_pip_packages=list(set(PostgresSqlStoreConfig.pip_packages() + PostgresKVStoreConfig.pip_packages())),
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
@ -260,6 +323,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
],
"files": [files_provider],
},
provider_overrides=default_overrides,
default_models=[],
default_tool_groups=default_tool_groups,
default_shields=default_shields,
@ -274,6 +338,55 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
default_shield_id="llama-guard",
),
),
"run-with-postgres-store.yaml": RunConfigSettings(
provider_overrides={
**default_overrides,
"agents": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
config=dict(
persistence_store=postgres_config,
responses_store=postgres_config,
),
)
],
"batches": [
Provider(
provider_id="reference",
provider_type="inline::reference",
config=dict(
kvstore=KVStoreReference(
backend="kv_postgres",
namespace="batches",
).model_dump(exclude_none=True),
),
)
],
},
storage_backends={
"kv_postgres": PostgresKVStoreConfig.sample_run_config(),
"sql_postgres": postgres_config,
},
storage_stores={
"metadata": KVStoreReference(
backend="kv_postgres",
namespace="registry",
).model_dump(exclude_none=True),
"inference": InferenceStoreReference(
backend="sql_postgres",
table_name="inference_store",
).model_dump(exclude_none=True),
"conversations": SqlStoreReference(
backend="sql_postgres",
table_name="openai_conversations",
).model_dump(exclude_none=True),
"prompts": KVStoreReference(
backend="kv_postgres",
namespace="prompts",
).model_dump(exclude_none=True),
},
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (

View file

@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import (
)
from termcolor import cprint
from llama_stack.models.llama.datatypes import ToolPromptFormat
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput
from .model import Transformer

View file

@ -15,13 +15,10 @@ from pathlib import Path
from termcolor import colored
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat
from ..datatypes import (
BuiltinTool,
RawMessage,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from . import template_data
from .chat_format import ChatFormat

View file

@ -15,7 +15,7 @@ import textwrap
from datetime import datetime
from typing import Any
from llama_stack.apis.inference import (
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolDefinition,
)

View file

@ -8,8 +8,9 @@ import json
import re
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolPromptFormat
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
from ..datatypes import RecursiveType
logger = get_logger(name=__name__, category="models::llama")

View file

@ -13,7 +13,7 @@
import textwrap
from llama_stack.apis.inference import ToolDefinition
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.models.llama.llama3.prompt_templates.base import (
PromptTemplate,
PromptTemplateGeneratorBase,

View file

@ -102,6 +102,7 @@ class MetaReferenceAgentsImpl(Agents):
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None,
max_tool_calls: int | None = None,
) -> OpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
result = await self.openai_responses_impl.create_openai_response(
@ -119,6 +120,7 @@ class MetaReferenceAgentsImpl(Agents):
include,
max_infer_iters,
guardrails,
max_tool_calls,
)
return result # type: ignore[no-any-return]

View file

@ -255,6 +255,7 @@ class OpenAIResponsesImpl:
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[str | ResponseGuardrailSpec] | None = None,
max_tool_calls: int | None = None,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
@ -270,6 +271,9 @@ class OpenAIResponsesImpl:
if not conversation.startswith("conv_"):
raise InvalidConversationIdError(conversation)
if max_tool_calls is not None and max_tool_calls < 1:
raise ValueError(f"Invalid {max_tool_calls=}; should be >= 1")
stream_gen = self._create_streaming_response(
input=input,
conversation=conversation,
@ -282,6 +286,7 @@ class OpenAIResponsesImpl:
tools=tools,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
max_tool_calls=max_tool_calls,
)
if stream:
@ -331,6 +336,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
max_tool_calls: int | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults)
# but we assert here to help mypy understand the types
@ -373,6 +379,7 @@ class OpenAIResponsesImpl:
safety_api=self.safety_api,
guardrail_ids=guardrail_ids,
instructions=instructions,
max_tool_calls=max_tool_calls,
)
# Stream the response

View file

@ -115,6 +115,7 @@ class StreamingResponseOrchestrator:
safety_api,
guardrail_ids: list[str] | None = None,
prompt: OpenAIResponsePrompt | None = None,
max_tool_calls: int | None = None,
):
self.inference_api = inference_api
self.ctx = ctx
@ -126,6 +127,10 @@ class StreamingResponseOrchestrator:
self.safety_api = safety_api
self.guardrail_ids = guardrail_ids or []
self.prompt = prompt
# System message that is inserted into the model's context
self.instructions = instructions
# Max number of total calls to built-in tools that can be processed in a response
self.max_tool_calls = max_tool_calls
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
@ -139,8 +144,8 @@ class StreamingResponseOrchestrator:
self.accumulated_usage: OpenAIResponseUsage | None = None
# Track if we've sent a refusal response
self.violation_detected = False
# system message that is inserted into the model's context
self.instructions = instructions
# Track total calls made to built-in tools
self.accumulated_builtin_tool_calls = 0
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content."""
@ -186,6 +191,7 @@ class StreamingResponseOrchestrator:
usage=self.accumulated_usage,
instructions=self.instructions,
prompt=self.prompt,
max_tool_calls=self.max_tool_calls,
)
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
@ -894,6 +900,11 @@ class StreamingResponseOrchestrator:
"""Coordinate execution of both function and non-function tool calls."""
# Execute non-function tool calls
for tool_call in non_function_tool_calls:
# Check if total calls made to built-in and mcp tools exceed max_tool_calls
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
break
# Find the item_id for this tool call
matching_item_id = None
for index, item_id in completion_result_data.tool_call_item_ids.items():
@ -974,6 +985,9 @@ class StreamingResponseOrchestrator:
if tool_response_message:
next_turn_messages.append(tool_response_message)
# Track number of calls made to built-in and mcp tools
self.accumulated_builtin_tool_calls += 1
# Execute function tool calls (client-side)
for tool_call in function_tool_calls:
# Find the item_id for this tool call from our tracking dictionary

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import math
from collections.abc import Generator
from typing import Optional
import torch
@ -14,21 +13,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.apis.inference import (
GreedySamplingStrategy,
JsonSchemaResponseFormat,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIResponseFormatJSONSchema,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.models.llama.datatypes import QuantizationMode
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model, ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
get_default_tool_prompt_format,
)
from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig
@ -106,14 +103,6 @@ def _infer_sampling_params(sampling_params: SamplingParams):
return temperature, top_p
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
tool_config = request.tool_config
if tool_config is not None and tool_config.tool_prompt_format is not None:
return tool_config.tool_prompt_format
else:
return get_default_tool_prompt_format(request.model)
class LlamaGenerator:
def __init__(
self,
@ -157,55 +146,56 @@ class LlamaGenerator:
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion(
self,
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(first_request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
first_request.response_format,
),
)
def chat_completion(
self,
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
request: OpenAIChatCompletionRequestWithExtraBody,
raw_messages: list,
):
"""Generate chat completion using OpenAI request format.
Args:
request: OpenAI chat completion request
raw_messages: Pre-converted list of RawMessage objects
"""
# Determine tool prompt format
tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json
# Prepare sampling params
sampling_params = SamplingParams()
if request.temperature is not None or request.top_p is not None:
sampling_params.strategy = TopPSamplingStrategy(
temperature=request.temperature if request.temperature is not None else 1.0,
top_p=request.top_p if request.top_p is not None else 1.0,
)
if request.max_tokens:
sampling_params.max_tokens = request.max_tokens
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
# Get logits processor for response format
logits_processor = None
if request.response_format:
if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
# Extract the actual schema from OpenAIJSONSchema TypedDict
schema_dict = request.response_format.json_schema.get("schema") or {}
json_schema_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema,
json_schema=schema_dict,
)
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
# Generate
yield from self.inner_generator.generate(
llm_inputs=[
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
for request in request_batch
],
llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(first_request.logprobs),
logprobs=False,
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
first_request.response_format,
),
logits_processor=logits_processor,
)

View file

@ -5,12 +5,19 @@
# the root directory of this source tree.
import asyncio
import time
import uuid
from collections.abc import AsyncIterator
from llama_stack.apis.inference import (
InferenceProvider,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionUsage,
OpenAIChoice,
OpenAICompletionRequestWithExtraBody,
OpenAIUserMessageParam,
ToolChoice,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
@ -19,12 +26,20 @@ from llama_stack.apis.inference.inference import (
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
JsonCustomToolGenerator,
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
@ -44,6 +59,170 @@ log = get_logger(__name__, category="inference")
SEMAPHORE = asyncio.Semaphore(1)
def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition:
"""Convert OpenAI tool format to ToolDefinition format."""
# OpenAI tools have function.name and function.parameters
return ToolDefinition(
tool_name=tool.function.name,
description=tool.function.description or "",
parameters=tool.function.parameters or {},
)
def _get_tool_choice_prompt(tool_choice, tools) -> str:
"""Generate prompt text for tool_choice behavior."""
if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto":
return ""
elif tool_choice == ToolChoice.required or tool_choice == "required":
return "You MUST use one of the provided functions/tools to answer the user query."
elif tool_choice == ToolChoice.none or tool_choice == "none":
return ""
else:
# Specific tool specified
return f"You MUST use the tool `{tool_choice}` to answer the user query."
def _raw_content_as_str(content) -> str:
"""Convert RawContent to string for system messages."""
if isinstance(content, str):
return content
elif isinstance(content, RawTextItem):
return content.text
elif isinstance(content, list):
return "\n".join(_raw_content_as_str(c) for c in content)
else:
return "<media>"
def _augment_raw_messages_for_tools_llama_3_1(
raw_messages: list[RawMessage],
tools: list,
tool_choice,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions for Llama 3.1 style models."""
messages = raw_messages.copy()
existing_system_message = None
if messages and messages[0].role == "system":
existing_system_message = messages.pop(0)
sys_content = ""
# Add tool definitions first (if present)
if tools:
# Convert OpenAI tools to ToolDefinitions
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
# For OpenAI format, all tools are custom (have string names)
tool_gen = JsonCustomToolGenerator()
tool_template = tool_gen.gen(tool_definitions)
sys_content += tool_template.render()
sys_content += "\n"
# Add default system prompt
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content += default_template.render()
# Add existing system message if present
if existing_system_message:
sys_content += "\n" + _raw_content_as_str(existing_system_message.content)
# Add tool choice prompt if needed
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
sys_content += "\n" + tool_choice_prompt
# Create new system message
new_system_message = RawMessage(
role="system",
content=[RawTextItem(text=sys_content.strip())],
)
return [new_system_message] + messages
def _augment_raw_messages_for_tools_llama_4(
raw_messages: list[RawMessage],
tools: list,
tool_choice,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models."""
messages = raw_messages.copy()
existing_system_message = None
if messages and messages[0].role == "system":
existing_system_message = messages.pop(0)
sys_content = ""
# Add tool definitions if present
if tools:
# Convert OpenAI tools to ToolDefinitions
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
# Use python_list format for Llama 4
tool_gen = PythonListCustomToolGeneratorLlama4()
system_prompt = None
if existing_system_message:
system_prompt = _raw_content_as_str(existing_system_message.content)
tool_template = tool_gen.gen(tool_definitions, system_prompt)
sys_content = tool_template.render()
elif existing_system_message:
# No tools, just use existing system message
sys_content = _raw_content_as_str(existing_system_message.content)
# Add tool choice prompt if needed
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
sys_content += "\n" + tool_choice_prompt
if sys_content:
new_system_message = RawMessage(
role="system",
content=[RawTextItem(text=sys_content.strip())],
)
return [new_system_message] + messages
return messages
def augment_raw_messages_for_tools(
raw_messages: list[RawMessage],
params: OpenAIChatCompletionRequestWithExtraBody,
llama_model,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions based on model family."""
if not params.tools:
return raw_messages
# Determine augmentation strategy based on model family
if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
):
# Llama 3.1 and Llama 3.2 multimodal use JSON format
return _augment_raw_messages_for_tools_llama_3_1(
raw_messages,
params.tools,
params.tool_choice,
)
elif llama_model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# Llama 3.2/3.3/4 use python_list format
return _augment_raw_messages_for_tools_llama_4(
raw_messages,
params.tools,
params.tool_choice,
)
else:
# Default to Llama 3.1 style
return _augment_raw_messages_for_tools_llama_3_1(
raw_messages,
params.tools,
params.tool_choice,
)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model)
@ -136,17 +315,20 @@ class MetaReferenceInferenceImpl(
self.llama_model = llama_model
log.info("Warming up...")
await self.openai_chat_completion(
model=model_id,
messages=[{"role": "user", "content": "Hi how are you?"}],
max_tokens=20,
params=OpenAIChatCompletionRequestWithExtraBody(
model=model_id,
messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")],
max_tokens=20,
)
)
log.info("Warmed up!")
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first"
"No available model yet, please register your requested model or add your model in the resources first"
)
elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
@ -155,4 +337,207 @@ class MetaReferenceInferenceImpl(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
self.check_model(params)
# Convert OpenAI messages to RawMessages
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_openai_message_to_raw_message,
decode_assistant_message,
)
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
# Augment messages with tool definitions if tools are present
raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model)
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
if isinstance(self.generator, LlamaGenerator):
generator = self.generator.chat_completion(params, raw_messages)
else:
# Model parallel: submit task to process group
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
# Check if streaming is requested
if params.stream:
return self._stream_chat_completion(generator, params)
# Non-streaming: collect all generated text
generated_text = ""
for result_batch in generator:
for result in result_batch:
if not result.ignore_token and result.source == "output":
generated_text += result.text
# Decode assistant message to extract tool calls and determine stop_reason
# Default to end_of_turn if generation completed normally
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
# Convert tool calls to OpenAI format
openai_tool_calls = None
if decoded_message.tool_calls:
from llama_stack.apis.inference import (
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
)
openai_tool_calls = [
OpenAIChatCompletionToolCall(
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
id=f"call_{uuid.uuid4().hex[:24]}",
type="function",
function=OpenAIChatCompletionToolCallFunction(
name=str(tc.tool_name),
arguments=tc.arguments,
),
)
for tc in decoded_message.tool_calls
]
# Determine finish_reason based on whether tool calls are present
finish_reason = "tool_calls" if openai_tool_calls else "stop"
# Extract content from decoded message
content = ""
if isinstance(decoded_message.content, str):
content = decoded_message.content
elif isinstance(decoded_message.content, list):
for item in decoded_message.content:
if isinstance(item, RawTextItem):
content += item.text
# Create OpenAI response
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
return OpenAIChatCompletion(
id=response_id,
object="chat.completion",
created=created,
model=params.model,
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant",
content=content,
tool_calls=openai_tool_calls,
),
finish_reason=finish_reason,
logprobs=None,
)
],
usage=OpenAIChatCompletionUsage(
prompt_tokens=0, # TODO: calculate properly
completion_tokens=0, # TODO: calculate properly
total_tokens=0, # TODO: calculate properly
),
)
async def _stream_chat_completion(
self,
generator,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Stream chat completion chunks as they're generated."""
from llama_stack.apis.inference import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoiceDelta,
OpenAIChunkChoice,
)
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
generated_text = ""
# Yield chunks as tokens are generated
for result_batch in generator:
for result in result_batch:
if result.ignore_token or result.source != "output":
continue
generated_text += result.text
# Yield delta chunk with the new text
chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(
role="assistant",
content=result.text,
),
finish_reason="",
logprobs=None,
)
],
)
yield chunk
# After generation completes, decode the full message to extract tool calls
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
# If tool calls are present, yield a final chunk with tool_calls
if decoded_message.tool_calls:
openai_tool_calls = [
OpenAIChatCompletionToolCall(
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
id=f"call_{uuid.uuid4().hex[:24]}",
type="function",
function=OpenAIChatCompletionToolCallFunction(
name=str(tc.tool_name),
arguments=tc.arguments,
),
)
for tc in decoded_message.tool_calls
]
# Yield chunk with tool_calls
chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(
role="assistant",
tool_calls=openai_tool_calls,
),
finish_reason="",
logprobs=None,
)
],
)
yield chunk
finish_reason = "tool_calls"
else:
finish_reason = "stop"
# Yield final chunk with finish_reason
final_chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(),
finish_reason=finish_reason,
logprobs=None,
)
],
)
yield final_chunk

View file

@ -4,17 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable, Generator
from copy import deepcopy
from collections.abc import Callable
from functools import partial
from typing import Any
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .parallel_utils import ModelParallelProcessGroup
@ -23,12 +18,14 @@ class ModelRunner:
def __init__(self, llama):
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: Any):
if task[0] == "chat_completion":
return self.llama.chat_completion(task[1])
task_type = task[0]
if task_type == "chat_completion":
# task[1] is [params, raw_messages]
params, raw_messages = task[1]
return self.llama.chat_completion(params, raw_messages)
else:
raise ValueError(f"Unexpected task type {task[0]}")
raise ValueError(f"Unexpected task type {task_type}")
def init_model_cb(
@ -78,19 +75,3 @@ class LlamaModelParallelGenerator:
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
def completion(
self,
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("completion", req_obj))
yield from gen
def chat_completion(
self,
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("chat_completion", req_obj))
yield from gen

View file

@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
log = get_logger(name=__name__, category="inference")
@ -69,10 +65,7 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: tuple[
str,
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
]
task: tuple[str, list]
class TaskResponse(BaseModel):
@ -328,10 +321,7 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: tuple[
str,
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
],
req: tuple[str, list],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -22,9 +22,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
)
from .config import SentenceTransformersInferenceConfig
@ -32,7 +29,6 @@ log = get_logger(name=__name__, category="inference")
class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
InferenceProvider,
ModelsProtocolPrivate,

View file

@ -91,7 +91,7 @@ class TorchtuneCheckpointer:
if checkpoint_format == "meta" or checkpoint_format is None:
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
elif checkpoint_format == "huggingface":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
# Note: for saving hugging face format checkpoints, we only support saving adapter weights now
self._save_hf_format_checkpoint(model_file_path, state_dict)
else:
raise ValueError(f"Unsupported checkpoint format: {format}")

View file

@ -25,7 +25,7 @@ def llama_stack_instruct_to_torchtune_instruct(
)
input_messages = json.loads(sample[ColumnName.chat_completion_input.value])
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
assert len(input_messages) == 1, "llama stack instruct dataset format only supports 1 user message"
input_message = input_messages[0]
assert "content" in input_message, "content not found in input message"

View file

@ -27,7 +27,6 @@ from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolGroup,
ToolInvocationResult,
@ -91,7 +90,7 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
def __init__(
self,
config: RagToolRuntimeConfig,

View file

@ -223,7 +223,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def register_vector_store(self, vector_store: VectorStore) -> None:
assert self.kvstore is not None
if self.kvstore is None:
raise RuntimeError("KVStore not initialized. Call initialize() before registering vector stores.")
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
@ -239,7 +240,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
return [i.vector_store for i in self.cache.values()]
async def unregister_vector_store(self, vector_store_id: str) -> None:
assert self.kvstore is not None
if self.kvstore is None:
raise RuntimeError("KVStore not initialized. Call initialize() before unregistering vector stores.")
if vector_store_id not in self.cache:
return
@ -248,6 +250,27 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
del self.cache[vector_store_id]
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
if self.kvstore is None:
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
vector_store_data = await self.kvstore.get(key)
if not vector_store_data:
raise VectorStoreNotFoundError(vector_store_id)
vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex(
vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
inference_api=self.inference_api,
)
self.cache[vector_store_id] = index
return index
async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = self.cache.get(vector_store_id)
if index is None:

View file

@ -412,6 +412,14 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
return [v.vector_store for v in self.cache.values()]
async def register_vector_store(self, vector_store: VectorStore) -> None:
if self.kvstore is None:
raise RuntimeError("KVStore not initialized. Call initialize() before registering vector stores.")
# Save to kvstore for persistence
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
# Create and cache the index
index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
)
@ -421,13 +429,16 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
if vector_store_id in self.cache:
return self.cache[vector_store_id]
if self.vector_store_table is None:
raise VectorStoreNotFoundError(vector_store_id)
vector_store = self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
# Try to load from kvstore
if self.kvstore is None:
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
vector_store_data = await self.kvstore.get(key)
if not vector_store_data:
raise VectorStoreNotFoundError(vector_store_id)
vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex(
vector_store=vector_store,
index=SQLiteVecIndex(

View file

@ -138,10 +138,11 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"],
pip_packages=[],
module="llama_stack.providers.remote.inference.bedrock",
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
provider_data_validator="llama_stack.providers.remote.inference.bedrock.config.BedrockProviderDataValidator",
description="AWS Bedrock inference provider using OpenAI compatible endpoint.",
),
RemoteProviderSpec(
api=Api.inference,
@ -296,6 +297,20 @@ Available Models:
Azure OpenAI inference provider for accessing GPT models and other Azure services.
Provider documentation
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
""",
),
RemoteProviderSpec(
api=Api.inference,
provider_type="remote::oci",
adapter_type="oci",
pip_packages=["oci"],
module="llama_stack.providers.remote.inference.oci",
config_class="llama_stack.providers.remote.inference.oci.config.OCIConfig",
provider_data_validator="llama_stack.providers.remote.inference.oci.config.OCIProviderDataValidator",
description="""
Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models.
Provider documentation
https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm
""",
),
]

View file

@ -20,6 +20,7 @@ This provider enables dataset management using NVIDIA's NeMo Customizer service.
Build the NVIDIA environment:
```bash
uv pip install llama-stack-client
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```

View file

@ -11,7 +11,7 @@ async def get_adapter_impl(config: BedrockConfig, _deps):
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
impl = BedrockInferenceAdapter(config)
impl = BedrockInferenceAdapter(config=config)
await impl.initialize()

Some files were not shown because too many files have changed in this diff Show more