This commit is contained in:
Xi Yan 2025-03-23 15:48:14 -07:00
commit a54d757ade
197 changed files with 9392 additions and 3089 deletions

View file

@ -36,7 +36,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -189,13 +188,11 @@ class AgentToolGroupWithArgs(BaseModel):
args: Dict[str, Any]
AgentToolGroup = register_schema(
Union[
str,
AgentToolGroupWithArgs,
],
name="AgentTool",
)
AgentToolGroup = Union[
str,
AgentToolGroupWithArgs,
]
register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
@ -312,20 +309,18 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
turn: Turn
AgentTurnResponseEventPayload = register_schema(
Annotated[
Union[
AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnAwaitingInputPayload,
],
Field(discriminator="event_type"),
AgentTurnResponseEventPayload = Annotated[
Union[
AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnAwaitingInputPayload,
],
name="AgentTurnResponseEventPayload",
)
Field(discriminator="event_type"),
]
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
@json_schema_type
@ -387,7 +382,6 @@ class AgentStepResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Agents(Protocol):
"""Agents API for creating and interacting with agentic systems.
@ -399,7 +393,7 @@ class Agents(Protocol):
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
"""
@webmethod(route="/agents", method="POST")
@webmethod(route="/agents", method="POST", descriptive_name="create_agent")
async def create_agent(
self,
agent_config: AgentConfig,
@ -411,7 +405,9 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
)
async def create_agent_turn(
self,
agent_id: str,
@ -443,6 +439,7 @@ class Agents(Protocol):
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
)
async def resume_agent_turn(
self,
@ -505,7 +502,7 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/session", method="POST")
@webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
async def create_agent_session(
self,
agent_id: str,

View file

@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
# other modalities can be added here
InterleavedContentItem = register_schema(
Annotated[
Union[ImageContentItem, TextContentItem],
Field(discriminator="type"),
],
name="InterleavedContentItem",
)
InterleavedContentItem = Annotated[
Union[ImageContentItem, TextContentItem],
Field(discriminator="type"),
]
register_schema(InterleavedContentItem, name="InterleavedContentItem")
# accept a single "str" as a special case since it is common
InterleavedContent = register_schema(
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
name="InterleavedContent",
)
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
register_schema(InterleavedContent, name="InterleavedContent")
@json_schema_type
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
# streaming completions send a stream of ContentDeltas
ContentDelta = register_schema(
Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"),
],
name="ContentDelta",
)
ContentDelta = Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"),
]
register_schema(ContentDelta, name="ContentDelta")

View file

@ -72,24 +72,22 @@ class DialogType(BaseModel):
type: Literal["dialog"] = "dialog"
ParamType = register_schema(
Annotated[
Union[
StringType,
NumberType,
BooleanType,
ArrayType,
ObjectType,
JsonType,
UnionType,
ChatCompletionInputType,
CompletionInputType,
AgentTurnInputType,
],
Field(discriminator="type"),
ParamType = Annotated[
Union[
StringType,
NumberType,
BooleanType,
ArrayType,
ObjectType,
JsonType,
UnionType,
ChatCompletionInputType,
CompletionInputType,
AgentTurnInputType,
],
name="ParamType",
)
Field(discriminator="type"),
]
register_schema(ParamType, name="ParamType")
"""
# TODO: recursive definition of ParamType in these containers

View file

@ -84,13 +84,11 @@ class RowsDataSource(BaseModel):
rows: List[Dict[str, Any]]
DataSource = register_schema(
Annotated[
Union[URIDataSource, RowsDataSource],
Field(discriminator="type"),
],
name="DataSource",
)
DataSource = Annotated[
Union[URIDataSource, RowsDataSource],
Field(discriminator="type"),
]
register_schema(DataSource, name="DataSource")
class CommonDatasetFields(BaseModel):
@ -121,8 +119,6 @@ class Dataset(CommonDatasetFields, Resource):
class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str
provider_id: Optional[str] = None
provider_dataset_id: Optional[str] = None
class ListDatasetsResponse(BaseModel):

View file

@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class ModelCandidate(BaseModel):
"""A model candidate for evaluation.
:param model: The model ID to evaluate.
:param sampling_params: The sampling parameters for the model.
:param system_message: (Optional) The system message providing instructions or context to the model.
"""
type: Literal["model"] = "model"
model: str
sampling_params: SamplingParams
system_message: Optional[SystemMessage] = None
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
config: AgentConfig
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate")
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
:param eval_candidate: The candidate to evaluate.
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
"""
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
num_examples: Optional[int] = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
# we could optinally add any specific dataset config here
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation.
:param generations: The generations from the evaluation.
:param scores: The scores from the evaluation.
"""
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: Dict[str, ScoringResult]
class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation.
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to cancel.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
"""

View file

@ -144,18 +144,16 @@ class CompletionMessage(BaseModel):
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
Message = register_schema(
Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
],
Field(discriminator="role"),
Message = Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
],
name="Message",
)
Field(discriminator="role"),
]
register_schema(Message, name="Message")
@json_schema_type
@ -263,13 +261,11 @@ class GrammarResponseFormat(BaseModel):
bnf: Dict[str, Any]
ResponseFormat = register_schema(
Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
],
name="ResponseFormat",
)
ResponseFormat = Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]
register_schema(ResponseFormat, name="ResponseFormat")
# This is an internally used class

View file

@ -24,17 +24,6 @@ class HealthInfo(BaseModel):
# TODO: add a provider level status
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@json_schema_type
class VersionInfo(BaseModel):
version: str
@ -46,9 +35,6 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/inspect/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/inspect/routes", method="GET")
async def list_routes(self) -> ListRoutesResponse: ...

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
group_size: int
AlgorithmConfig = register_schema(
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
name="AlgorithmConfig",
)
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@json_schema_type
@ -184,7 +182,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")

View file

@ -0,0 +1,149 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
)
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
@json_schema_type
class ScoringFnParamsType(Enum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
basic = "basic"
@json_schema_type
class AggregationFunctionType(Enum):
average = "average"
weighted_average = "weighted_average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
judge_model: str
prompt_template: Optional[str] = None
judge_score_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
parsing_regexes: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@json_schema_type
class BasicScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
Field(discriminator="type"),
]
register_schema(ScoringFnParams, name="ScoringFnParams")
class CommonScoringFnFields(BaseModel):
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
params: Optional[ScoringFnParams] = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
@property
def scoring_fn_id(self) -> str:
return self.identifier
@property
def provider_scoring_fn_id(self) -> str:
return self.provider_resource_id
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: Optional[str] = None
provider_scoring_fn_id: Optional[str] = None
class ListScoringFunctionsResponse(BaseModel):
data: List[ScoringFn]
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
) -> None: ...

View file

@ -146,16 +146,14 @@ class SpanEndPayload(BaseModel):
status: SpanStatus
StructuredLogPayload = register_schema(
Annotated[
Union[
SpanStartPayload,
SpanEndPayload,
],
Field(discriminator="type"),
StructuredLogPayload = Annotated[
Union[
SpanStartPayload,
SpanEndPayload,
],
name="StructuredLogPayload",
)
Field(discriminator="type"),
]
register_schema(StructuredLogPayload, name="StructuredLogPayload")
@json_schema_type
@ -164,17 +162,15 @@ class StructuredLogEvent(EventCommon):
payload: StructuredLogPayload
Event = register_schema(
Annotated[
Union[
UnstructuredLogEvent,
MetricEvent,
StructuredLogEvent,
],
Field(discriminator="type"),
Event = Annotated[
Union[
UnstructuredLogEvent,
MetricEvent,
StructuredLogEvent,
],
name="Event",
)
Field(discriminator="type"),
]
register_schema(Event, name="Event")
@json_schema_type

View file

@ -17,6 +17,15 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
@json_schema_type
class RAGDocument(BaseModel):
"""
A document to be used for document ingestion in the RAG Tool.
:param document_id: The unique identifier for the document.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
:param metadata: Additional metadata for the document.
"""
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
@ -49,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
template: str
RAGQueryGeneratorConfig = register_schema(
Annotated[
Union[
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
Field(discriminator="type"),
RAGQueryGeneratorConfig = Annotated[
Union[
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
name="RAGQueryGeneratorConfig",
)
Field(discriminator="type"),
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type

View file

@ -69,7 +69,7 @@ class ToolGroup(Resource):
@json_schema_type
class ToolInvocationResult(BaseModel):
content: InterleavedContent
content: Optional[InterleavedContent] = None
error_message: Optional[str] = None
error_code: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
@ -140,9 +140,9 @@ class SpecialToolGroup(Enum):
@runtime_checkable
@trace_protocol
class ToolRuntime(Protocol):
tool_store: ToolStore
tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime
rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")

View file

@ -36,7 +36,7 @@ class VectorDBStore(Protocol):
@runtime_checkable
@trace_protocol
class VectorIO(Protocol):
vector_db_store: VectorDBStore
vector_db_store: VectorDBStore | None = None
# this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion

View file

@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
d = json.load(f)
manifest = Manifest(**d)
if datetime.now(timezone.utc) > manifest.expires_on:
if datetime.now(timezone.utc) > manifest.expires_on.astimezone(timezone.utc):
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console()

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
def check_access(
obj_identifier: str,
obj_attributes: Optional[AccessAttributes],
user_attributes: Optional[Dict[str, Any]] = None,
) -> bool:
"""Check if the current user has access to the given object, based on access attributes.
Access control algorithm:
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
3. For each attribute category in the resource's access_attributes:
a. If the user lacks that category, access is DENIED
b. If the user has the category but none of the required values, access is DENIED
c. If the user has at least one matching value in each required category, access is GRANTED
Example:
# Resource requires:
access_attributes = AccessAttributes(
roles=["admin", "data-scientist"],
teams=["ml-team"]
)
# User has:
user_attributes = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "infra-team"],
"projects": ["llama-3"]
}
# Result: Access GRANTED
# - User has the "data-scientist" role (matches one of the required roles)
# - AND user is part of the "ml-team" (matches the required team)
# - The extra "projects" attribute is ignored
Args:
obj_identifier: The identifier of the resource object to check access for
obj_attributes: The access attributes of the resource object
user_attributes: The attributes of the current user
Returns:
bool: True if access is granted, False if denied
"""
# If object has no access attributes, allow access by default
if not obj_attributes:
return True
# If no user attributes, deny access to objects with access control
if not user_attributes:
return False
dict_attribs = obj_attributes.model_dump(exclude_none=True)
if not dict_attribs:
return True
# Check each attribute category (requires ALL categories to match)
# TODO: formalize this into a proper ABAC policy
for attr_key, required_values in dict_attribs.items():
user_values = user_attributes.get(attr_key, [])
if not user_values:
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
return False
if not any(val in user_values for val in required_values):
logger.debug(
f"Access denied to {obj_identifier}: "
f"no match for attribute '{attr_key}', required one of {required_values}"
)
return False
logger.debug(f"Access granted to {obj_identifier}")
return True

View file

@ -90,6 +90,7 @@ RUN apt-get update && apt-get install -y \
procps psmisc lsof \
traceroute \
bubblewrap \
gcc \
&& rm -rf /var/lib/apt/lists/*
ENV UV_SYSTEM_PYTHON=1
@ -235,7 +236,7 @@ image_tag="$image_name:$version_tag"
# Detect platform architecture
ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then
CLI_ARGS+=("--platform $BUILD_PLATFORM")
CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
CLI_ARGS+=("--platform" "linux/arm64")
elif [ "$ARCH" = "x86_64" ]; then

View file

@ -13,6 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.resource import Resource
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
@ -28,6 +29,115 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = Union[str, List[str]]
class AccessAttributes(BaseModel):
"""Structured representation of user attributes for access control.
This model defines a structured approach to representing user attributes
with common standard categories for access control.
Standard attribute categories include:
- roles: Role-based attributes (e.g., admin, data-scientist)
- teams: Team-based attributes (e.g., ml-team, infra-team)
- projects: Project access attributes (e.g., llama-3, customer-insights)
- namespaces: Namespace-based access control for resource isolation
"""
# Standard attribute categories - the minimal set we need now
roles: Optional[List[str]] = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: Optional[List[str]] = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: Optional[List[str]] = Field(
default=None, description="Namespace-based access control for resource isolation"
)
class ResourceWithACL(Resource):
"""Extension of Resource that adds attribute-based access control capabilities.
This class adds an optional access_attributes field that allows fine-grained control
over which users can access each resource. When attributes are defined, a user must have
matching attributes to access the resource.
Attribute Matching Algorithm:
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
4. Within each category, ANY value match is sufficient (OR relationship within a category)
Examples:
# Resource visible to everyone (no access control)
model = Model(identifier="llama-2", ...)
# Resource visible only to admins
model = Model(
identifier="gpt-4",
access_attributes=AccessAttributes(roles=["admin"])
)
# Resource visible to data scientists on the ML team
model = Model(
identifier="private-model",
access_attributes=AccessAttributes(
roles=["data-scientist", "researcher"],
teams=["ml-team"]
)
)
# ^ User must have at least one of the roles AND be on the ml-team
# Resource visible to users with specific project access
vector_db = VectorDB(
identifier="customer-embeddings",
access_attributes=AccessAttributes(
projects=["customer-insights"],
namespaces=["confidential"]
)
)
# ^ User must have access to the customer-insights project AND have confidential namespace
"""
access_attributes: Optional[AccessAttributes] = None
# Use the extended Resource for all routable objects
class ModelWithACL(Model, ResourceWithACL):
pass
class ShieldWithACL(Shield, ResourceWithACL):
pass
class VectorDBWithACL(VectorDB, ResourceWithACL):
pass
class DatasetWithACL(Dataset, ResourceWithACL):
pass
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
pass
class BenchmarkWithACL(Benchmark, ResourceWithACL):
pass
class ToolWithACL(Tool, ResourceWithACL):
pass
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
pass
RoutableObject = Union[
Model,
Shield,
@ -41,13 +151,14 @@ RoutableObject = Union[
RoutableObjectWithProvider = Annotated[
Union[
Model,
Shield,
VectorDB,
Dataset,
Benchmark,
Tool,
ToolGroup,
ModelWithACL,
ShieldWithACL,
VectorDBWithACL,
DatasetWithACL,
ScoringFnWithACL,
BenchmarkWithACL,
ToolWithACL,
ToolGroupWithACL,
],
Field(discriminator="type"),
]

View file

@ -11,9 +11,7 @@ from pydantic import BaseModel
from llama_stack.apis.inspect import (
HealthInfo,
Inspect,
ListProvidersResponse,
ListRoutesResponse,
ProviderInfo,
RouteInfo,
VersionInfo,
)
@ -39,24 +37,6 @@ class DistributionInspectImpl(Inspect):
async def initialize(self) -> None:
pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
ret = []
for api, providers in run_config.providers.items():
ret.extend(
[
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
)
for p in providers
]
)
return ListProvidersResponse(data=ret)
async def list_routes(self) -> ListRoutesResponse:
run_config = self.config.run_config

View file

@ -9,7 +9,6 @@ import inspect
import json
import logging
import os
import re
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import (
construct_stack,
get_stack_run_config_from_template,
@ -232,31 +234,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
endpoints = get_all_api_endpoints()
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in self.impls:
continue
for endpoint in api_endpoints:
impl = self.impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
self.endpoint_impls = endpoint_impls
self.endpoint_impls = initialize_endpoint_impls(self.impls)
return True
async def request(
@ -290,32 +268,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
Returns:
A tuple of (endpoint_function, path_params)
Raises:
ValueError: If no matching endpoint is found
"""
impls = self.endpoint_impls.get(method)
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, func in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params
raise ValueError(f"No endpoint found for {path}")
async def _call_non_streaming(
self,
*,
@ -326,10 +278,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {}
body |= options.json_data or {}
matched_func, path_params = self._find_matching_endpoint(options.method, path)
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
await start_trace(options.url, {"__location__": "library_client"})
await start_trace(route, {"__location__": "library_client"})
try:
result = await matched_func(**body)
finally:
@ -371,13 +323,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url
body = options.params or {}
body |= options.json_data or {}
func, path_params = self._find_matching_endpoint(options.method, path)
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
async def gen():
await start_trace(options.url, {"__location__": "library_client"})
await start_trace(route, {"__location__": "library_client"})
try:
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
@ -422,7 +374,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not body:
return {}
func, _ = self._find_matching_endpoint(method, path)
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature

View file

@ -7,21 +7,26 @@
import contextvars
import json
import logging
from typing import Any, ContextManager, Dict, Optional
from typing import Any, ContextManager, Dict, List, Optional
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
# Context variable for request provider data
# Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(ContextManager):
"""Context manager for request provider data"""
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
self.provider_data = provider_data
def __init__(
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
):
self.provider_data = provider_data or {}
if auth_attributes:
self.provider_data["__auth_attributes"] = auth_attributes
self.token = None
def __enter__(self):
@ -80,7 +85,17 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
return None
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
"""Context manager that sets request provider data from headers for the duration of the context"""
def request_provider_data_context(
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
) -> ContextManager:
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data)
return RequestProviderDataContext(provider_data, auth_attributes)
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
"""Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return None
return provider_data.get("__auth_attributes")

View file

@ -19,6 +19,8 @@ from llama_stack.apis.datasets import (
DatasetType,
DataSource,
ListDatasetsResponse,
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
@ -32,11 +34,22 @@ from llama_stack.apis.tools import (
ToolHost,
)
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import (
AccessAttributes,
BenchmarkWithACL,
DatasetWithACL,
ModelWithACL,
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
ScoringFnWithACL,
ShieldWithACL,
ToolGroupWithACL,
ToolWithACL,
VectorDBWithACL,
)
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
@ -165,6 +178,11 @@ class CommonRoutingTableImpl(RoutingTable):
if not obj:
return None
# Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
@ -181,6 +199,13 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
@ -193,7 +218,17 @@ class CommonRoutingTableImpl(RoutingTable):
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type == type]
filtered_objs = [obj for obj in objs if obj.type == type]
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
]
return filtered_objs
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
@ -230,7 +265,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = Model(
model = ModelWithACL(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
@ -276,7 +311,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
)
if params is None:
params = {}
shield = Shield(
shield = ShieldWithACL(
identifier=shield_id,
provider_resource_id=provider_shield_id,
provider_id=provider_id,
@ -330,7 +365,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
}
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
@ -358,6 +393,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
elif source["type"] == "rows":
source = RowsDataSource.parse_obj(source)
if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"
@ -378,7 +419,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
if metadata is None:
metadata = {}
dataset = Dataset(
dataset = DatasetWithACL(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
@ -429,7 +470,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
raise ValueError("No evaluation providers available. Please configure an evaluation provider.")
provider_id = list(self.impls_by_provider_id.keys())[0]
benchmark = Benchmark(
benchmark = BenchmarkWithACL(
identifier=benchmark_id,
dataset_id=dataset_id,
grader_ids=grader_ids,
@ -473,7 +514,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
for tool_def in tool_defs:
tools.append(
Tool(
ToolWithACL(
identifier=tool_def.name,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
@ -498,7 +539,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
await self.register_object(tool)
await self.dist_registry.register(
ToolGroup(
ToolGroupWithACL(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
@ -511,7 +552,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id).data
tools = (await self.list_tools(toolgroup_id)).data
for tool in tools:
await self.unregister_object(tool)
await self.unregister_object(tool_group)

View file

@ -5,16 +5,118 @@
# the root directory of this source tree.
import json
from typing import Dict, List, Optional
from urllib.parse import parse_qs
import httpx
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: Dict[str, List[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
class AuthRequest(BaseModel):
api_key: str = Field(description="The API key extracted from the Authorization header")
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
access_attributes: Optional[AccessAttributes] = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
message: Optional[str] = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
class AuthenticationMiddleware:
"""Middleware that authenticates requests using an external auth endpoint.
This middleware:
1. Extracts the Bearer token from the Authorization header
2. Sends it to the configured auth endpoint along with request details
3. Validates the response and extracts user attributes
4. Makes these attributes available to the route handlers for access control
Authentication Request Format:
```json
{
"api_key": "the-api-key-extracted-from-auth-header",
"request": {
"path": "/models/list",
"headers": {
"content-type": "application/json",
"user-agent": "..."
// All headers except Authorization
},
"params": {
"limit": ["100"],
"offset": ["0"]
// Query parameters as key -> list of values
}
}
}
```
Expected Auth Endpoint Response Format:
```json
{
"access_attributes": { // Structured attribute format
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
"namespaces": ["research"]
},
"message": "Optional message about auth result"
}
```
Attribute-Based Access Control:
The attributes returned by the auth endpoint are used to determine which
resources the user can access. Resources can specify required attributes
using the access_attributes field. For a user to access a resource:
1. All attribute categories specified in the resource must be present in the user's attributes
2. For each category, the user must have at least one matching value
If the auth endpoint doesn't return any attributes, the user will only be able to
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_endpoint):
self.app = app
self.auth_endpoint = auth_endpoint
@ -32,25 +134,57 @@ class AuthenticationMiddleware:
path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
# Remove sensitive headers
if "authorization" in request_headers:
del request_headers["authorization"]
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string)
auth_data = {
"api_key": api_key,
"request": {
"path": path,
"headers": request_headers,
"params": params,
},
}
# Build the auth request model
auth_request = AuthRequest(
api_key=api_key,
request=AuthRequestContext(
path=path,
headers=request_headers,
params=params,
),
)
# Validate with authentication endpoint
try:
async with httpx.AsyncClient() as client:
response = await client.post(self.auth_endpoint, json=auth_data)
response = await client.post(
self.auth_endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Authentication failed: {response.status_code}")
return await self._send_auth_error(send, "Authentication failed")
# Parse and validate the auth response
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
# Store attributes in request scope for access control
if auth_response.access_attributes:
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [api_key],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
except Exception:
logger.exception("Error parsing authentication response")
return await self._send_auth_error(send, "Invalid authentication response format")
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
except Exception:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import inspect
import re
from typing import Dict, List
from pydantic import BaseModel
@ -19,6 +20,7 @@ class ApiEndpoint(BaseModel):
route: str
method: str
name: str
descriptive_name: str | None = None
def toolgroup_protocol_map():
@ -58,8 +60,69 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
endpoints.append(
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
)
apis[api] = endpoints
return apis
def initialize_endpoint_impls(impls):
endpoints = get_all_api_endpoints()
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in impls:
continue
for endpoint in api_endpoints:
impl = impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
func,
endpoint.descriptive_name or endpoint.route,
)
return endpoint_impls
def find_matching_endpoint(method, path, endpoint_impls):
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
endpoint_impls: A dictionary of endpoint implementations
Returns:
A tuple of (endpoint_function, path_params, descriptive_name)
Raises:
ValueError: If no matching endpoint is found
"""
impls = endpoint_impls.get(method.lower())
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, (func, descriptive_name) in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params, descriptive_name
raise ValueError(f"No endpoint found for {path}")

View file

@ -32,6 +32,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import (
construct_stack,
redact_sensitive_fields,
@ -179,8 +183,11 @@ async def sse_generator(event_gen):
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
# Use context manager for request provider data
with request_provider_data_context(request.headers):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
@ -219,14 +226,30 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
class TracingMiddleware:
def __init__(self, app):
def __init__(self, app, impls):
self.app = app
self.impls = impls
async def __call__(self, scope, receive, send):
path = scope.get("path", "")
await start_trace(path, {"__location__": "server"})
try:
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
if not hasattr(self, "endpoint_impls"):
self.endpoint_impls = initialize_endpoint_impls(self.impls)
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally:
await end_trace()
@ -348,7 +371,6 @@ def main():
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
@ -366,7 +388,7 @@ def main():
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig()))
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_endpoints = get_all_api_endpoints()
@ -412,6 +434,7 @@ def main():
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls)
import uvicorn

View file

@ -12,9 +12,12 @@ import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
class DistributionRegistry(Protocol):
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
try:
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
continue
return all_objects
@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry):
if not json_str:
return None
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
try:
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
return None
async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set(

View file

@ -5,9 +5,7 @@
# the root directory of this source tree.
import streamlit as st
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.shared.document import Document
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import data_url_from_file
@ -35,7 +33,7 @@ def rag_chat_page():
)
if st.button("Create Vector Database"):
documents = [
Document(
RAGDocument(
document_id=uploaded_file.name,
content=data_url_from_file(uploaded_file),
)
@ -167,7 +165,7 @@ def rag_chat_page():
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()
if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip()

View file

@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
arguments: Dict[str, RecursiveType]
# Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]]
arguments_json: Optional[str] = None
@field_validator("tool_name", mode="before")
@classmethod
@ -179,13 +186,11 @@ class TopKSamplingStrategy(BaseModel):
top_k: int = Field(..., ge=1)
SamplingStrategy = register_schema(
Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
],
name="SamplingStrategy",
)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type

View file

@ -12,6 +12,7 @@
# the top-level of this source tree.
import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
@ -203,9 +204,10 @@ class ChatFormat:
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
if isinstance(tool_arguments, dict):
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
@ -229,6 +231,7 @@ class ChatFormat:
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
)
)
content = ""

View file

@ -244,6 +244,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
template_str = textwrap.dedent(
"""
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.

View file

@ -11,11 +11,8 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
)
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from .prompt_templates import (
BuiltinToolGenerator,

View file

@ -15,8 +15,11 @@ import json
import re
from typing import Optional, Tuple
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
@ -92,7 +95,15 @@ def parse_python_list_for_function_calls(input_string):
# Extract keyword arguments
for keyword in node.keywords:
function_args[keyword.arg] = ast.literal_eval(keyword.value)
try:
function_args[keyword.arg] = ast.literal_eval(keyword.value)
except ValueError as e:
logger.error(
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
)
raise ValueError(
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
) from e
result.append((function_name, function_args))

View file

@ -180,25 +180,29 @@ class ChatAgent(ShieldRunnerMixin):
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
await self._initialize_tools(request.toolgroups)
async with tracing.span("create_and_execute_turn") as span:
span = tracing.get_current_span()
if span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
async for chunk in self._run_turn(request, turn_id):
yield chunk
await self._initialize_tools(request.toolgroups)
async for chunk in self._run_turn(request, turn_id):
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
await self._initialize_tools()
async with tracing.span("resume_turn") as span:
span = tracing.get_current_span()
if span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
async for chunk in self._run_turn(request):
yield chunk
span.set_attribute("turn_id", request.turn_id)
await self._initialize_tools()
async for chunk in self._run_turn(request):
yield chunk
async def _run_turn(
self,

View file

@ -13,6 +13,9 @@ from typing import List, Optional
from pydantic import BaseModel
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
@ -24,6 +27,7 @@ class AgentSessionInfo(BaseModel):
# TODO: is this used anywhere?
vector_db_id: Optional[str] = None
started_at: datetime
access_attributes: Optional[AccessAttributes] = None
class AgentPersistence:
@ -33,11 +37,18 @@ class AgentPersistence:
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions
auth_attributes = get_auth_attributes()
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
@ -51,12 +62,34 @@ class AgentPersistence:
if not value:
return None
return AgentSessionInfo(**json.loads(value))
session_info = AgentSessionInfo(**json.loads(value))
# Check access to session
if not self._check_session_access(session_info):
return None
return session_info
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes"):
return True
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
"""Get session info if the user has access to it. For internal use by sub-session methods."""
session_info = await self.get_session_info(session_id)
if not session_info:
return None
return session_info
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
session_info = await self.get_session_info(session_id)
session_info = await self.get_session_if_accessible(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
raise ValueError(f"Session {session_id} not found or access denied")
session_info.vector_db_id = vector_db_id
await self.kvstore.set(
@ -65,12 +98,18 @@ class AgentPersistence:
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.model_dump_json(),
)
async def get_session_turns(self, session_id: str) -> List[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
values = await self.kvstore.range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
@ -87,6 +126,9 @@ class AgentPersistence:
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
)
@ -95,24 +137,36 @@ class AgentPersistence:
return Turn(**json.loads(value))
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
value=str(num_infer_iters),
)
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)

View file

@ -35,12 +35,12 @@ class PandasDataframeDataset:
else:
return self.df.iloc[idx].to_dict()
def load(self) -> None:
async def load(self) -> None:
if self.df is not None:
return
if self.dataset_def.source.type == "uri":
self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
@ -95,7 +95,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()
await dataset_impl.load()
start_index = start_index or 0
@ -114,7 +114,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()
await dataset_impl.load()
new_rows_df = pandas.DataFrame(rows)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List
from tqdm import tqdm
@ -20,8 +20,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:"
@ -101,7 +101,7 @@ class MetaReferenceEvalImpl(
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs))
self.jobs[job_id] = res
return Job(job_id=job_id)
return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@ -215,17 +215,18 @@ class MetaReferenceEvalImpl(
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs:
return JobStatus.completed
return Job(job_id=job_id, status=JobStatus.completed)
return None
raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
status = await self.job_status(benchmark_id, job_id)
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")

View file

@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
tool_name=t.function.name,
# vLLM function args come back as a string. Llama Stack expects JSON.
arguments=json.loads(t.function.arguments),
arguments_json=t.function.arguments,
)
for t in vllm_message.tool_calls
],

View file

@ -23,7 +23,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn,
)
@ -36,6 +38,8 @@ FIXED_FNS = [
RegexParserScoringFn,
RegexParserMathResponseScoringFn,
BFCLScoringFn,
IfEvalScoringFn,
DocVQAScoringFn,
]

View file

@ -0,0 +1,240 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.docvqa import docvqa
CONTRACTIONS = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
"1st": "first",
"2nd": "second",
"3rd": "third",
}
NUMBERS = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
ARTICLES = [
"a",
"an",
"the",
"to",
"in",
"from",
"by",
] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
PUNCTUATION = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def normalize_answer(s: str) -> str:
# process punctuation
for p in PUNCTUATION:
if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None):
s = s.replace(p, "")
else:
s = s.replace(p, " ")
s = PERIOD_STRIP.sub("", s, re.UNICODE)
# process digits and articles
temp_text = s.lower().split()
out_text = []
for word in temp_text:
word = NUMBERS.setdefault(word, word)
if word not in ARTICLES:
out_text.append(word)
# standardize contractions
for word_id, word in enumerate(out_text):
if word in CONTRACTIONS:
out_text[word_id] = CONTRACTIONS[word]
return " ".join(out_text)
class DocVQAScoringFn(RegisteredBaseScoringFn):
"""
docvqa basically matches the generated answer against several allowed
choices, but we need to normalize the answer to avoid penalizing
trivial differences
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
docvqa.identifier: docvqa,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "docvqa",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
expected_answers = json.loads(input_row["expected_answer"])
generated_answer = input_row["generated_answer"]
score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0
return {
"score": score,
}

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
docvqa = ScoringFn(
identifier="basic::docvqa",
description="DocVQA Visual Question & Answer scoring function",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="docvqa",
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
ifeval = ScoringFn(
identifier="basic::ifeval",
description="Eval intruction follow capacity by checkping how many instructions can be followed in each example",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="ifeval",
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.weighted_average],
),
)

View file

@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.ifeval import (
ifeval,
)
class IfEvalScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn Instruction-Following Eval (IFEval) benchmark
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
ifeval.identifier: ifeval,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
instruction_list = input_row["instruction_id_list"]
generated_answer = input_row["generated_answer"].strip()
is_following_list = []
results = dict(
{k + "_correct": 0.0 for k in INSTRUCTION_LIST},
**{k + "_total": 0.0 for k in INSTRUCTION_LIST},
)
for index, instruction_id in enumerate(instruction_list):
instruction_cls = INSTRUCTION_DICT[instruction_id]
instruction = instruction_cls(instruction_id)
results[instruction_id + "_total"] += 1.0
results[instruction_id.split(":")[0] + "_total"] += 1.0
clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
print(clean_input_row)
instruction.build_description(**clean_input_row)
args = instruction.get_instruction_args()
if args and "prompt" in args:
instruction.build_description(prompt=input_row["prompt"])
if generated_answer and instruction.check_following(generated_answer):
is_following_list.append(True)
results[instruction_id + "_correct"] += 1.0
results[instruction_id.split(":")[0] + "_correct"] += 1.0
else:
is_following_list.append(False)
if len(is_following_list) == 0:
return {
"score": 0.0,
"weight": 0.0,
}
return {
"score": float(sum(is_following_list)) / float(len(is_following_list)),
"weight": float(len(is_following_list)),
}

File diff suppressed because it is too large Load diff

View file

@ -6,12 +6,14 @@
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api
from .config import TelemetryConfig, TelemetrySink
__all__ = ["TelemetryConfig", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
from .telemetry import TelemetryAdapter
impl = TelemetryAdapter(config, deps)

View file

@ -13,19 +13,20 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class TelemetrySink(str, Enum):
OTEL = "otel"
OTEL_TRACE = "otel_trace"
OTEL_METRIC = "otel_metric"
SQLITE = "sqlite"
CONSOLE = "console"
class TelemetryConfig(BaseModel):
otel_endpoint: str = Field(
otel_trace_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
description="The OpenTelemetry collector endpoint URL for traces",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
otel_metric_endpoint: str = Field(
default="http://localhost:4318/v1/metrics",
description="The OpenTelemetry collector endpoint URL for metrics",
)
sinks: List[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
@ -46,7 +47,6 @@ class TelemetryConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
}

View file

@ -101,6 +101,6 @@ class ConsoleSpanProcessor(SpanProcessor):
"""Shutdown the processor."""
pass
def force_flush(self, timeout_millis: float = None) -> bool:
def force_flush(self, timeout_millis: float | None = None) -> bool:
"""Force flush any pending spans."""
return True

View file

@ -12,6 +12,7 @@ from datetime import datetime, timezone
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
from opentelemetry.trace.span import format_span_id, format_trace_id
class SQLiteSpanProcessor(SpanProcessor):
@ -100,14 +101,14 @@ class SQLiteSpanProcessor(SpanProcessor):
conn = self._get_connection()
cursor = conn.cursor()
trace_id = format(span.get_span_context().trace_id, "032x")
span_id = format(span.get_span_context().span_id, "016x")
trace_id = format_trace_id(span.get_span_context().trace_id)
span_id = format_span_id(span.get_span_context().span_id)
service_name = span.resource.attributes.get("service.name", "unknown")
parent_span_id = None
parent_context = span.parent
if parent_context:
parent_span_id = format(parent_context.span_id, "016x")
parent_span_id = format_span_id(parent_context.span_id)
# Insert into traces
cursor.execute(
@ -123,7 +124,7 @@ class SQLiteSpanProcessor(SpanProcessor):
(
trace_id,
service_name,
(span_id if not parent_span_id else None),
(span_id if span.attributes.get("__root_span__") == "true" else None),
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
),

View file

@ -44,7 +44,7 @@ from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTrace
from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = {
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"active_spans": {},
"counters": {},
"gauges": {},
@ -54,30 +54,21 @@ _global_lock = threading.Lock()
_TRACER_PROVIDER = None
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
return int.from_bytes(s.encode(), byteorder="big", signed=False)
def string_to_span_id(s: str) -> int:
# Use only the first 8 bytes (64 bits) for span ID
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None:
self.config = config
self.datasetio_api = deps.get(Api.datasetio)
self.meter = None
resource = Resource.create(
{
ResourceAttributes.SERVICE_NAME: self.config.service_name,
# service name is always the same, use zero-width space to avoid clutter
ResourceAttributes.SERVICE_NAME: "",
}
)
@ -91,15 +82,16 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
_TRACER_PROVIDER = provider
if TelemetrySink.OTEL in self.config.sinks:
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
if TelemetrySink.OTEL_TRACE in self.config.sinks:
span_exporter = OTLPSpanExporter(
endpoint=self.config.otel_trace_endpoint,
)
span_processor = BatchSpanProcessor(otlp_exporter)
span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
if TelemetrySink.OTEL_METRIC in self.config.sinks:
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
endpoint=self.config.otel_metric_endpoint,
)
)
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
@ -109,7 +101,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
if TelemetrySink.OTEL in self.config.sinks:
if TelemetrySink.OTEL_METRIC in self.config.sinks:
self.meter = metrics.get_meter(__name__)
if TelemetrySink.SQLITE in self.config.sinks:
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
@ -135,7 +127,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage
span_id = string_to_span_id(event.span_id)
span_id = event.span_id
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
@ -146,7 +138,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
"message": event.message,
"severity": event.severity.value,
"__ttl__": ttl_seconds,
**event.attributes,
**(event.attributes or {}),
},
timestamp=timestamp_ns,
)
@ -154,6 +146,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["counters"]:
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name,
@ -163,6 +156,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
@ -182,6 +176,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
name=name,
@ -192,8 +187,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
span_id = string_to_span_id(event.span_id)
trace_id = string_to_trace_id(event.trace_id)
span_id = int(event.span_id, 16)
tracer = trace.get_tracer(__name__)
if event.attributes is None:
event.attributes = {}
@ -204,14 +198,23 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if span_id in _GLOBAL_STORAGE["active_spans"]:
return
parent_span = None
context = None
if event.payload.parent_span_id:
parent_span_id = string_to_span_id(event.payload.parent_span_id)
parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.Context(trace_id=trace_id)
if parent_span:
context = trace.set_span_in_context(parent_span, context)
context = trace.set_span_in_context(parent_span)
else:
context = trace.set_span_in_context(
trace.NonRecordingSpan(
trace.SpanContext(
trace_id=int(event.trace_id, 16),
span_id=span_id,
is_remote=False,
trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),
)
)
)
event.attributes["__root_span__"] = "true"
span = tracer.start_span(
name=event.payload.name,

View file

@ -69,7 +69,7 @@ def popen_not_allowed(*args, **kwargs):
)
_subprocess.Popen = popen_not_allowed
_subprocess.Popen = popen_not_allowed # type: ignore
import atexit as _atexit
@ -104,7 +104,7 @@ def _open_connections():
return _NETWORK_CONNECTIONS
_builtins._open_connections = _open_connections
_builtins._open_connections = _open_connections # type: ignore
@_atexit.register

View file

@ -161,9 +161,9 @@ _set_seeds()\
def process_matplotlib_response(response, matplotlib_dump_dir: str):
image_data = response["image_data"]
# Convert the base64 string to a bytes object
images = [base64.b64decode(d["image_base64"]) for d in image_data]
images_raw = [base64.b64decode(d["image_base64"]) for d in image_data]
# Create a list of PIL images from the bytes objects
images = [Image.open(BytesIO(img)) for img in images]
images = [Image.open(BytesIO(img)) for img in images_raw]
# Create a list of image paths
image_paths = []
for i, img in enumerate(images):

View file

@ -11,7 +11,7 @@ from llama_stack.providers.datatypes import Api
from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]):
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])

View file

@ -15,6 +15,7 @@ from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import Inference
@ -23,6 +24,7 @@ from llama_stack.apis.tools import (
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
@ -62,6 +64,12 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def shutdown(self):
pass
async def register_tool(self, tool: Tool) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
return
async def insert(
self,
documents: List[RAGDocument],
@ -121,11 +129,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
return RAGQueryResult(content=None)
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
chunks = chunks[: query_config.max_chunks]
tokens = 0
picked = [
picked: list[InterleavedContentItem] = [
TextContentItem(
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)

View file

@ -15,11 +15,13 @@ import faiss
import numpy as np
from numpy.typing import NDArray
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
@ -35,16 +37,14 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
class FaissIndex(EmbeddingIndex):
chunk_by_index: Dict[int, str]
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
self.index = faiss.IndexFlatL2(dimension)
self.chunk_by_index = {}
self.chunk_by_index: dict[int, Chunk] = {}
self.kvstore = kvstore
self.bank_id = bank_id
@classmethod
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
instance = cls(dimension, kvstore, bank_id)
await instance.initialize()
return instance
@ -114,11 +114,11 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Api.inference) -> None:
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
self.config = config
self.inference_api = inference_api
self.cache = {}
self.kvstore = None
self.cache: dict[str, VectorDBWithIndex] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
@ -144,6 +144,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self,
vector_db: VectorDB,
) -> None:
assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(
key=key,
@ -161,6 +163,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
return [i.vector_db for i in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
assert self.kvstore is not None
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return

View file

@ -15,9 +15,10 @@ import numpy as np
import sqlite_vec
from numpy.typing import NDArray
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
logger = logging.getLogger(__name__)
@ -78,6 +79,8 @@ class SQLiteVecIndex(EmbeddingIndex):
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
If any insert fails, the transaction is rolled back to maintain consistency.
"""
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
cur = self.connection.cursor()
try:
# Start transaction
@ -89,6 +92,7 @@ class SQLiteVecIndex(EmbeddingIndex):
metadata_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks
if isinstance(chunk.content, str)
]
# Insert metadata (ON CONFLICT to avoid duplicates)
cur.executemany(
@ -103,6 +107,7 @@ class SQLiteVecIndex(EmbeddingIndex):
embedding_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
if isinstance(chunk.content, str)
]
# Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
@ -154,7 +159,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
"""
def __init__(self, config, inference_api: Api.inference) -> None:
def __init__(self, config, inference_api: Inference) -> None:
self.config = config
self.inference_api = inference_api
self.cache: Dict[str, VectorDBWithIndex] = {}

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.eval,
provider_type="inline::meta-reference",
pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
module="llama_stack.providers.inline.eval.meta_reference",
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
Api.scoring,
Api.inference,
Api.agents,
],
),
]

View file

@ -21,7 +21,7 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety,
provider_type="inline::prompt-guard",
pip_packages=[
"transformers",
"transformers[accelerate]",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.inline.safety.prompt_guard",

View file

@ -42,9 +42,7 @@ from llama_stack.models.llama.datatypes import (
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response,
)
@ -293,14 +291,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
if not tool_calls:
return []
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
compitable_tool_calls = [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call_function_arguments,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]

View file

@ -90,15 +90,12 @@ def _convert_to_vllm_tool_calls_in_response(
if not tool_calls:
return []
call_function_arguments = None
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call_function_arguments,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]
@ -183,6 +180,7 @@ async def _process_vllm_chat_completion_stream_response(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args,
arguments_json=args_str,
),
parse_status=ToolCallParseStatus.succeeded,
),

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import io
from urllib.parse import unquote
@ -13,12 +14,15 @@ import pandas
from llama_stack.providers.utils.memory.vector_store import parse_data_url
def get_dataframe_from_uri(uri: str):
async def get_dataframe_from_uri(uri: str):
df = None
if uri.endswith(".csv"):
df = pandas.read_csv(uri)
# Moving to its own thread to avoid io from blocking the eventloop
# This isn't ideal as it moves more then just the IO to a new thread
# but it is as close as we can easly get
df = await asyncio.to_thread(pandas.read_csv, uri)
elif uri.endswith(".xlsx"):
df = pandas.read_excel(uri)
df = await asyncio.to_thread(pandas.read_excel, uri)
elif uri.startswith("data:"):
parts = parse_data_url(uri)
data = parts["data"]

View file

@ -529,7 +529,11 @@ async def convert_message_to_openai_dict_new(
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
async def impl(
content_: InterleavedContent,
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]:
) -> Union[
str,
OpenAIChatCompletionContentPartParam,
List[OpenAIChatCompletionContentPartParam],
]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str):
return content_
@ -570,7 +574,7 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
arguments=json.dumps(tool.arguments),
),
type="function",
@ -609,6 +613,7 @@ def convert_tool_call(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments),
arguments_json=tool_call.function.arguments,
)
except Exception:
return UnparseableToolCall(
@ -759,6 +764,7 @@ def _convert_openai_tool_calls(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]
@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream(
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2
"multiple tool calls found in a single delta, using the first, ignoring the rest",
stacklevel=2,
)
if not enable_incremental_tool_calls:
@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream(
call_id=buffer["call_id"],
tool_name=buffer["name"],
arguments=arguments,
arguments_json=buffer["arguments"],
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(

View file

@ -28,6 +28,17 @@ def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]
}
def aggregate_weighted_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
return {
"weighted_average": sum(
result["score"] * result["weight"]
for result in scoring_results
if result["score"] is not None and result["weight"] is not None
)
/ sum(result["weight"] for result in scoring_results if result["weight"] is not None),
}
def aggregate_categorical_count(
scoring_results: List[ScoringResultRow],
) -> Dict[str, Any]:
@ -46,6 +57,7 @@ def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
AGGREGATION_FUNCTIONS = {
AggregationFunctionType.accuracy: aggregate_accuracy,
AggregationFunctionType.average: aggregate_average,
AggregationFunctionType.weighted_average: aggregate_weighted_average,
AggregationFunctionType.categorical_count: aggregate_categorical_count,
AggregationFunctionType.median: aggregate_median,
}

View file

@ -13,7 +13,7 @@ from llama_stack.apis.telemetry import QueryCondition, QuerySpansResponse, Span
class TelemetryDatasetMixin:
"""Mixin class that provides dataset-related functionality for telemetry providers."""
datasetio_api: DatasetIO
datasetio_api: DatasetIO | None
async def save_spans_to_dataset(
self,

View file

@ -5,12 +5,11 @@
# the root directory of this source tree.
import asyncio
import base64
import contextvars
import logging
import queue
import random
import threading
import uuid
from datetime import datetime, timezone
from functools import wraps
from typing import Any, Callable, Dict, List, Optional
@ -31,11 +30,44 @@ from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
logger = get_logger(__name__, category="core")
def generate_short_uuid(len: int = 8):
full_uuid = uuid.uuid4()
uuid_bytes = full_uuid.bytes
encoded = base64.urlsafe_b64encode(uuid_bytes)
return encoded.rstrip(b"=").decode("ascii")[:len]
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
def trace_id_to_str(trace_id: int) -> str:
"""Convenience trace ID formatting method
Args:
trace_id: Trace ID int
Returns:
The trace ID as 32-byte hexadecimal string
"""
return format(trace_id, "032x")
def span_id_to_str(span_id: int) -> str:
"""Convenience span ID formatting method
Args:
span_id: Span ID int
Returns:
The span ID as 16-byte hexadecimal string
"""
return format(span_id, "016x")
def generate_span_id() -> str:
span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id_to_str(span_id)
def generate_trace_id() -> str:
trace_id = random.getrandbits(128)
while trace_id == INVALID_TRACE_ID:
trace_id = random.getrandbits(128)
return trace_id_to_str(trace_id)
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
@ -83,7 +115,7 @@ class TraceContext:
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
current_span = self.get_current_span()
span = Span(
span_id=generate_short_uuid(),
span_id=generate_span_id(),
trace_id=self.trace_id,
name=name,
start_time=datetime.now(timezone.utc),
@ -143,7 +175,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceCont
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
return
trace_id = generate_short_uuid(16)
trace_id = generate_trace_id()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})})

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Protocol, TypeVar
from typing import Any, Callable, List, Optional, TypeVar
from .strong_typing.schema import json_schema_type, register_schema # noqa: F401
@ -18,13 +18,11 @@ class WebMethod:
response_examples: Optional[List[Any]] = None
method: Optional[str] = None
raw_bytes_request_body: Optional[bool] = False
# A descriptive name of the corresponding span created by tracing
descriptive_name: Optional[str] = None
class HasWebMethod(Protocol):
__webmethod__: WebMethod
T = TypeVar("T", bound=HasWebMethod) # Bound T to classes that match this protocol
T = TypeVar("T", bound=Callable[..., Any])
def webmethod(
@ -34,6 +32,7 @@ def webmethod(
request_examples: Optional[List[Any]] = None,
response_examples: Optional[List[Any]] = None,
raw_bytes_request_body: Optional[bool] = False,
descriptive_name: Optional[str] = None,
) -> Callable[[T], T]:
"""
Decorator that supplies additional metadata to an endpoint operation function.
@ -44,15 +43,16 @@ def webmethod(
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
"""
def wrap(cls: T) -> T:
cls.__webmethod__ = WebMethod(
def wrap(func: T) -> T:
func.__webmethod__ = WebMethod( # type: ignore
route=route,
method=method,
public=public or False,
request_examples=request_examples,
response_examples=response_examples,
raw_bytes_request_body=raw_bytes_request_body,
descriptive_name=descriptive_name,
)
return cls
return func
return wrap

View file

@ -17,6 +17,7 @@ import enum
import functools
import inspect
import json
import types
import typing
import uuid
from copy import deepcopy
@ -455,7 +456,7 @@ class JsonSchemaGenerator:
"maxItems": len(args),
"prefixItems": [self.type_to_schema(member_type) for member_type in args],
}
elif origin_type is Union:
elif origin_type in (Union, types.UnionType):
discriminator = None
if typing.get_origin(data_type) is Annotated:
discriminator = typing.get_args(data_type)[1].discriminator

View file

@ -78,7 +78,7 @@ def get_distribution_template() -> DistributionTemplate:
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"8321",
"Port for the Llama Stack distribution server",
),
},

View file

@ -47,9 +47,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \

View file

@ -37,7 +37,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/bedrock/trace_store.db}
datasetio:

View file

@ -102,7 +102,7 @@ def get_distribution_template() -> DistributionTemplate:
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"8321",
"Port for the Llama Stack distribution server",
),
"CEREBRAS_API_KEY": (

View file

@ -39,9 +39,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-{{ name }} \
@ -55,6 +56,6 @@ docker run \
```bash
llama stack build --template cerebras --image-type conda
llama stack run ./run.yaml \
--port 5001 \
--port 8321 \
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
```

View file

@ -58,7 +58,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/cerebras/trace_store.db}
tool_runtime:

View file

@ -108,7 +108,7 @@ def get_distribution_template() -> DistributionTemplate:
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"8321",
"Port for the Llama Stack distribution server",
),
"FIREWORKS_API_KEY": (

View file

@ -40,7 +40,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ci-tests/trace_store.db}
datasetio:

View file

@ -43,6 +43,7 @@ export CUDA_VISIBLE_DEVICES=0
export LLAMA_STACK_PORT=8321
docker run --rm -it \
--pull always \
--network host \
-v $HOME/.cache/huggingface:/data \
-e HF_TOKEN=$HF_TOKEN \
@ -66,6 +67,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1
docker run --rm -it \
--pull always \
--network host \
-v $HOME/.cache/huggingface:/data \
-e HF_TOKEN=$HF_TOKEN \
@ -108,6 +110,7 @@ This method allows you to get started quickly without having to build the distri
```bash
docker run -it \
--pull always \
--network host \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $HOME/.llama:/root/.llama \
@ -135,6 +138,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $HOME/.llama:/root/.llama \
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \

View file

@ -43,7 +43,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
datasetio:

View file

@ -39,7 +39,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
datasetio:

View file

@ -184,7 +184,7 @@ def get_distribution_template() -> DistributionTemplate:
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"8321",
"Port for the Llama Stack distribution server",
),
"FIREWORKS_API_KEY": (

View file

@ -69,7 +69,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
datasetio:

View file

@ -28,7 +28,11 @@ providers:
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/meta_reference_eval.db
scoring:
- provider_id: basic
provider_type: inline::basic
@ -40,7 +44,11 @@ providers:
datasetio:
- provider_id: localfs
provider_type: inline::localfs
config: {}
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/localfs_datasetio.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -58,7 +66,7 @@ providers:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/agents_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -70,7 +78,7 @@ providers:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/faiss_store.db
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
@ -82,7 +90,7 @@ providers:
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/registry.db
models: []
shields: []
vector_dbs: []

View file

@ -49,9 +49,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \

View file

@ -160,7 +160,7 @@ def get_distribution_template() -> DistributionTemplate:
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"8321",
"Port for the Llama Stack distribution server",
),
"FIREWORKS_API_KEY": (

View file

@ -48,7 +48,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
datasetio:

View file

@ -43,7 +43,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
datasetio:

View file

@ -49,9 +49,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \

View file

@ -95,7 +95,7 @@ def get_distribution_template() -> DistributionTemplate:
},
run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"8321",
"Port for the Llama Stack distribution server",
),
"GROQ_API_KEY": (

View file

@ -43,7 +43,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/groq/trace_store.db}
datasetio:

View file

@ -125,7 +125,7 @@ def get_distribution_template() -> DistributionTemplate:
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"8321",
"Port for the Llama Stack distribution server",
),
"HF_API_TOKEN": (

View file

@ -48,7 +48,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
datasetio:

View file

@ -43,7 +43,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
datasetio:

View file

@ -126,7 +126,7 @@ def get_distribution_template() -> DistributionTemplate:
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"8321",
"Port for the Llama Stack distribution server",
),
"HF_API_TOKEN": (

View file

@ -48,7 +48,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
datasetio:

View file

@ -43,7 +43,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
datasetio:

View file

@ -65,9 +65,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-{{ name }} \
@ -80,6 +81,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
```bash
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-{{ name }} \
@ -95,7 +97,7 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
```bash
llama stack build --template {{ name }} --image-type conda
llama stack run distributions/{{ name }}/run.yaml \
--port 5001 \
--port 8321 \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
```
@ -103,7 +105,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
```bash
llama stack run distributions/{{ name }}/run-with-safety.yaml \
--port 5001 \
--port 8321 \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
```

View file

@ -132,7 +132,7 @@ def get_distribution_template() -> DistributionTemplate:
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"8321",
"Port for the Llama Stack distribution server",
),
"INFERENCE_MODEL": (

View file

@ -50,7 +50,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
datasetio:

View file

@ -44,7 +44,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
datasetio:

View file

@ -67,9 +67,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-{{ name }} \
@ -82,6 +83,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
```bash
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-{{ name }} \

View file

@ -98,7 +98,7 @@ def get_distribution_template() -> DistributionTemplate:
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"8321",
"Port for the Llama Stack distribution server",
),
"INFERENCE_MODEL": (

View file

@ -46,7 +46,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-quantized-gpu/trace_store.db}
datasetio:

View file

@ -39,9 +39,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-{{ name }} \
@ -55,7 +56,7 @@ docker run \
```bash
llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \
--port 5001 \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
--env INFERENCE_MODEL=$INFERENCE_MODEL
```

View file

@ -46,7 +46,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
datasetio:

Some files were not shown because too many files have changed in this diff Show more