Merge branch 'main' of https://github.com/meta-llama/llama-stack into add_nemo_customizer

This commit is contained in:
Ubuntu 2025-03-20 09:34:19 +00:00
commit f534b4c2ea
571 changed files with 229651 additions and 12956 deletions

View file

@ -41,16 +41,36 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
class Attachment(BaseModel):
"""An attachment to an agent turn.
:param content: The content of the attachment.
:param mime_type: The MIME type of the attachment.
"""
content: InterleavedContent | URL
mime_type: str
class Document(BaseModel):
"""A document to be used by an agent.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
"""
content: InterleavedContent | URL
mime_type: str
class StepCommon(BaseModel):
"""A common step in an agent turn.
:param turn_id: The ID of the turn.
:param step_id: The ID of the step.
:param started_at: The time the step started.
:param completed_at: The time the step completed.
"""
turn_id: str
step_id: str
started_at: Optional[datetime] = None
@ -58,6 +78,14 @@ class StepCommon(BaseModel):
class StepType(Enum):
"""Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM.
:cvar tool_execution: The step is a tool execution step that executes a tool call.
:cvar shield_call: The step is a shield call step that checks for safety violations.
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
"""
inference = "inference"
tool_execution = "tool_execution"
shield_call = "shield_call"
@ -66,6 +94,11 @@ class StepType(Enum):
@json_schema_type
class InferenceStep(StepCommon):
"""An inference step in an agent turn.
:param model_response: The response from the LLM.
"""
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value
@ -74,6 +107,12 @@ class InferenceStep(StepCommon):
@json_schema_type
class ToolExecutionStep(StepCommon):
"""A tool execution step in an agent turn.
:param tool_calls: The tool calls to execute.
:param tool_responses: The tool responses from the tool calls.
"""
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
tool_calls: List[ToolCall]
tool_responses: List[ToolResponse]
@ -81,13 +120,25 @@ class ToolExecutionStep(StepCommon):
@json_schema_type
class ShieldCallStep(StepCommon):
"""A shield call step in an agent turn.
:param violation: The violation from the shield call.
"""
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
violation: Optional[SafetyViolation]
@json_schema_type
class MemoryRetrievalStep(StepCommon):
"""A memory retrieval step in an agent turn.
:param vector_db_ids: The IDs of the vector databases to retrieve context from.
:param inserted_context: The context retrieved from the vector databases.
"""
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
# TODO: should this be List[str]?
vector_db_ids: str
inserted_context: InterleavedContent
@ -138,17 +189,15 @@ class AgentToolGroupWithArgs(BaseModel):
args: Dict[str, Any]
AgentToolGroup = register_schema(
Union[
str,
AgentToolGroupWithArgs,
],
name="AgentTool",
)
AgentToolGroup = Union[
str,
AgentToolGroupWithArgs,
]
register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
@ -183,6 +232,23 @@ class AgentConfig(AgentConfigCommon):
response_format: Optional[ResponseFormat] = None
@json_schema_type
class Agent(BaseModel):
agent_id: str
agent_config: AgentConfig
created_at: datetime
@json_schema_type
class ListAgentsResponse(BaseModel):
data: List[Agent]
@json_schema_type
class ListAgentSessionsResponse(BaseModel):
data: List[Session]
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None
@ -244,20 +310,18 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
turn: Turn
AgentTurnResponseEventPayload = register_schema(
Annotated[
Union[
AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnAwaitingInputPayload,
],
Field(discriminator="event_type"),
AgentTurnResponseEventPayload = Annotated[
Union[
AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnAwaitingInputPayload,
],
name="AgentTurnResponseEventPayload",
)
Field(discriminator="event_type"),
]
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
@json_schema_type
@ -296,16 +360,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None
# TODO (xiyan): temporary flag, will remove for 0.1.5
allow_turn_resume: Optional[bool] = False
@json_schema_type
class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: List[ToolResponseMessage]
tool_responses: List[ToolResponse]
stream: Optional[bool] = False
@ -338,7 +399,13 @@ class Agents(Protocol):
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse: ...
) -> AgentCreateResponse:
"""Create an agent with the given configuration.
:param agent_config: The configuration for the agent.
:returns: An AgentCreateResponse with the agent ID.
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
async def create_agent_turn(
@ -355,8 +422,19 @@ class Agents(Protocol):
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Create a new turn for an agent.
:param agent_id: The ID of the agent to create the turn for.
:param session_id: The ID of the session to create the turn for.
:param messages: List of messages to start the turn with.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param documents: (Optional) List of documents to create the turn with.
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
"""
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
@ -367,7 +445,7 @@ class Agents(Protocol):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses.
@ -392,7 +470,15 @@ class Agents(Protocol):
agent_id: str,
session_id: str,
turn_id: str,
) -> Turn: ...
) -> Turn:
"""Retrieve an agent turn by its ID.
:param agent_id: The ID of the agent to get the turn for.
:param session_id: The ID of the session to get the turn for.
:param turn_id: The ID of the turn to get.
:returns: A Turn.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
@ -404,14 +490,30 @@ class Agents(Protocol):
session_id: str,
turn_id: str,
step_id: str,
) -> AgentStepResponse: ...
) -> AgentStepResponse:
"""Retrieve an agent step by its ID.
:param agent_id: The ID of the agent to get the step for.
:param session_id: The ID of the session to get the step for.
:param turn_id: The ID of the turn to get the step for.
:param step_id: The ID of the step to get.
:returns: An AgentStepResponse.
"""
...
@webmethod(route="/agents/{agent_id}/session", method="POST")
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse: ...
) -> AgentSessionCreateResponse:
"""Create a new session for an agent.
:param agent_id: The ID of the agent to create the session for.
:param session_name: The name of the session to create.
:returns: An AgentSessionCreateResponse.
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
async def get_agents_session(
@ -419,17 +521,64 @@ class Agents(Protocol):
session_id: str,
agent_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session: ...
) -> Session:
"""Retrieve an agent session by its ID.
:param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by.
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
async def delete_agents_session(
self,
session_id: str,
agent_id: str,
) -> None: ...
) -> None:
"""Delete an agent session by its ID.
:param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for.
"""
...
@webmethod(route="/agents/{agent_id}", method="DELETE")
async def delete_agent(
self,
agent_id: str,
) -> None: ...
) -> None:
"""Delete an agent by its ID.
:param agent_id: The ID of the agent to delete.
"""
...
@webmethod(route="/agents", method="GET")
async def list_agents(self) -> ListAgentsResponse:
"""List all agents.
:returns: A ListAgentsResponse.
"""
...
@webmethod(route="/agents/{agent_id}", method="GET")
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
:param agent_id: ID of the agent.
:returns: An Agent of the agent.
"""
...
@webmethod(route="/agents/{agent_id}/sessions", method="GET")
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
:returns: A ListAgentSessionsResponse.
"""
...

View file

@ -40,7 +40,7 @@ class BatchInference(Protocol):
self,
model: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
@ -50,7 +50,7 @@ class BatchInference(Protocol):
self,
model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,

View file

@ -52,7 +52,7 @@ class Benchmarks(Protocol):
async def get_benchmark(
self,
benchmark_id: str,
) -> Optional[Benchmark]: ...
) -> Benchmark: ...
@webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark(

View file

@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
# other modalities can be added here
InterleavedContentItem = register_schema(
Annotated[
Union[ImageContentItem, TextContentItem],
Field(discriminator="type"),
],
name="InterleavedContentItem",
)
InterleavedContentItem = Annotated[
Union[ImageContentItem, TextContentItem],
Field(discriminator="type"),
]
register_schema(InterleavedContentItem, name="InterleavedContentItem")
# accept a single "str" as a special case since it is common
InterleavedContent = register_schema(
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
name="InterleavedContent",
)
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
register_schema(InterleavedContent, name="InterleavedContent")
@json_schema_type
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
# streaming completions send a stream of ContentDeltas
ContentDelta = register_schema(
Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"),
],
name="ContentDelta",
)
ContentDelta = Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"),
]
register_schema(ContentDelta, name="ContentDelta")

View file

@ -72,24 +72,22 @@ class DialogType(BaseModel):
type: Literal["dialog"] = "dialog"
ParamType = register_schema(
Annotated[
Union[
StringType,
NumberType,
BooleanType,
ArrayType,
ObjectType,
JsonType,
UnionType,
ChatCompletionInputType,
CompletionInputType,
AgentTurnInputType,
],
Field(discriminator="type"),
ParamType = Annotated[
Union[
StringType,
NumberType,
BooleanType,
ArrayType,
ObjectType,
JsonType,
UnionType,
ChatCompletionInputType,
CompletionInputType,
AgentTurnInputType,
],
name="ParamType",
)
Field(discriminator="type"),
]
register_schema(ParamType, name="ParamType")
"""
# TODO: recursive definition of ParamType in these containers

View file

@ -13,11 +13,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class PaginatedRowsResult(BaseModel):
# the rows obey the DatasetSchema for the given dataset
rows: List[Dict[str, Any]]
total_count: int
next_page_token: Optional[str] = None
class IterrowsResponse(BaseModel):
"""
A paginated list of rows from a dataset.
:param data: The rows in the current page.
:param next_start_index: Index into dataset for the first row in the next page. None if there are no more rows.
"""
data: List[Dict[str, Any]]
next_start_index: Optional[int] = None
class DatasetStore(Protocol):
@ -29,14 +34,21 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/rows", method="GET")
async def get_rows_paginated(
# TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ...
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
"""Get a paginated list of rows from a dataset. Uses cursor-based pagination.
@webmethod(route="/datasetio/rows", method="POST")
:param dataset_id: The ID of the dataset to get the rows from.
:param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get.
"""
...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...

View file

@ -4,19 +4,100 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class DatasetPurpose(str, Enum):
"""
Purpose of the dataset. Each purpose has a required input data schema.
:cvar post-training/messages: The dataset contains messages used for post-training.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
:cvar eval/question-answer: The dataset contains a question column and an answer column.
{
"question": "What is the capital of France?",
"answer": "Paris"
}
:cvar eval/messages-answer: The dataset contains a messages column with list of messages and an answer column.
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
"""
post_training_messages = "post-training/messages"
eval_question_answer = "eval/question-answer"
eval_messages_answer = "eval/messages-answer"
# TODO: add more schemas here
class DatasetType(Enum):
"""
Type of the dataset source.
:cvar uri: The dataset can be obtained from a URI.
:cvar rows: The dataset is stored in rows.
"""
uri = "uri"
rows = "rows"
@json_schema_type
class URIDataSource(BaseModel):
"""A dataset that can be obtained from a URI.
:param uri: The dataset can be obtained from a URI. E.g.
- "https://mywebsite.com/mydata.jsonl"
- "lsfs://mydata.jsonl"
- "data:csv;base64,{base64_content}"
"""
type: Literal["uri"] = "uri"
uri: str
@json_schema_type
class RowsDataSource(BaseModel):
"""A dataset stored in rows.
:param rows: The dataset is stored in rows. E.g.
- [
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
]
"""
type: Literal["rows"] = "rows"
rows: List[Dict[str, Any]]
DataSource = Annotated[
Union[URIDataSource, RowsDataSource],
Field(discriminator="type"),
]
register_schema(DataSource, name="DataSource")
class CommonDatasetFields(BaseModel):
dataset_schema: Dict[str, ParamType]
url: URL
"""
Common fields for a dataset.
"""
purpose: DatasetPurpose
source: DataSource
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this dataset",
@ -38,8 +119,6 @@ class Dataset(CommonDatasetFields, Resource):
class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str
provider_id: Optional[str] = None
provider_dataset_id: Optional[str] = None
class ListDatasetsResponse(BaseModel):
@ -50,19 +129,75 @@ class Datasets(Protocol):
@webmethod(route="/datasets", method="POST")
async def register_dataset(
self,
dataset_id: str,
dataset_schema: Dict[str, ParamType],
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
dataset_id: Optional[str] = None,
) -> Dataset:
"""
Register a new dataset.
:param purpose: The purpose of the dataset. One of
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
{
"question": "What is the capital of France?",
"answer": "Paris"
}
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples:
- {
"type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl"
}
- {
"type": "uri",
"uri": "lsfs://mydata.jsonl"
}
- {
"type": "uri",
"uri": "data:csv;base64,{base64_content}"
}
- {
"type": "uri",
"uri": "huggingface://llamastack/simpleqa?split=train"
}
- {
"type": "rows",
"rows": [
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
]
}
:param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"}
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
async def get_dataset(
self,
dataset_id: str,
) -> Optional[Dataset]: ...
) -> Dataset: ...
@webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ...

View file

@ -5,12 +5,16 @@
# the root directory of this source tree.
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Api(Enum):
providers = "providers"
inference = "inference"
safety = "safety"
agents = "agents"
@ -33,3 +37,20 @@ class Api(Enum):
# built-in API
inspect = "inspect"
@json_schema_type
class Error(BaseModel):
"""
Error response from the API. Roughly follows RFC 7807.
:param status: HTTP status code
:param title: Error title, a short summary of the error which is invariant for an error type
:param detail: Error detail, a longer human-readable description of the error
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
"""
status: int
title: str
detail: str
instance: Optional[str] = None

View file

@ -19,6 +19,13 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
@json_schema_type
class ModelCandidate(BaseModel):
"""A model candidate for evaluation.
:param model: The model ID to evaluate.
:param sampling_params: The sampling parameters for the model.
:param system_message: (Optional) The system message providing instructions or context to the model.
"""
type: Literal["model"] = "model"
model: str
sampling_params: SamplingParams
@ -27,18 +34,28 @@ class ModelCandidate(BaseModel):
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
config: AgentConfig
EvalCandidate = register_schema(
Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")],
name="EvalCandidate",
)
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate")
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
:param eval_candidate: The candidate to evaluate.
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
"""
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
@ -53,18 +70,32 @@ class BenchmarkConfig(BaseModel):
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation.
:param generations: The generations from the evaluation.
:param scores: The scores from the evaluation.
"""
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: Dict[str, ScoringResult]
class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
async def run_eval(
self,
benchmark_id: str,
task_config: BenchmarkConfig,
) -> Job: ...
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation.
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows(
@ -72,14 +103,41 @@ class Eval(Protocol):
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: BenchmarkConfig,
) -> EvaluateResponse: ...
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: ...
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: ...
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to cancel.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ...
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
"""

View file

@ -115,7 +115,7 @@ class Files(Protocol):
async def get_upload_session_info(
self,
upload_id: str,
) -> Optional[FileUploadResponse]:
) -> FileUploadResponse:
"""
Returns information about an existsing upload session

View file

@ -117,13 +117,11 @@ class ToolResponseMessage(BaseModel):
:param role: Must be "tool" to identify this as a tool response
:param call_id: Unique identifier for the tool call this response is for
:param tool_name: Name of the tool that was called
:param content: The response content from the tool
"""
role: Literal["tool"] = "tool"
call_id: str
tool_name: Union[BuiltinTool, str]
content: InterleavedContent
@ -146,18 +144,16 @@ class CompletionMessage(BaseModel):
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
Message = register_schema(
Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
],
Field(discriminator="role"),
Message = Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
],
name="Message",
)
Field(discriminator="role"),
]
register_schema(Message, name="Message")
@json_schema_type
@ -265,27 +261,25 @@ class GrammarResponseFormat(BaseModel):
bnf: Dict[str, Any]
ResponseFormat = register_schema(
Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
],
name="ResponseFormat",
)
ResponseFormat = Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]
register_schema(ResponseFormat, name="ResponseFormat")
# This is an internally used class
class CompletionRequest(BaseModel):
model: str
content: InterleavedContent
sampling_params: Optional[SamplingParams] = SamplingParams()
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class CompletionResponse(BaseModel):
class CompletionResponse(MetricResponseMixin):
"""Response from a completion request.
:param content: The generated completion text
@ -299,7 +293,7 @@ class CompletionResponse(BaseModel):
@json_schema_type
class CompletionResponseStreamChunk(BaseModel):
class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens.
@ -357,7 +351,7 @@ class ToolConfig(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams()
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
@ -368,7 +362,7 @@ class ChatCompletionRequest(BaseModel):
@json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed chat completion response.
:param event: The event containing the new content
@ -378,7 +372,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request.
:param completion_message: The complete response message
@ -444,7 +438,7 @@ class Inference(Protocol):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -467,7 +461,7 @@ class Inference(Protocol):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,

View file

@ -11,13 +11,6 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
@json_schema_type
class RouteInfo(BaseModel):
route: str
@ -36,19 +29,12 @@ class VersionInfo(BaseModel):
version: str
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
class ListRoutesResponse(BaseModel):
data: List[RouteInfo]
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/inspect/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/inspect/routes", method="GET")
async def list_routes(self) -> ListRoutesResponse: ...

View file

@ -66,7 +66,7 @@ class Models(Protocol):
async def get_model(
self,
model_id: str,
) -> Optional[Model]: ...
) -> Model: ...
@webmethod(route="/models", method="POST")
async def register_model(

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from typing import Any, Dict, List, Literal, Optional, Protocol
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
group_size: int
AlgorithmConfig = register_schema(
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
name="AlgorithmConfig",
)
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@json_schema_type
@ -184,7 +182,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")
@ -202,10 +200,10 @@ class PostTraining(Protocol):
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ...
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...

View file

@ -3,3 +3,5 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .providers import * # noqa: F401 F403

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
config: Dict[str, Any]
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@runtime_checkable
class Providers(Protocol):
"""
Providers API for inspecting, listing, and modifying providers and their configurations.
"""
@webmethod(route="/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/providers/{provider_id}", method="GET")
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...

View file

@ -17,6 +17,13 @@ ScoringResultRow = Dict[str, Any]
@json_schema_type
class ScoringResult(BaseModel):
"""
A scoring result for a single row.
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
:param aggregated_results: Map of metric name to aggregated value
"""
score_rows: List[ScoringResultRow]
# aggregated metrics to value
aggregated_results: Dict[str, Any]
@ -30,6 +37,12 @@ class ScoreBatchResponse(BaseModel):
@json_schema_type
class ScoreResponse(BaseModel):
"""
The response from scoring.
:param results: A map of scoring function name to ScoringResult.
"""
# each key in the dict is a scoring function name
results: Dict[str, ScoringResult]
@ -55,4 +68,11 @@ class Scoring(Protocol):
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse: ...
) -> ScoreResponse:
"""Score a list of rows.
:param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring.
:return: ScoreResponse object containing rows and aggregated results
"""
...

View file

@ -36,6 +36,7 @@ class ScoringFnParamsType(Enum):
@json_schema_type
class AggregationFunctionType(Enum):
average = "average"
weighted_average = "weighted_average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
@ -78,17 +79,15 @@ class BasicScoringFnParams(BaseModel):
)
ScoringFnParams = register_schema(
Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
Field(discriminator="type"),
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
name="ScoringFnParams",
)
Field(discriminator="type"),
]
register_schema(ScoringFnParams, name="ScoringFnParams")
class CommonScoringFnFields(BaseModel):
@ -135,7 +134,7 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(

View file

@ -49,7 +49,7 @@ class Shields(Protocol):
async def list_shields(self) -> ListShieldsResponse: ...
@webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
async def get_shield(self, identifier: str) -> Shield: ...
@webmethod(route="/shields", method="POST")
async def register_shield(

View file

@ -96,6 +96,13 @@ class MetricEvent(EventCommon):
unit: str
@json_schema_type
class MetricInResponse(BaseModel):
metric: str
value: Union[int, float]
unit: Optional[str] = None
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be inlcuded with the response
@ -117,7 +124,7 @@ class MetricEvent(EventCommon):
class MetricResponseMixin(BaseModel):
metrics: Optional[List[MetricEvent]] = None
metrics: Optional[List[MetricInResponse]] = None
@json_schema_type
@ -139,16 +146,14 @@ class SpanEndPayload(BaseModel):
status: SpanStatus
StructuredLogPayload = register_schema(
Annotated[
Union[
SpanStartPayload,
SpanEndPayload,
],
Field(discriminator="type"),
StructuredLogPayload = Annotated[
Union[
SpanStartPayload,
SpanEndPayload,
],
name="StructuredLogPayload",
)
Field(discriminator="type"),
]
register_schema(StructuredLogPayload, name="StructuredLogPayload")
@json_schema_type
@ -157,17 +162,15 @@ class StructuredLogEvent(EventCommon):
payload: StructuredLogPayload
Event = register_schema(
Annotated[
Union[
UnstructuredLogEvent,
MetricEvent,
StructuredLogEvent,
],
Field(discriminator="type"),
Event = Annotated[
Union[
UnstructuredLogEvent,
MetricEvent,
StructuredLogEvent,
],
name="Event",
)
Field(discriminator="type"),
]
register_schema(Event, name="Event")
@json_schema_type

View file

@ -17,6 +17,15 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
@json_schema_type
class RAGDocument(BaseModel):
"""
A document to be used for document ingestion in the RAG Tool.
:param document_id: The unique identifier for the document.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
:param metadata: Additional metadata for the document.
"""
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
@ -49,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
template: str
RAGQueryGeneratorConfig = register_schema(
Annotated[
Union[
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
Field(discriminator="type"),
RAGQueryGeneratorConfig = Annotated[
Union[
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
name="RAGQueryGeneratorConfig",
)
Field(discriminator="type"),
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type

View file

@ -50,7 +50,7 @@ class VectorDBs(Protocol):
async def get_vector_db(
self,
vector_db_id: str,
) -> Optional[VectorDB]: ...
) -> VectorDB: ...
@webmethod(route="/vector-dbs", method="POST")
async def register_vector_db(

View file

@ -10,7 +10,7 @@ import json
import os
import shutil
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timezone
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional
@ -343,7 +343,7 @@ def _hf_download(
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.")
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.")
except Exception as e:
parser.error(e)
@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
d = json.load(f)
manifest = Manifest(**d)
if datetime.now() > manifest.expires_on:
if datetime.now(timezone.utc) > manifest.expires_on:
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console()

View file

@ -9,6 +9,7 @@ import argparse
from .download import Download
from .model import ModelParser
from .stack import StackParser
from .stack.utils import print_subcommand_description
from .verify_download import VerifyDownload
@ -20,6 +21,7 @@ class LlamaCLIParser:
prog="llama",
description="Welcome to the Llama CLI",
add_help=True,
formatter_class=argparse.RawTextHelpFormatter,
)
# Default command is to print help
@ -33,6 +35,8 @@ class LlamaCLIParser:
Download.create(subparsers)
VerifyDownload.create(subparsers)
print_subcommand_description(self.parser, subparsers)
def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()

View file

@ -7,8 +7,6 @@
import argparse
import json
from termcolor import colored
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.sku_list import resolve_model
@ -52,11 +50,12 @@ class ModelDescribe(Subcommand):
)
return
headers = [
"Model",
model.descriptor(),
]
rows = [
(
colored("Model", "white", attrs=["bold"]),
colored(model.descriptor(), "white", attrs=["bold"]),
),
("Hugging Face ID", model.huggingface_repo or "<Not Available>"),
("Description", model.description),
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
@ -65,7 +64,7 @@ class ModelDescribe(Subcommand):
]
if model.recommended_sampling_params is not None:
sampling_params = model.recommended_sampling_params.dict()
sampling_params = model.recommended_sampling_params.model_dump()
for k in ("max_tokens", "repetition_penalty"):
del sampling_params[k]
rows.append(
@ -77,5 +76,6 @@ class ModelDescribe(Subcommand):
print_table(
rows,
headers,
separate_rows=True,
)

View file

@ -12,6 +12,7 @@ from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.remove import ModelRemove
from llama_stack.cli.model.verify_download import ModelVerifyDownload
from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand
@ -24,6 +25,7 @@ class ModelParser(Subcommand):
"model",
prog="llama model",
description="Work with llama models",
formatter_class=argparse.RawTextHelpFormatter,
)
self.parser.set_defaults(func=lambda args: self.parser.print_help())
@ -37,3 +39,5 @@ class ModelParser(Subcommand):
ModelDescribe.create(subparsers)
ModelVerifyDownload.create(subparsers)
ModelRemove.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -7,10 +7,14 @@
import argparse
import textwrap
from io import StringIO
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
ROOT_DIR = Path(__file__).parent.parent.parent
class ModelPromptFormat(Subcommand):
"""Llama model cli for describe a model prompt format (message formats)"""
@ -37,8 +41,14 @@ class ModelPromptFormat(Subcommand):
"-m",
"--model-name",
type=str,
default="llama3_1",
help="Model Family (llama3_1, llama3_X, etc.)",
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
"(Run `llama model list` to see a list of valid model names)",
)
self.parser.add_argument(
"-l",
"--list",
action="store_true",
help="List all available models",
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
@ -48,18 +58,42 @@ class ModelPromptFormat(Subcommand):
supported_model_ids = [
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
]
model_str = "\n".join([m.value for m in supported_model_ids])
model_list = [m.value for m in supported_model_ids]
if args.list:
headers = ["Model(s)"]
rows = []
for m in model_list:
rows.append(
[
m,
]
)
print_table(
rows,
headers,
separate_rows=True,
)
return
try:
model_id = CoreModelId(args.model_name)
except ValueError:
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
self.parser.error(
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
f"Run `llama model list` to see the valid model names."
)
if model_id not in supported_model_ids:
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
self.parser.error(
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
f"Run `llama model list` to see the valid model names."
)
llama_3_1_file = importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
llama_3_2_text_file = importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
llama_3_2_vision_file = importlib.resources.files("llama_models") / "llama3_2/vision_prompt_format.md"
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
llama_3_2_vision_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "vision_prompt_format.md"
if model_family(model_id) == ModelFamily.llama3_1:
with importlib.resources.as_file(llama_3_1_file) as f:
content = f.open("r").read()

View file

@ -38,8 +38,8 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, in_notebook, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
@ -65,8 +65,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
if args.image_type == "venv":
current_venv = os.environ.get("VIRTUAL_ENV")
image_name = args.image_name or current_venv
if not image_name and in_notebook():
image_name = "__system__"
elif args.image_type == "conda":
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = args.image_name or current_conda_env
@ -143,7 +141,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
completer=WordCompleter(available_providers),
complete_while_typing=True,
validator=Validator.from_callable(
lambda x: x in available_providers,
lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
error_message="Invalid provider, use <TAB> to see options",
),
)
@ -172,7 +170,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
if build_config.image_type == ImageType.container.value and not args.image_name:
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
cprint(
"Please specify --image-name when building a container from a config file",
color="red",
@ -215,7 +213,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
config = parse_and_maybe_upgrade_config(config_dict)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
run_with_pty(run_args)
run_command(run_args)
def _generate_run_config(
@ -228,7 +226,7 @@ def _generate_run_config(
"""
apis = list(build_config.distribution_spec.providers.keys())
run_config = StackRunConfig(
container_image=(image_name if build_config.image_type == ImageType.container.value else None),
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
image_name=image_name,
apis=apis,
providers={},
@ -250,7 +248,7 @@ def _generate_run_config(
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
if hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__=f"distributions/{image_name}")
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
else:
config = {}
@ -281,16 +279,18 @@ def _run_stack_build_command_from_build_config(
template_name: Optional[str] = None,
config_path: Optional[str] = None,
) -> str:
if build_config.image_type == ImageType.container.value:
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
if template_name:
image_name = f"distribution-{template_name}"
else:
if not image_name:
raise ValueError("Please specify an image name when building a container image without a template")
elif build_config.image_type == ImageType.conda.value:
elif build_config.image_type == LlamaStackImageType.CONDA.value:
if not image_name:
raise ValueError("Please specify an image name when building a conda image")
elif build_config.image_type == ImageType.venv.value:
elif build_config.image_type == LlamaStackImageType.VENV.value:
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
image_name = "__system__"
if not image_name:
raise ValueError("Please specify an image name when building a venv image")

View file

@ -16,7 +16,7 @@ class StackBuild(Subcommand):
"build",
prog="llama stack build",
description="Build a Llama stack container",
formatter_class=argparse.RawTextHelpFormatter,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_build_command)
@ -26,7 +26,7 @@ class StackBuild(Subcommand):
"--config",
type=str,
default=None,
help="Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
help="Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
)
self.parser.add_argument(

View file

@ -9,9 +9,12 @@ import os
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
class StackRun(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
@ -20,7 +23,7 @@ class StackRun(Subcommand):
"run",
prog="llama stack run",
description="""Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
formatter_class=argparse.RawTextHelpFormatter,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_run_cmd)
@ -34,12 +37,13 @@ class StackRun(Subcommand):
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. Defaults to 8321",
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT.",
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
)
self.parser.add_argument(
"--image-name",
type=str,
default=os.environ.get("CONDA_DEFAULT_ENV"),
help="Name of the image to run. Defaults to the current conda environment",
)
self.parser.add_argument(
@ -75,19 +79,10 @@ class StackRun(Subcommand):
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
from termcolor import cprint
from llama_stack.distribution.build import ImageType
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import (
BUILDS_BASE_DIR,
DISTRIBS_BASE_DIR,
)
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
if not args.config:
self.parser.error("Must specify a config file to run")
return
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
@ -99,14 +94,6 @@ class StackRun(Subcommand):
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to conda dir
config_file = Path(BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml")
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to container dir
config_file = Path(BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml")
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
@ -115,11 +102,23 @@ class StackRun(Subcommand):
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
return
print(f"Using run configuration: {config_file}")
config_dict = yaml.safe_load(config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file} is not a file: type={type(config_file)}"
)
logger.info(f"Using run configuration: {config_file}")
try:
config_dict = yaml.safe_load(config_file.read_text())
except yaml.parser.ParserError as e:
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
try:
config = parse_and_maybe_upgrade_config(config_dict)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
run_args = formulate_run_args(args.image_type, args.image_name, config, template_name)
@ -129,20 +128,12 @@ class StackRun(Subcommand):
for env_var in args.env:
if "=" not in env_var:
cprint(
f"Environment variable '{env_var}' must be in KEY=VALUE format",
color="red",
)
return
self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
key, value = env_var.split("=", 1) # split on first = only
if not key:
cprint(
f"Environment variable '{env_var}' has empty key",
color="red",
)
return
self.parser.error(f"Environment variable '{env_var}' has empty key")
run_args.extend(["--env", f"{key}={value}"])
if args.tls_keyfile and args.tls_certfile:
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
run_with_pty(run_args)
run_command(run_args)

View file

@ -7,6 +7,7 @@
import argparse
from importlib.metadata import version
from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand
from .build import StackBuild
@ -22,6 +23,7 @@ class StackParser(Subcommand):
"stack",
prog="llama stack",
description="Operations for the Llama Stack / Distributions",
formatter_class=argparse.RawTextHelpFormatter,
)
self.parser.add_argument(
@ -39,3 +41,5 @@ class StackParser(Subcommand):
StackListApis.create(subparsers)
StackListProviders.create(subparsers)
StackRun.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
def print_subcommand_description(parser, subparsers):
"""Print descriptions of subcommands."""
description_text = ""
for name, subcommand in subparsers.choices.items():
description = subcommand.description
description_text += f" {name:<21} {description}\n"
parser.epilog = description_text

View file

@ -1,127 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
import pytest
import yaml
from llama_stack.distribution.configure import (
LLAMA_STACK_RUN_CONFIG_VERSION,
parse_and_maybe_upgrade_config,
)
@pytest.fixture
def up_to_date_config():
return yaml.safe_load(
"""
version: {version}
image_name: foo
apis_to_serve: []
built_at: {built_at}
providers:
inference:
- provider_id: provider1
provider_type: inline::meta-reference
config: {{}}
safety:
- provider_id: provider1
provider_type: inline::meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
enable_prompt_guard: false
memory:
- provider_id: provider1
provider_type: inline::meta-reference
config: {{}}
""".format(version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat())
)
@pytest.fixture
def old_config():
return yaml.safe_load(
"""
image_name: foo
built_at: {built_at}
apis_to_serve: []
routing_table:
inference:
- provider_type: remote::ollama
config:
host: localhost
port: 11434
routing_key: Llama3.2-1B-Instruct
- provider_type: inline::meta-reference
config:
model: Llama3.1-8B-Instruct
routing_key: Llama3.1-8B-Instruct
safety:
- routing_key: ["shield1", "shield2"]
provider_type: inline::meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
enable_prompt_guard: false
memory:
- routing_key: vector
provider_type: inline::meta-reference
config: {{}}
api_providers:
telemetry:
provider_type: noop
config: {{}}
""".format(built_at=datetime.now().isoformat())
)
@pytest.fixture
def invalid_config():
return yaml.safe_load(
"""
routing_table: {}
api_providers: {}
"""
)
def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
result = parse_and_maybe_upgrade_config(up_to_date_config)
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
assert "inference" in result.providers
def test_parse_and_maybe_upgrade_config_old_format(old_config):
result = parse_and_maybe_upgrade_config(old_config)
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
assert all(api in result.providers for api in ["inference", "safety", "memory", "telemetry"])
safety_provider = result.providers["safety"][0]
assert safety_provider.provider_type == "meta-reference"
assert "llama_guard_shield" in safety_provider.config
inference_providers = result.providers["inference"]
assert len(inference_providers) == 2
assert set(x.provider_id for x in inference_providers) == {
"remote::ollama-00",
"meta-reference-01",
}
ollama = inference_providers[0]
assert ollama.provider_type == "remote::ollama"
assert ollama.config["port"] == 11434
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
with pytest.raises(ValueError):
parse_and_maybe_upgrade_config(invalid_config)

View file

@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict[str, Any]] = None) -> bool:
"""Check if the current user has access to the given object, based on access attributes.
Access control algorithm:
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
3. For each attribute category in the resource's access_attributes:
a. If the user lacks that category, access is DENIED
b. If the user has the category but none of the required values, access is DENIED
c. If the user has at least one matching value in each required category, access is GRANTED
Example:
# Resource requires:
access_attributes = AccessAttributes(
roles=["admin", "data-scientist"],
teams=["ml-team"]
)
# User has:
user_attributes = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "infra-team"],
"projects": ["llama-3"]
}
# Result: Access GRANTED
# - User has the "data-scientist" role (matches one of the required roles)
# - AND user is part of the "ml-team" (matches the required team)
# - The extra "projects" attribute is ignored
Args:
obj: The resource object to check access for
Returns:
bool: True if access is granted, False if denied
"""
# If object has no access attributes, allow access by default
if not hasattr(obj, "access_attributes") or not obj.access_attributes:
return True
# If no user attributes, deny access to objects with access control
if not user_attributes:
return False
obj_attributes = obj.access_attributes.model_dump(exclude_none=True)
if not obj_attributes:
return True
# Check each attribute category (requires ALL categories to match)
for attr_key, required_values in obj_attributes.items():
user_values = user_attributes.get(attr_key, [])
if not user_values:
logger.debug(
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
)
return False
if not any(val in user_values for val in required_values):
logger.debug(
f"Access denied to {obj.type} '{obj.identifier}': "
f"no match for attribute '{attr_key}', required one of {required_values}"
)
return False
logger.debug(f"Access granted to {obj.type} '{obj.identifier}'")
return True

View file

@ -6,7 +6,6 @@
import importlib.resources
import logging
import sys
from pathlib import Path
from typing import Dict, List
@ -15,9 +14,8 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.distribution.utils.exec import run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
@ -96,18 +94,16 @@ def build_image(
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == ImageType.container.value:
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
args = [
script,
template_or_config,
image_name,
container_base,
str(build_file_path),
str(BUILDS_BASE_DIR / ImageType.container.value),
" ".join(normal_deps),
]
elif build_config.image_type == ImageType.conda.value:
elif build_config.image_type == LlamaStackImageType.CONDA.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
args = [
script,
@ -115,7 +111,7 @@ def build_image(
str(build_file_path),
" ".join(normal_deps),
]
elif build_config.image_type == ImageType.venv.value:
elif build_config.image_type == LlamaStackImageType.VENV.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
args = [
script,
@ -126,11 +122,7 @@ def build_image(
if special_deps:
args.append("#".join(special_deps))
is_terminal = sys.stdin.isatty()
if is_terminal:
return_code = run_with_pty(args)
else:
return_code = run_command(args)
return_code = run_command(args)
if return_code != 0:
log.error(

View file

@ -6,8 +6,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
@ -16,8 +16,8 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi
if [ "$#" -lt 3 ]; then
@ -52,7 +52,7 @@ ensure_conda_env_python310() {
local python_version="3.10"
# Check if conda command is available
if ! command -v conda &>/dev/null; then
if ! is_command_available conda; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1
fi
@ -87,8 +87,6 @@ ensure_conda_env_python310() {
# these packages are damaged in test-pypi, so install them first
uv pip install fastapi libcst
uv pip install --extra-index-url https://test.pypi.org/simple/ \
llama-models==$TEST_PYPI_VERSION \
llama-stack-client==$TEST_PYPI_VERSION \
llama-stack==$TEST_PYPI_VERSION \
$pip_dependencies
if [ -n "$special_pip_deps" ]; then
@ -111,22 +109,21 @@ ensure_conda_env_python310() {
else
PYPI_VERSION="${PYPI_VERSION:-}"
if [ -n "$PYPI_VERSION" ]; then
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}"
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
else
SPEC_VERSION="llama-stack"
fi
uv pip install --no-cache-dir $SPEC_VERSION
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
uv pip uninstall llama-models
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
fi
# Install pip dependencies

View file

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
@ -6,7 +6,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
@ -20,35 +19,38 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
if [ "$#" -lt 6 ]; then
if [ "$#" -lt 4 ]; then
# This only works for templates
echo "Usage: $0 <template_or_config> <image_name> <container_base> <build_file_path> <host_build_dir> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
exit 1
fi
set -euo pipefail
template_or_config="$1"
image_name="$2"
container_base="$3"
build_file_path="$4"
host_build_dir="$5"
pip_dependencies="$6"
special_pip_deps="${7:-}"
shift
image_name="$1"
shift
container_base="$1"
shift
pip_dependencies="$1"
shift
special_pip_deps="${1:-}"
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:-}
CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
TEMP_DIR=$(mktemp -d)
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
add_to_container() {
local input
output_file="$TEMP_DIR/Containerfile"
if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file"
@ -58,15 +60,21 @@ add_to_container() {
fi
}
# Check if container command is available
if ! is_command_available $CONTAINER_BINARY; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1
fi
# Update and install UBI9 components if UBI9 base image is used
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
add_to_container << EOF
FROM $container_base
WORKDIR /app
RUN microdnf -y update && microdnf install -y iputils net-tools wget \
RUN dnf -y update && dnf install -y iputils net-tools wget \
vim-minimal python3.11 python3.11-pip python3.11-wheel \
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1
RUN pip install uv
@ -107,7 +115,6 @@ EOF
fi
stack_mount="/app/llama-stack-source"
models_mount="/app/llama-models-source"
client_mount="/app/llama-stack-client-source"
install_local_package() {
@ -131,10 +138,6 @@ EOF
}
if [ -n "$LLAMA_MODELS_DIR" ]; then
install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
fi
@ -150,12 +153,12 @@ EOF
add_to_container << EOF
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
llama-stack==$TEST_PYPI_VERSION
EOF
else
if [ -n "$PYPI_VERSION" ]; then
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}"
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
else
SPEC_VERSION="llama-stack"
fi
@ -165,6 +168,11 @@ EOF
fi
fi
# remove uv after installation
add_to_container << EOF
RUN pip uninstall -y uv
EOF
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
if [[ "$template_or_config" != *.yaml ]]; then
add_to_container << EOF
@ -185,26 +193,28 @@ RUN mkdir -p /.llama /.cache
RUN chmod -R g+rw /app /.llama /.cache
EOF
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n"
cat $TEMP_DIR/Containerfile
printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
cat "$TEMP_DIR"/Containerfile
printf "\n"
mounts=""
# Start building the CLI arguments
CLI_ARGS=()
# Read CONTAINER_OPTS and put it in an array
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount"
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
fi
fi
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
if is_command_available selinuxenabled && selinuxenabled; then
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
CLI_ARGS+=("--security-opt" "label=disable")
fi
# Set version tag based on PyPI version
@ -212,7 +222,7 @@ if [ -n "$PYPI_VERSION" ]; then
version_tag="$PYPI_VERSION"
elif [ -n "$TEST_PYPI_VERSION" ]; then
version_tag="test-$TEST_PYPI_VERSION"
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_MODELS_DIR" ]]; then
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_STACK_CLIENT_DIR" ]]; then
version_tag="dev"
else
URL="https://pypi.org/pypi/llama-stack/json"
@ -225,11 +235,11 @@ image_tag="$image_name:$version_tag"
# Detect platform architecture
ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then
PLATFORM="--platform $BUILD_PLATFORM"
CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
PLATFORM="--platform linux/arm64"
CLI_ARGS+=("--platform" "linux/arm64")
elif [ "$ARCH" = "x86_64" ]; then
PLATFORM="--platform linux/amd64"
CLI_ARGS+=("--platform" "linux/amd64")
else
echo "Unsupported architecture: $ARCH"
exit 1
@ -238,8 +248,12 @@ fi
echo "PWD: $(pwd)"
echo "Containerfile: $TEMP_DIR/Containerfile"
set -x
$CONTAINER_BINARY build $CONTAINER_OPTS $PLATFORM -t $image_tag \
-f "$TEMP_DIR/Containerfile" "." $mounts --progress=plain
$CONTAINER_BINARY build \
"${CLI_ARGS[@]}" \
-t "$image_tag" \
-f "$TEMP_DIR/Containerfile" \
"."
# clean up tmp/configs
set +x

View file

@ -9,8 +9,8 @@
# TODO: combine this with build_conda_env.sh since it is almost identical
# the only difference is that we don't do any conda-specific setup
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
@ -21,13 +21,13 @@ VIRTUAL_ENV=${VIRTUAL_ENV:-}
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi
if [ "$#" -lt 2 ]; then
echo "Usage: $0 <distribution_type> <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
echo "Usage: $0 <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Example: $0 mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
exit 1
fi
@ -95,7 +95,7 @@ run() {
# we are building a command line so word splitting is expected
uv pip install --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
llama-stack=="$TEST_PYPI_VERSION" \
$pip_dependencies
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
@ -120,15 +120,14 @@ run() {
uv pip install --no-cache-dir llama-stack
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_MODELS_DIR" >&2
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
exit 1
fi
printf "Installing from LLAMA_MODELS_DIR: %s\n" "$LLAMA_MODELS_DIR"
uv pip uninstall llama-models
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
fi
# Install pip dependencies

View file

@ -39,7 +39,7 @@ def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provi
return Provider(
provider_id=provider.provider_id,
provider_type=provider.provider_type,
config=cfg.dict(),
config=cfg.model_dump(),
)
@ -62,7 +62,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if config.apis:
apis_to_serve = config.apis
else:
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]
for api_str in apis_to_serve:
api = Api(api_str)

View file

@ -1,47 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
set -euo pipefail
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 2 ]; then
echo "Usage: $0 <container name> <build file path>"
exit 1
fi
container_image="$1"
host_build_dir="$2"
container_build_dir="/app/builds"
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
# Disable SELinux labels
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
fi
mounts=""
if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
fi
set -x
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
--entrypoint "/usr/local/bin/llama" \
-v $host_build_dir:$container_build_dir \
$mounts \
$container_image \
stack configure ./llamastack-build.yaml --output-dir $container_build_dir

View file

@ -14,6 +14,7 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.resource import Resource
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
@ -31,6 +32,115 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = Union[str, List[str]]
class AccessAttributes(BaseModel):
"""Structured representation of user attributes for access control.
This model defines a structured approach to representing user attributes
with common standard categories for access control.
Standard attribute categories include:
- roles: Role-based attributes (e.g., admin, data-scientist)
- teams: Team-based attributes (e.g., ml-team, infra-team)
- projects: Project access attributes (e.g., llama-3, customer-insights)
- namespaces: Namespace-based access control for resource isolation
"""
# Standard attribute categories - the minimal set we need now
roles: Optional[List[str]] = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: Optional[List[str]] = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: Optional[List[str]] = Field(
default=None, description="Namespace-based access control for resource isolation"
)
class ResourceWithACL(Resource):
"""Extension of Resource that adds attribute-based access control capabilities.
This class adds an optional access_attributes field that allows fine-grained control
over which users can access each resource. When attributes are defined, a user must have
matching attributes to access the resource.
Attribute Matching Algorithm:
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
4. Within each category, ANY value match is sufficient (OR relationship within a category)
Examples:
# Resource visible to everyone (no access control)
model = Model(identifier="llama-2", ...)
# Resource visible only to admins
model = Model(
identifier="gpt-4",
access_attributes=AccessAttributes(roles=["admin"])
)
# Resource visible to data scientists on the ML team
model = Model(
identifier="private-model",
access_attributes=AccessAttributes(
roles=["data-scientist", "researcher"],
teams=["ml-team"]
)
)
# ^ User must have at least one of the roles AND be on the ml-team
# Resource visible to users with specific project access
vector_db = VectorDB(
identifier="customer-embeddings",
access_attributes=AccessAttributes(
projects=["customer-insights"],
namespaces=["confidential"]
)
)
# ^ User must have access to the customer-insights project AND have confidential namespace
"""
access_attributes: Optional[AccessAttributes] = None
# Use the extended Resource for all routable objects
class ModelWithACL(Model, ResourceWithACL):
pass
class ShieldWithACL(Shield, ResourceWithACL):
pass
class VectorDBWithACL(VectorDB, ResourceWithACL):
pass
class DatasetWithACL(Dataset, ResourceWithACL):
pass
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
pass
class BenchmarkWithACL(Benchmark, ResourceWithACL):
pass
class ToolWithACL(Tool, ResourceWithACL):
pass
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
pass
RoutableObject = Union[
Model,
Shield,
@ -45,14 +155,14 @@ RoutableObject = Union[
RoutableObjectWithProvider = Annotated[
Union[
Model,
Shield,
VectorDB,
Dataset,
ScoringFn,
Benchmark,
Tool,
ToolGroup,
ModelWithACL,
ShieldWithACL,
VectorDBWithACL,
DatasetWithACL,
ScoringFnWithACL,
BenchmarkWithACL,
ToolWithACL,
ToolGroupWithACL,
],
Field(discriminator="type"),
]
@ -117,6 +227,21 @@ class Provider(BaseModel):
config: Dict[str, Any]
class LoggingConfig(BaseModel):
category_levels: Dict[str, str] = Field(
default_factory=Dict,
description="""
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
)
class AuthenticationConfig(BaseModel):
endpoint: str = Field(
...,
description="Endpoint URL to validate authentication tokens",
)
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
@ -132,6 +257,10 @@ class ServerConfig(BaseModel):
default=None,
description="Path to TLS key file for HTTPS",
)
auth: Optional[AuthenticationConfig] = Field(
default=None,
description="Authentication configuration for the server",
)
class StackRunConfig(BaseModel):
@ -176,6 +305,8 @@ a default SQLite store will be used.""",
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
server: ServerConfig = Field(
default_factory=ServerConfig,
description="Configuration for the HTTP(S) server",

View file

@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]:
return [v for v in Api]
return list(Api)
class AutoRoutedApiInfo(BaseModel):
@ -55,8 +55,8 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]:
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:

View file

@ -11,9 +11,7 @@ from pydantic import BaseModel
from llama_stack.apis.inspect import (
HealthInfo,
Inspect,
ListProvidersResponse,
ListRoutesResponse,
ProviderInfo,
RouteInfo,
VersionInfo,
)
@ -39,24 +37,6 @@ class DistributionInspectImpl(Inspect):
async def initialize(self) -> None:
pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
ret = []
for api, providers in run_config.providers.items():
ret.extend(
[
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
)
for p in providers
]
)
return ListProvidersResponse(data=ret)
async def list_routes(self) -> ListRoutesResponse:
run_config = self.config.run_config

View file

@ -32,7 +32,10 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
@ -41,8 +44,10 @@ from llama_stack.distribution.stack import (
redact_sensitive_fields,
replace_env_vars,
)
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.distribution.utils.exec import in_notebook
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger,
start_trace,
@ -104,7 +109,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
logger.warning(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
)
return value
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
class LlamaStackAsLibraryClient(LlamaStackClient):
@ -160,6 +165,9 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
except StopAsyncIteration:
pass
finally:
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
return sync_generator()
@ -262,21 +270,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
# Create headers with provider data if available
headers = {}
if self.provider_data:
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
if stream:
response = await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
response = await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
return response
# Use context manager for provider data
with request_provider_data_context(headers):
if stream:
response = await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
response = await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
"""Find the matching endpoint implementation for a given method and path.
@ -324,6 +336,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
await end_trace()
json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=json_content.encode("utf-8"),
@ -335,7 +348,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url,
params=options.params,
headers=options.headers or {},
json=options.json_data,
json=convert_pydantic_to_json_value(body),
),
)
response = APIResponse(
@ -373,9 +386,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
finally:
await end_trace()
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=gen(),
content=wrapped_gen,
headers={
"Content-Type": "application/json",
},
@ -384,7 +399,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url,
params=options.params,
headers=options.headers or {},
json=options.json_data,
json=convert_pydantic_to_json_value(body),
),
)

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.log import get_logger
from .datatypes import StackRunConfig
from .stack import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
class ProviderImplConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = ProviderImpl(config, deps)
await impl.initialize()
return impl
class ProviderImpl(Providers):
def __init__(self, config, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown")
pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
ret = []
for api, providers in safe_config.providers.items():
ret.extend(
[
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
config=p.config,
)
for p in providers
]
)
return ListProvidersResponse(data=ret)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
all_providers = await self.list_providers()
for p in all_providers.data:
if p.provider_id == provider_id:
return p
raise ValueError(f"Provider {provider_id} not found")

View file

@ -4,16 +4,40 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import contextvars
import json
import logging
import threading
from typing import Any, Dict
from typing import Any, ContextManager, Dict, List, Optional
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
_THREAD_LOCAL = threading.local()
# Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(ContextManager):
"""Context manager for request provider data"""
def __init__(
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
):
self.provider_data = provider_data or {}
if auth_attributes:
self.provider_data["__auth_attributes"] = auth_attributes
self.token = None
def __enter__(self):
# Save the current value and set the new one
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Restore the previous value
if self.token is not None:
PROVIDER_DATA_VAR.reset(self.token)
class NeedsRequestProviderData:
@ -26,7 +50,7 @@ class NeedsRequestProviderData:
if not validator_class:
raise ValueError(f"Provider {provider_type} does not have a validator")
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
val = PROVIDER_DATA_VAR.get()
if not val:
return None
@ -36,25 +60,42 @@ class NeedsRequestProviderData:
return provider_data
except Exception as e:
log.error(f"Error parsing provider data: {e}")
return None
def set_request_provider_data(headers: Dict[str, str]):
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
"""Parse provider data from request headers"""
keys = [
"X-LlamaStack-Provider-Data",
"x-llamastack-provider-data",
]
val = None
for key in keys:
val = headers.get(key, None)
if val:
break
if not val:
return
return None
try:
val = json.loads(val)
return json.loads(val)
except json.JSONDecodeError:
log.error("Provider data not encoded as a JSON object!", val)
return
log.error("Provider data not encoded as a JSON object!")
return None
_THREAD_LOCAL.provider_data_header_value = val
def request_provider_data_context(
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
) -> ContextManager:
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data, auth_attributes)
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
"""Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return None
return provider_data.get("__auth_attributes")

View file

@ -5,8 +5,7 @@
# the root directory of this source tree.
import importlib
import inspect
import logging
from typing import Any, Dict, List, Set
from typing import Any, Dict, List, Set, Tuple
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
@ -17,6 +16,7 @@ from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers as ProvidersAPI
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
@ -35,6 +35,7 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
Api,
BenchmarksProtocolPrivate,
@ -50,7 +51,7 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate,
)
log = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="core")
class InvalidProviderError(Exception):
@ -59,6 +60,7 @@ class InvalidProviderError(Exception):
def api_protocol_map() -> Dict[Api, Any]:
return {
Api.providers: ProvidersAPI,
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
@ -104,60 +106,43 @@ class ProviderWithSpec(Provider):
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
Resolves provider implementations by:
1. Validating and organizing providers.
2. Sorting them in dependency order.
3. Instantiating them with required dependencies.
"""
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
specs = {}
for provider in providers:
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
log.error(p.deprecation_error, "red", attrs=["bold"])
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
spec = ProviderWithSpec(
spec=p,
**(provider.model_dump()),
)
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
providers_with_specs = validate_and_prepare_providers(
run_config, provider_registry, routing_table_apis, router_apis
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
"""Generates specifications for automatically routed APIs."""
specs = {}
for info in builtin_automatically_routed_apis():
if info.router_api.value not in apis_to_serve:
continue
providers_with_specs[info.routing_table_api.value] = {
specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__routing_table__",
provider_type="__routing_table__",
@ -167,12 +152,12 @@ async def resolve_impls(
router_api=info.router_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
deps__=([f"inner-{info.router_api.value}"]),
deps__=[f"inner-{info.router_api.value}"],
),
)
}
providers_with_specs[info.router_api.value] = {
specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",
provider_type="__autorouted__",
@ -182,12 +167,68 @@ async def resolve_impls(
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
# Add telemetry as an optional dependency to all auto-routed providers
optional_api_dependencies=[Api.telemetry],
deps__=([info.routing_table_api.value, Api.telemetry.value]),
),
)
}
return specs
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
def validate_and_prepare_providers(
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
) -> Dict[str, Dict[str, ProviderWithSpec]]:
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)
p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
spec = ProviderWithSpec(spec=p, **provider.model_dump())
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
return providers_with_specs
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
"""Validates if the provider is allowed and handles deprecations."""
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
logger.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
logger.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
def sort_providers_by_deps(
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
) -> List[Tuple[str, ProviderWithSpec]]:
"""Sorts providers based on their dependencies."""
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
# Append built-in "inspect" provider
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
(
@ -195,28 +236,51 @@ async def resolve_impls(
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={
"run_config": run_config.dict(),
},
config={"run_config": run_config.model_dump()},
spec=InlineProviderSpec(
api=Api.inspect,
provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect",
api_dependencies=apis,
deps__=([x.value for x in apis]),
deps__=[x.value for x in apis],
),
),
)
)
log.info(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
log.info(f" {api_str} => {provider.provider_id}")
log.info("")
sorted_providers.append(
(
"providers",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={"run_config": run_config.model_dump()},
spec=InlineProviderSpec(
api=Api.providers,
provider_type="__builtin__",
config_class="llama_stack.distribution.providers.ProviderImplConfig",
module="llama_stack.distribution.providers",
api_dependencies=apis,
deps__=[x.value for x in apis],
),
),
)
)
impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}")
logger.debug("")
return sorted_providers
async def instantiate_providers(
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
) -> Dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: Dict[Api, Any] = {}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies:
@ -227,14 +291,9 @@ async def resolve_impls(
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(
provider,
deps,
inner_impls,
dist_registry,
)
# TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str:
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
@ -245,7 +304,7 @@ async def resolve_impls(
def topological_sort(
providers_with_specs: Dict[str, List[ProviderWithSpec]],
) -> List[ProviderWithSpec]:
) -> List[Tuple[str, ProviderWithSpec]]:
def dfs(kv, visited: Set[str], stack: List[str]):
api_str, providers = kv
visited.add(api_str)
@ -261,8 +320,8 @@ def topological_sort(
stack.append(api_str)
visited = set()
stack = []
visited: Set[str] = set()
stack: List[str] = []
for api_str, providers in providers_with_specs.items():
if api_str not in visited:
@ -272,13 +331,14 @@ def topological_sort(
for api_str in stack:
for provider in providers_with_specs[api_str]:
flattened.append((api_str, provider))
return flattened
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider: ProviderWithSpec,
deps: Dict[str, Any],
deps: Dict[Api, Any],
inner_impls: Dict[str, Any],
dist_registry: DistributionRegistry,
):
@ -286,8 +346,10 @@ async def instantiate_provider(
additional_protocols = additional_protocols_map()
provider_spec = provider.spec
module = importlib.import_module(provider_spec.module)
if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
config_type = instantiate_class_type(provider_spec.config_class)
@ -350,7 +412,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class

View file

@ -45,7 +45,7 @@ async def get_routing_table_impl(
return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
from .routers import (
DatasetIORouter,
EvalRouter,
@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
}
api_to_deps = {
"inference": {"telemetry": Api.telemetry},
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_routers[api.value](routing_table)
api_to_dep_impl = {}
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
if dep_api in deps:
api_to_dep_impl[dep_name] = deps[dep_api]
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize()
return impl

View file

@ -4,14 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import (
BenchmarkConfig,
Eval,
@ -20,6 +22,10 @@ from llama_stack.apis.eval import (
JobStatus,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -27,13 +33,14 @@ from llama_stack.apis.inference import (
Message,
ResponseFormat,
SamplingParams,
StopReason,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
ScoreBatchResponse,
@ -42,6 +49,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
@ -51,8 +59,13 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO):
@ -62,12 +75,15 @@ class VectorIORouter(VectorIO):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown")
pass
async def register_vector_db(
@ -78,6 +94,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
@ -92,6 +109,9 @@ class VectorIORouter(VectorIO):
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
@ -100,6 +120,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -109,13 +130,21 @@ class InferenceRouter(Inference):
def __init__(
self,
routing_table: RoutingTable,
telemetry: Optional[Telemetry] = None,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown")
pass
async def register_model(
@ -126,13 +155,85 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
Args:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total number of tokens used
model: Model object containing model_id and provider_id
Returns:
List of MetricEvent objects with token usage metrics
"""
span = get_current_span()
if span is None:
logger.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
return metric_events
async def _compute_and_log_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def _count_tokens(
self,
messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None,
@ -140,7 +241,12 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -159,8 +265,6 @@ class InferenceRouter(Inference):
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
tools = tools or []
if tool_config.tool_choice == ToolChoice.none:
tools = []
@ -187,20 +291,67 @@ class InferenceRouter(Inference):
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
return await provider.chat_completion(**params)
response = await provider.chat_completion(**params)
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
logger.debug(
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -215,10 +366,41 @@ class InferenceRouter(Inference):
stream=stream,
logprobs=logprobs,
)
prompt_tokens = await self._count_tokens(content)
if stream:
return (chunk async for chunk in await provider.completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
return await provider.completion(**params)
response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def embeddings(
self,
@ -228,6 +410,7 @@ class InferenceRouter(Inference):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -247,12 +430,15 @@ class SafetyRouter(Safety):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
@ -262,6 +448,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
@ -270,6 +457,7 @@ class SafetyRouter(Safety):
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
@ -282,29 +470,51 @@ class DatasetIORouter(DatasetIO):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown")
pass
async def get_rows_paginated(
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
filter_condition=filter_condition,
start_index=start_index,
limit=limit,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
@ -316,12 +526,15 @@ class ScoringRouter(Scoring):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown")
pass
async def score_batch(
@ -330,6 +543,7 @@ class ScoringRouter(Scoring):
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
@ -350,6 +564,7 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
@ -367,22 +582,26 @@ class EvalRouter(Eval):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown")
pass
async def run_eval(
self,
benchmark_id: str,
task_config: BenchmarkConfig,
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
task_config=task_config,
benchmark_config=benchmark_config,
)
async def evaluate_rows(
@ -390,13 +609,14 @@ class EvalRouter(Eval):
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: BenchmarkConfig,
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
task_config=task_config,
benchmark_config=benchmark_config,
)
async def job_status(
@ -404,6 +624,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> Optional[JobStatus]:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel(
@ -411,6 +632,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
@ -421,6 +643,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
@ -433,6 +656,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
@ -441,7 +665,8 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl("query_from_memory").query(
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
@ -451,6 +676,9 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
@ -459,6 +687,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -467,12 +696,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
@ -481,4 +713,5 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import logging
import uuid
from typing import Any, Dict, List, Optional
from pydantic import TypeAdapter
@ -12,7 +13,16 @@ from pydantic import TypeAdapter
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
@ -31,11 +41,22 @@ from llama_stack.apis.tools import (
ToolHost,
)
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import (
AccessAttributes,
BenchmarkWithACL,
DatasetWithACL,
ModelWithACL,
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
ScoringFnWithACL,
ShieldWithACL,
ToolGroupWithACL,
ToolWithACL,
VectorDBWithACL,
)
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
@ -176,6 +197,11 @@ class CommonRoutingTableImpl(RoutingTable):
if not obj:
return None
# Check if user has permission to access this object
if not check_access(obj, get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
@ -192,6 +218,13 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
@ -204,15 +237,24 @@ class CommonRoutingTableImpl(RoutingTable):
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type == type]
filtered_objs = [obj for obj in objs if obj.type == type]
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [obj for obj in filtered_objs if check_access(obj, get_auth_attributes())]
return filtered_objs
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def get_model(self, model_id: str) -> Optional[Model]:
return await self.get_object_by_identifier("model", model_id)
async def get_model(self, model_id: str) -> Model:
model = await self.get_object_by_identifier("model", model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
return model
async def register_model(
self,
@ -238,7 +280,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = Model(
model = ModelWithACL(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
@ -259,8 +301,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier)
async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier)
if shield is None:
raise ValueError(f"Shield '{identifier}' not found")
return shield
async def register_shield(
self,
@ -281,7 +326,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
)
if params is None:
params = {}
shield = Shield(
shield = ShieldWithACL(
identifier=shield_id,
provider_resource_id=provider_shield_id,
provider_id=provider_id,
@ -295,8 +340,11 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
return await self.get_object_by_identifier("vector_db", vector_db_id)
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise ValueError(f"Vector DB '{vector_db_id}' not found")
return vector_db
async def register_vector_db(
self,
@ -309,23 +357,17 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await self.get_object_by_identifier("model", embedding_model)
if model is None:
if embedding_model == "all-MiniLM-L6-v2":
raise ValueError(
"Embeddings are now served via Inference providers. "
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
)
else:
raise ValueError(f"Model {embedding_model} not found")
raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata:
@ -338,7 +380,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
}
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
@ -353,39 +395,56 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id)
async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)
if dataset is None:
raise ValueError(f"Dataset '{dataset_id}' not found")
return dataset
async def register_dataset(
self,
dataset_id: str,
dataset_schema: Dict[str, ParamType],
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
if provider_dataset_id is None:
provider_dataset_id = dataset_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this dataset
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
dataset_id: Optional[str] = None,
) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
elif source["type"] == "rows":
source = RowsDataSource.parse_obj(source)
if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"
provider_dataset_id = dataset_id
# infer provider from source
if source.type == DatasetType.rows.value:
provider_id = "localfs"
elif source.type == DatasetType.uri.value:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
provider_id = "localfs"
else:
raise ValueError(f"Unknown data source type: {source.type}")
if metadata is None:
metadata = {}
dataset = Dataset(
dataset = DatasetWithACL(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
dataset_schema=dataset_schema,
url=url,
purpose=purpose,
source=source,
metadata=metadata,
)
await self.register_object(dataset)
return dataset
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
@ -398,8 +457,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
async def register_scoring_function(
self,
@ -419,7 +481,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFn(
scoring_fn = ScoringFnWithACL(
identifier=scoring_fn_id,
description=description,
return_type=return_type,
@ -435,8 +497,11 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]:
return await self.get_object_by_identifier("benchmark", benchmark_id)
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark '{benchmark_id}' not found")
return benchmark
async def register_benchmark(
self,
@ -458,7 +523,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
)
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
benchmark = Benchmark(
benchmark = BenchmarkWithACL(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
@ -480,7 +545,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id)
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group '{toolgroup_id}' not found")
return tool_group
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)
@ -498,7 +566,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
for tool_def in tool_defs:
tools.append(
Tool(
ToolWithACL(
identifier=tool_def.name,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
@ -523,7 +591,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
await self.register_object(tool)
await self.dist_registry.register(
ToolGroup(
ToolGroupWithACL(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
@ -536,7 +604,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id).data
tools = (await self.list_tools(toolgroup_id)).data
for tool in tools:
await self.unregister_object(tool)
await self.unregister_object(tool_group)

View file

@ -0,0 +1,203 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Dict, List, Optional
from urllib.parse import parse_qs
import httpx
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: Dict[str, List[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
class AuthRequest(BaseModel):
api_key: str = Field(description="The API key extracted from the Authorization header")
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
access_attributes: Optional[AccessAttributes] = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
message: Optional[str] = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
class AuthenticationMiddleware:
"""Middleware that authenticates requests using an external auth endpoint.
This middleware:
1. Extracts the Bearer token from the Authorization header
2. Sends it to the configured auth endpoint along with request details
3. Validates the response and extracts user attributes
4. Makes these attributes available to the route handlers for access control
Authentication Request Format:
```json
{
"api_key": "the-api-key-extracted-from-auth-header",
"request": {
"path": "/models/list",
"headers": {
"content-type": "application/json",
"user-agent": "..."
// All headers except Authorization
},
"params": {
"limit": ["100"],
"offset": ["0"]
// Query parameters as key -> list of values
}
}
}
```
Expected Auth Endpoint Response Format:
```json
{
"access_attributes": { // Structured attribute format
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
"namespaces": ["research"]
},
"message": "Optional message about auth result"
}
```
Attribute-Based Access Control:
The attributes returned by the auth endpoint are used to determine which
resources the user can access. Resources can specify required attributes
using the access_attributes field. For a user to access a resource:
1. All attribute categories specified in the resource must be present in the user's attributes
2. For each category, the user must have at least one matching value
If the auth endpoint doesn't return any attributes, the user will only be able to
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_endpoint):
self.app = app
self.auth_endpoint = auth_endpoint
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if not auth_header or not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Missing or invalid Authorization header")
api_key = auth_header.split("Bearer ", 1)[1]
path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
# Remove sensitive headers
if "authorization" in request_headers:
del request_headers["authorization"]
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string)
# Build the auth request model
auth_request = AuthRequest(
api_key=api_key,
request=AuthRequestContext(
path=path,
headers=request_headers,
params=params,
),
)
# Validate with authentication endpoint
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.auth_endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Authentication failed: {response.status_code}")
return await self._send_auth_error(send, "Authentication failed")
# Parse and validate the auth response
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
# Store attributes in request scope for access control
if auth_response.access_attributes:
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [api_key],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
except Exception:
logger.exception("Error parsing authentication response")
return await self._send_auth_error(send, "Invalid authentication response format")
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
except Exception:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
return await self.app(scope, receive, send)
async def _send_auth_error(self, send, message):
await send(
{
"type": "http.response.start",
"status": 401,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps({"error": {"message": message}}).encode()
await send({"type": "http.response.body", "body": error_msg})

View file

@ -6,12 +6,9 @@
import argparse
import asyncio
import functools
import inspect
import json
import logging
import os
import signal
import sys
import traceback
import warnings
@ -26,12 +23,14 @@ from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import (
construct_stack,
@ -39,23 +38,26 @@ from llama_stack.distribution.stack import (
replace_env_vars,
validate_env_pair,
)
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger,
start_trace,
)
from .auth import AuthenticationMiddleware
from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -117,78 +119,32 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
)
def handle_signal(app, signum, _) -> None:
async def shutdown(app):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
"""
Handle incoming signals and initiate a graceful shutdown of the application.
This function is intended to be used as a signal handler for various signals
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
indicating the received signal and initiate a shutdown process.
Args:
app: The application instance containing implementations to be shut down.
signum (int): The signal number received.
frame: The current stack frame (not used in this function).
The shutdown process involves:
- Shutting down all implementations registered in the application.
- Gathering all running asyncio tasks.
- Cancelling all gathered tasks.
- Waiting for all tasks to finish.
- Stopping the event loop.
Note:
This function schedules the shutdown process as an asyncio task and does
not block the current execution.
"""
signame = signal.Signals(signum).name
print(f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
# Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except Exception as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
# Gather all running tasks
loop = asyncio.get_running_loop()
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
# Cancel all tasks
for task in tasks:
task.cancel()
# Wait for all tasks to finish
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logger.exception("Timeout while waiting for tasks to finish")
except asyncio.CancelledError:
pass
finally:
loop.stop()
loop = asyncio.get_running_loop()
loop.create_task(shutdown())
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Starting up")
logger.info("Starting up")
yield
print("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
logger.info("Shutting down")
await shutdown(app)
def is_streaming_request(func_name: str, request: Request, **kwargs):
@ -204,15 +160,14 @@ async def maybe_await(value):
async def sse_generator(event_gen):
try:
event_gen = await event_gen
async for item in event_gen:
async for item in await event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
logger.info("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
logger.exception("Error in sse_generator")
yield create_sse_event(
{
"error": {
@ -224,18 +179,25 @@ async def sse_generator(event_gen):
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
set_request_provider_data(request.headers)
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
)
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
logger.exception(f"Error executing endpoint {route=} {method=}")
raise translate_exception(e) from e
sig = inspect.signature(func)
@ -264,7 +226,7 @@ class TracingMiddleware:
self.app = app
async def __call__(self, scope, receive, send):
path = scope["path"]
path = scope.get("path", "")
await start_trace(path, {"__location__": "server"})
try:
return await self.app(scope, receive, send)
@ -348,52 +310,66 @@ def main():
args = parser.parse_args()
if args.env:
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
print(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
print(f"Error: {str(e)}")
sys.exit(1)
log_line = ""
if args.yaml_config:
# if the user provided a config file, use it, even if template was specified
config_file = Path(args.yaml_config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
print(f"Using config file: {config_file}")
log_line = f"Using config file: {config_file}"
elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
print(f"Using template {args.template} config file: {config_file}")
log_line = f"Using template {args.template} config file: {config_file}"
else:
raise ValueError("Either --yaml-config or --template must be provided")
logger_config = None
with open(config_file, "r") as fp:
config = replace_env_vars(yaml.safe_load(fp))
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config)
if args.env:
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
config = replace_env_vars(config_contents)
config = StackRunConfig(**config)
print("Run configuration:")
# now that the logger is initialized, print the line about which type of config we are using.
logger.info(log_line)
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump())
print(yaml.dump(safe_config, indent=2))
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
if config.server.auth and config.server.auth.endpoint:
logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}")
app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint)
try:
impls = asyncio.run(construct_stack(config))
except InvalidProviderError:
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig()))
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_endpoints = get_all_api_endpoints()
@ -409,6 +385,7 @@ def main():
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
for api_str in apis_to_serve:
api = Api(api_str)
@ -432,15 +409,10 @@ def main():
)
)
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
for endpoint in endpoints:
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
logger.debug(f"serving APIs: {apis_to_serve}")
print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
app.__llama_stack_impls__ = impls
@ -462,15 +434,17 @@ def main():
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
print(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,
"host": listen_host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
}
if ssl_config:
uvicorn_config.update(ssl_config)

View file

@ -5,13 +5,12 @@
# the root directory of this source tree.
import importlib.resources
import logging
import os
import re
import tempfile
from typing import Any, Dict, Optional
import yaml
from termcolor import colored
from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
@ -24,6 +23,7 @@ from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
@ -33,16 +33,19 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="core")
class LlamaStack(
Providers,
VectorDBs,
Inference,
BatchInference,
@ -101,12 +104,10 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
log.info(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
logger.debug(
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
log.info("")
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
@ -155,18 +156,34 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
return result
elif isinstance(config, str):
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"
# Updated pattern to support both default values (:) and conditional values (+)
pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}"
def get_env_var(match):
env_var = match.group(1)
default_val = match.group(2)
operator = match.group(2) # ':' for default, '+' for conditional
value_expr = match.group(3)
value = os.environ.get(env_var)
if not value:
if default_val is None:
raise EnvVarError(env_var, path)
env_value = os.environ.get(env_var)
if operator == ":": # Default value syntax: ${env.FOO:default}
if not env_value:
if value_expr is None:
raise EnvVarError(env_var, path)
else:
value = value_expr
else:
value = default_val
value = env_value
elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set}
if env_value:
value = value_expr
else:
# If env var is not set, return empty string for the conditional case
value = ""
else: # No operator case: ${env.FOO}
if not env_value:
raise EnvVarError(env_var, path)
value = env_value
# expand "~" from the values
return os.path.expanduser(value)
@ -215,3 +232,53 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
run_config = yaml.safe_load(path.open())
return StackRunConfig(**replace_env_vars(run_config))
def run_config_from_adhoc_config_spec(
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
) -> StackRunConfig:
"""
Create an adhoc distribution from a list of API providers.
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
"""
api_providers = adhoc_config_spec.replace(";", ",").split(",")
provider_registry = provider_registry or get_provider_registry()
distro_dir = tempfile.mkdtemp()
provider_configs_by_api = {}
for api_provider in api_providers:
api_str, provider = api_provider.split("=")
api = Api(api_str)
providers_by_type = provider_registry[api]
provider_spec = providers_by_type.get(provider)
if not provider_spec:
provider_spec = providers_by_type.get(f"inline::{provider}")
if not provider_spec:
provider_spec = providers_by_type.get(f"remote::{provider}")
if not provider_spec:
raise ValueError(
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
)
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_configs_by_api[api_str] = [
Provider(
provider_id=provider,
provider_type=provider_spec.provider_type,
config=provider_config,
)
]
config = StackRunConfig(
image_name="distro-test",
apis=list(provider_configs_by_api.keys()),
providers=provider_configs_by_api,
)
return config

View file

@ -98,15 +98,23 @@ case "$env_type" in
*)
esac
set -x
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
$PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
--port "$port" \
$env_vars \
$other_args
elif [[ "$env_type" == "container" ]]; then
set -x
# Check if container command is available
if ! is_command_available $CONTAINER_BINARY; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1
fi
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
# Disable SELinux labels
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"

View file

@ -1,72 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <venv_path> <yaml_config> <port> <script_args...>"
exit 1
fi
venv_path="$1"
shift
yaml_config="$1"
shift
port="$1"
shift
# Initialize env_vars as an empty array
env_vars=""
other_args=""
# Process environment variables from --env arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac
done
echo "Using virtual environment: $venv_path"
# Activate virtual environment
if [ ! -d "$venv_path" ]; then
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
exit 1
fi
source "$venv_path/bin/activate"
set -x
python -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
--port "$port" \
$env_vars \
$other_args

View file

@ -33,7 +33,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v7"
KEY_VERSION = "v8"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -1,199 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.store.registry import (
CachedDiskDistributionRegistry,
DiskDistributionRegistry,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@pytest.fixture
def config():
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
if os.path.exists(config.db_path):
os.remove(config.db_path)
return config
@pytest_asyncio.fixture(scope="function")
async def registry(config):
registry = DiskDistributionRegistry(await kvstore_impl(config))
await registry.initialize()
return registry
@pytest_asyncio.fixture(scope="function")
async def cached_registry(config):
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await registry.initialize()
return registry
@pytest.fixture
def sample_vector_db():
return VectorDB(
identifier="test_vector_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="test_vector_db",
provider_id="test-provider",
)
@pytest.fixture
def sample_model():
return Model(
identifier="test_model",
provider_resource_id="test_model",
provider_id="test-provider",
)
@pytest.mark.asyncio
async def test_registry_initialization(registry):
# Test empty registry
result = await registry.get("nonexistent", "nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_basic_registration(registry, sample_vector_db, sample_model):
print(f"Registering {sample_vector_db}")
await registry.register(sample_vector_db)
print(f"Registering {sample_model}")
await registry.register(sample_model)
print("Getting vector_db")
result_vector_db = await registry.get("vector_db", "test_vector_db")
assert result_vector_db is not None
assert result_vector_db.identifier == sample_vector_db.identifier
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
assert result_vector_db.provider_id == sample_vector_db.provider_id
result_model = await registry.get("model", "test_model")
assert result_model is not None
assert result_model.identifier == sample_model.identifier
assert result_model.provider_id == sample_model.provider_id
@pytest.mark.asyncio
async def test_cached_registry_initialization(config, sample_vector_db, sample_model):
# First populate the disk registry
disk_registry = DiskDistributionRegistry(await kvstore_impl(config))
await disk_registry.initialize()
await disk_registry.register(sample_vector_db)
await disk_registry.register(sample_model)
# Test cached version loads from disk
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
assert result_vector_db is not None
assert result_vector_db.identifier == sample_vector_db.identifier
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
assert result_vector_db.embedding_dimension == sample_vector_db.embedding_dimension
assert result_vector_db.provider_id == sample_vector_db.provider_id
@pytest.mark.asyncio
async def test_cached_registry_updates(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
new_vector_db = VectorDB(
identifier="test_vector_db_2",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="test_vector_db_2",
provider_id="baz",
)
await cached_registry.register(new_vector_db)
# Verify in cache
result_vector_db = await cached_registry.get("vector_db", "test_vector_db_2")
assert result_vector_db is not None
assert result_vector_db.identifier == new_vector_db.identifier
assert result_vector_db.provider_id == new_vector_db.provider_id
# Verify persisted to disk
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
await new_registry.initialize()
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
assert result_vector_db is not None
assert result_vector_db.identifier == new_vector_db.identifier
assert result_vector_db.provider_id == new_vector_db.provider_id
@pytest.mark.asyncio
async def test_duplicate_provider_registration(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
original_vector_db = VectorDB(
identifier="test_vector_db_2",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="test_vector_db_2",
provider_id="baz",
)
await cached_registry.register(original_vector_db)
duplicate_vector_db = VectorDB(
identifier="test_vector_db_2",
embedding_model="different-model",
embedding_dimension=384,
provider_resource_id="test_vector_db_2",
provider_id="baz", # Same provider_id
)
await cached_registry.register(duplicate_vector_db)
result = await cached_registry.get("vector_db", "test_vector_db_2")
assert result is not None
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
@pytest.mark.asyncio
async def test_get_all_objects(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
# Create multiple test banks
test_vector_dbs = [
VectorDB(
identifier=f"test_vector_db_{i}",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id=f"test_vector_db_{i}",
provider_id=f"provider_{i}",
)
for i in range(3)
]
# Register all vector_dbs
for vector_db in test_vector_dbs:
await cached_registry.register(vector_db)
# Test get_all retrieval
all_results = await cached_registry.get_all()
assert len(all_results) == 3
# Verify each vector_db was stored correctly
for original_vector_db in test_vector_dbs:
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
assert len(matching_vector_dbs) == 1
stored_vector_db = matching_vector_dbs[0]
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
assert stored_vector_db.provider_id == original_vector_db.provider_id
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension

View file

@ -0,0 +1,11 @@
# More info on playground configuration can be found here:
# https://llama-stack.readthedocs.io/en/latest/playground
FROM python:3.9-slim
WORKDIR /app
COPY . /app/
RUN /usr/local/bin/python -m pip install --upgrade pip && \
/usr/local/bin/pip3 install -r requirements.txt
EXPOSE 8501
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]

View file

@ -17,7 +17,7 @@ llama stack run together
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
```bash
$ llama-stack-client datasets register \
llama-stack-client datasets register \
--dataset-id "mmlu" \
--provider-id "huggingface" \
--url "https://huggingface.co/datasets/llamastack/evals" \
@ -26,7 +26,7 @@ $ llama-stack-client datasets register \
```
```bash
$ llama-stack-client benchmarks register \
llama-stack-client benchmarks register \
--eval-task-id meta-reference-mmlu \
--provider-id meta-reference \
--dataset-id mmlu \
@ -40,3 +40,13 @@ cd llama_stack/distribution/ui
pip install -r requirements.txt
streamlit run app.py
```
## Environment Variables
| Environment Variable | Description | Default Value |
|----------------------------|------------------------------------|---------------------------|
| LLAMA_STACK_ENDPOINT | The endpoint for the Llama Stack | http://localhost:8321 |
| FIREWORKS_API_KEY | API key for Fireworks provider | (empty string) |
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def datasets():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def benchmarks():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def models():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def providers():

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from page.distribution.benchmarks import benchmarks
from page.distribution.datasets import datasets
from page.distribution.models import models
from page.distribution.scoring_functions import scoring_functions
from page.distribution.shields import shields
from page.distribution.vector_dbs import vector_dbs
from streamlit_option_menu import option_menu
from llama_stack.distribution.ui.page.distribution.datasets import datasets
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.distribution.ui.page.distribution.models import models
from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions
from llama_stack.distribution.ui.page.distribution.shields import shields
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
def resources_page():
options = [

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def scoring_functions():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def shields():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def vector_dbs():

View file

@ -8,8 +8,9 @@ import json
import pandas as pd
import streamlit as st
from modules.api import llama_stack_api
from modules.utils import process_dataset
from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import process_dataset
def application_evaluation_page():

View file

@ -8,7 +8,8 @@ import json
import pandas as pd
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def select_benchmark_1():
@ -166,11 +167,10 @@ def run_evaluation_3():
eval_candidate = st.session_state["eval_candidate"]
dataset_id = benchmarks[selected_benchmark].dataset_id
rows = llama_stack_api.client.datasetio.get_rows_paginated(
rows = llama_stack_api.client.datasets.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
)
total_rows = len(rows.rows)
total_rows = len(rows.data)
# Add number of examples control
num_rows = st.number_input(
"Number of Examples to Evaluate",
@ -195,7 +195,7 @@ def run_evaluation_3():
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = rows.rows
rows = rows.data
if num_rows < total_rows:
rows = rows[:num_rows]
@ -212,7 +212,7 @@ def run_evaluation_3():
benchmark_id=selected_benchmark,
input_rows=[r],
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
task_config=benchmark_config,
benchmark_config=benchmark_config,
)
for k in r.keys():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
# Sidebar configurations
with st.sidebar:

View file

@ -7,10 +7,10 @@
import streamlit as st
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.memory_insert_params import Document
from modules.api import llama_stack_api
from modules.utils import data_url_from_file
from llama_stack_client.types.shared.document import Document
from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import data_url_from_file
def rag_chat_page():
@ -124,26 +124,22 @@ def rag_chat_page():
else:
strategy = {"type": "greedy"}
agent_config = AgentConfig(
agent = Agent(
llama_stack_api.client,
model=selected_model,
instructions=system_prompt,
sampling_params={
"strategy": strategy,
},
toolgroups=[
tools=[
dict(
name="builtin::rag",
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
"vector_db_ids": list(selected_vector_dbs),
},
)
],
tool_choice="auto",
tool_prompt_format="json",
enable_session_persistence=False,
)
agent = Agent(llama_stack_api.client, agent_config)
session_id = agent.create_session("rag-session")
# Chat input

View file

@ -13,6 +13,4 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextvars import ContextVar
from typing import AsyncGenerator, List, TypeVar
T = TypeVar("T")
def preserve_contexts_async_generator(
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve context variables across iterations.
This is needed because we start a new asyncio event loop for each streaming request,
and we need to preserve the context across the event loop boundary.
"""
async def wrapper() -> AsyncGenerator[T, None]:
while True:
try:
item = await gen.__anext__()
context_values = {context_var.name: context_var.get() for context_var in context_vars}
yield item
for context_var in context_vars:
_ = context_var.set(context_values[context_var.name])
except StopAsyncIteration:
break
return wrapper()

View file

@ -4,13 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import errno
import logging
import os
import select
import signal
import subprocess
import sys
from termcolor import cprint
@ -20,14 +17,14 @@ import importlib
import json
from pathlib import Path
from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.distribution.utils.image_types import LlamaStackImageType
def formulate_run_args(image_type, image_name, config, template_name) -> list:
env_name = ""
if image_type == ImageType.container.value or config.container_image:
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
env_name = f"distribution-{template_name}" if template_name else config.container_image
elif image_type == ImageType.conda.value:
elif image_type == LlamaStackImageType.CONDA.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
env_name = image_name or current_conda_env
if not env_name:
@ -46,7 +43,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
envs = conda_env_info["envs"]
for envpath in envs:
if envpath.endswith(env_name):
if os.path.basename(envpath) == env_name:
return envpath
return None
@ -88,13 +85,6 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
return run_args
def run_with_pty(command):
if sys.platform.startswith("win"):
return _run_with_pty_win(command)
else:
return _run_with_pty_unix(command)
def in_notebook():
try:
from IPython import get_ipython
@ -108,19 +98,19 @@ def in_notebook():
return True
# run a command in a pseudo-terminal, with interrupt handling,
# useful when you want to run interactive things
def _run_with_pty_unix(command):
import pty
import termios
def run_command(command: list[str]) -> int:
"""
Run a command with interrupt handling and output capture.
Uses subprocess.run with direct stream piping for better performance.
master, slave = pty.openpty()
Args:
command (list): The command to run.
old_settings = termios.tcgetattr(sys.stdin)
Returns:
int: The return code of the command.
"""
original_sigint = signal.getsignal(signal.SIGINT)
ctrl_c_pressed = False
process = None
def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed
@ -131,106 +121,19 @@ def _run_with_pty_unix(command):
# Set up the signal handler
signal.signal(signal.SIGINT, sigint_handler)
new_settings = termios.tcgetattr(sys.stdin)
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
process = subprocess.Popen(
# Run the command with stdout/stderr piped directly to system streams
result = subprocess.run(
command,
stdin=slave,
stdout=slave,
stderr=slave,
universal_newlines=True,
preexec_fn=os.setsid,
text=True,
check=False,
)
# Close the slave file descriptor as it's now owned by the subprocess
os.close(slave)
def handle_io():
while not ctrl_c_pressed:
try:
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
if sys.stdin in rlist:
data = os.read(sys.stdin.fileno(), 1024)
if not data:
break
os.write(master, data)
if master in rlist:
data = os.read(master, 1024)
if not data:
break
sys.stdout.buffer.write(data)
sys.stdout.flush()
except KeyboardInterrupt:
# This will be raised when Ctrl+C is pressed
break
if process.poll() is not None:
break
handle_io()
except (EOFError, KeyboardInterrupt):
pass
except OSError as e:
if e.errno != errno.EIO:
raise
finally:
# Clean up
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
signal.signal(signal.SIGINT, original_sigint)
os.close(master)
if process and process.poll() is None:
process.terminate()
process.wait()
return process.returncode
# run a command in a pseudo-terminal in windows, with interrupt handling,
def _run_with_pty_win(command):
"""
Runs a command with interactive support using subprocess directly.
"""
try:
# For shell scripts on Windows, use appropriate shell
if isinstance(command, (list, tuple)):
if command[0].endswith(".sh"):
if os.path.exists("/usr/bin/bash"): # WSL
command = ["bash"] + command
else:
# Use cmd.exe with bash while preserving all arguments
command = ["cmd.exe", "/c", "bash"] + command
process = subprocess.Popen(
command,
shell=True,
universal_newlines=True,
)
process.wait()
return result.returncode
except subprocess.SubprocessError as e:
log.error(f"Subprocess error: {e}")
return 1
except Exception as e:
print(f"Error: {str(e)}")
log.exception(f"Unexpected error: {e}")
return 1
finally:
if process and process.poll() is None:
process.terminate()
process.wait()
return process.returncode
def run_command(command):
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
print("Script Output\n", result.stdout)
return result.returncode
except subprocess.CalledProcessError as e:
print("Error running script:", e)
print("Error output:", e.stderr)
return e.returncode
# Restore the original signal handler
signal.signal(signal.SIGINT, original_sigint)

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
import enum
class ImageType(Enum):
container = "container"
conda = "conda"
venv = "venv"
class LlamaStackImageType(enum.Enum):
CONTAINER = "container"
CONDA = "conda"
VENV = "venv"

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from concurrent.futures import ThreadPoolExecutor
from contextvars import ContextVar
import pytest
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
@pytest.mark.asyncio
async def test_preserve_contexts_with_exception():
# Create context variable
context_var = ContextVar("exception_var", default="initial")
token = context_var.set("start_value")
# Create an async generator that raises an exception
async def exception_generator():
yield context_var.get()
context_var.set("modified")
raise ValueError("Test exception")
yield None # This will never be reached
# Wrap the generator
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
# First iteration should work
value = await wrapped_gen.__anext__()
assert value == "start_value"
# Second iteration should raise the exception
with pytest.raises(ValueError, match="Test exception"):
await wrapped_gen.__anext__()
# Clean up
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_empty_generator():
# Create context variable
context_var = ContextVar("empty_var", default="initial")
token = context_var.set("value")
# Create an empty async generator
async def empty_generator():
if False: # This condition ensures the generator yields nothing
yield None
# Wrap the generator
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
# The generator should raise StopAsyncIteration immediately
with pytest.raises(StopAsyncIteration):
await wrapped_gen.__anext__()
# Context variable should remain unchanged
assert context_var.get() == "value"
# Clean up
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_across_event_loops():
"""
Test that context variables are preserved across event loop boundaries with nested generators.
This simulates the real-world scenario where:
1. A new event loop is created for each streaming request
2. The async generator runs inside that loop
3. There are multiple levels of nested generators
4. Context needs to be preserved across these boundaries
"""
# Create context variables
request_id = ContextVar("request_id", default=None)
user_id = ContextVar("user_id", default=None)
# Set initial values
# Results container to verify values across thread boundaries
results = []
# Inner-most generator (level 2)
async def inner_generator():
# Should have the context from the outer scope
yield (1, request_id.get(), user_id.get())
# Modify one context variable
user_id.set("user-modified")
# Should reflect the modification
yield (2, request_id.get(), user_id.get())
# Middle generator (level 1)
async def middle_generator():
inner_gen = inner_generator()
# Forward the first yield from inner
item = await inner_gen.__anext__()
yield item
# Forward the second yield from inner
item = await inner_gen.__anext__()
yield item
request_id.set("req-modified")
# Add our own yield with both modified variables
yield (3, request_id.get(), user_id.get())
# Function to run in a separate thread with a new event loop
def run_in_new_loop():
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Outer generator (runs in the new loop)
async def outer_generator():
request_id.set("req-12345")
user_id.set("user-6789")
# Wrap the middle generator
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
# Process all items from the middle generator
async for item in wrapped_gen:
# Store results for verification
results.append(item)
# Run the outer generator in the new loop
loop.run_until_complete(outer_generator())
finally:
loop.close()
# Run the generator chain in a separate thread with a new event loop
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
future.result() # Wait for completion
# Verify the results
assert len(results) == 3
# First yield should have original values
assert results[0] == (1, "req-12345", "user-6789")
# Second yield should have modified user_id
assert results[1] == (2, "req-12345", "user-modified")
# Third yield should have both modified values
assert results[2] == (3, "req-modified", "user-modified")

243
llama_stack/log.py Normal file
View file

@ -0,0 +1,243 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from logging.config import dictConfig
from typing import Dict, Optional
from rich.console import Console
from rich.errors import MarkupError
from rich.logging import RichHandler
from termcolor import cprint
from .distribution.datatypes import LoggingConfig
# Default log level
DEFAULT_LOG_LEVEL = logging.INFO
# Predefined categories
CATEGORIES = [
"core",
"server",
"router",
"inference",
"agents",
"safety",
"eval",
"tools",
"client",
]
# Initialize category levels with default level
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
def config_to_category_levels(category: str, level: str):
"""
Helper function to be called either by environment parsing or yaml parsing to go from a list of categories and levels to a dictionary ready to be
used by the logger dictConfig.
Parameters:
category (str): logging category to apply the level to
level (str): logging level to be used in the category
Returns:
Dict[str, int]: A dictionary mapping categories to their log levels.
"""
category_levels: Dict[str, int] = {}
level_value = logging._nameToLevel.get(str(level).upper())
if level_value is None:
logging.warning(f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'.")
return category_levels
if category == "all":
# Apply the log level to all categories and the root logger
for cat in CATEGORIES:
category_levels[cat] = level_value
# Set the root logger's level to the specified level
category_levels["root"] = level_value
elif category in CATEGORIES:
category_levels[category] = level_value
logging.info(f"Setting '{category}' category to level '{level}'.")
else:
logging.warning(f"Unknown logging category: {category}. No changes made.")
return category_levels
def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]:
"""
Helper function to parse a yaml logging configuration found in the run.yaml
Parameters:
yaml_config (Logging): the logger config object found in the run.yaml
Returns:
Dict[str, int]: A dictionary mapping categories to their log levels.
"""
category_levels = {}
for category, level in yaml_config.category_levels.items():
category_levels.update(config_to_category_levels(category=category, level=level))
return category_levels
def parse_environment_config(env_config: str) -> Dict[str, int]:
"""
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
Parameters:
env_config (str): The value of the LLAMA_STACK_LOGGING environment variable.
Returns:
Dict[str, int]: A dictionary mapping categories to their log levels.
"""
category_levels = {}
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
category_levels.update(config_to_category_levels(category=category, level=level))
except ValueError:
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
return category_levels
class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs):
kwargs["console"] = Console(width=120)
super().__init__(*args, **kwargs)
def emit(self, record):
"""Override emit to handle markup errors gracefully."""
try:
super().emit(record)
except MarkupError:
original_markup = self.markup
self.markup = False
try:
super().emit(record)
finally:
self.markup = original_markup
def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None:
"""
Configure logging based on the provided category log levels and an optional log file.
Parameters:
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
log_file (str): Path to a log file to additionally pipe the logs into
"""
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
class CategoryFilter(logging.Filter):
"""Ensure category is always present in log records."""
def filter(self, record):
if not hasattr(record, "category"):
record.category = "uncategorized" # Default to 'uncategorized' if no category found
return True
# Determine the root logger's level (default to WARNING if not specified)
root_level = category_levels.get("root", logging.WARNING)
handlers = {
"console": {
"()": CustomRichHandler, # Use custom console handler
"formatter": "rich",
"rich_tracebacks": True,
"show_time": False,
"show_path": False,
"markup": True,
"filters": ["category_filter"],
}
}
# Add a file handler if log_file is set
if log_file:
handlers["file"] = {
"class": "logging.FileHandler",
"formatter": "rich",
"filename": log_file,
"mode": "a",
"encoding": "utf-8",
}
logging_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"rich": {
"()": logging.Formatter,
"format": log_format,
}
},
"handlers": handlers,
"filters": {
"category_filter": {
"()": CategoryFilter,
}
},
"loggers": {
category: {
"handlers": list(handlers.keys()), # Apply all handlers
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
"propagate": False, # Disable propagation to root logger
}
for category in CATEGORIES
},
"root": {
"handlers": list(handlers.keys()),
"level": root_level, # Set root logger's level dynamically
},
}
dictConfig(logging_config)
# Ensure third-party libraries follow the root log level
for _, logger in logging.root.manager.loggerDict.items():
if isinstance(logger, logging.Logger):
logger.setLevel(root_level)
def get_logger(
name: str, category: str = "uncategorized", config: Optional[LoggingConfig] | None = None
) -> logging.LoggerAdapter:
"""
Returns a logger with the specified name and category.
If no category is provided, defaults to 'uncategorized'.
Parameters:
name (str): The name of the logger (e.g., module or filename).
category (str): The category of the logger (default 'uncategorized').
config (Logging): optional yaml config to override the existing logger configuration
Returns:
logging.LoggerAdapter: Configured logger with category support.
"""
if config:
_category_levels.update(parse_yaml_config(config))
logger = logging.getLogger(name)
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
return logging.LoggerAdapter(logger, {"category": category})
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow")
_category_levels.update(parse_environment_config(env_config))
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
setup_logging(_category_levels, log_file)

View file

@ -11,16 +11,135 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import base64
from enum import Enum
from typing import Any, Dict, Literal, Optional, Union
from io import BytesIO
from typing import Any, Dict, List, Literal, Optional, Union
# import all for backwards compatibility
from llama_models.datatypes import * # noqa: F403
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
# The goal is that these set of types are relevant for all Llama models.
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
# the llama3 series of models.
class Role(Enum):
system = "system"
user = "user"
assistant = "assistant"
tool = "tool"
class BuiltinTool(Enum):
brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha"
photogen = "photogen"
code_interpreter = "code_interpreter"
Primitive = Union[str, int, float, bool, None]
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
# Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]]
arguments_json: Optional[str] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class ToolPromptFormat(Enum):
"""Prompt format for calling custom / zero shot tools.
:cvar json: JSON format for calling tools. It takes the form:
{
"type": "function",
"function" : {
"name": "function_name",
"description": "function_description",
"parameters": {...}
}
}
:cvar function_tag: Function tag format, pseudo-XML. This looks like:
<function=function_name>(parameters)</function>
:cvar python_list: Python list. The output is a valid Python expression that can be
evaluated to a list. Each element in the list is a function call. Example:
["function_name(param1, param2)", "function_name(param1, param2)"]
"""
json = "json"
function_tag = "function_tag"
python_list = "python_list"
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
class RawMediaItem(BaseModel):
type: Literal["image"] = "image"
data: bytes | BytesIO
model_config = ConfigDict(arbitrary_types_allowed=True)
@field_serializer("data")
def serialize_data(self, data: Optional[bytes], _info):
if data is None:
return None
return base64.b64encode(data).decode("utf-8")
@field_validator("data", mode="before")
@classmethod
def validate_data(cls, v):
if isinstance(v, str):
return base64.b64decode(v)
return v
class RawTextItem(BaseModel):
type: Literal["text"] = "text"
text: str
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
RawContent = str | RawContentItem | List[RawContentItem]
class RawMessage(BaseModel):
role: Literal["user"] | Literal["system"] | Literal["tool"] | Literal["assistant"]
content: RawContent
# This is for RAG but likely should be absorbed into content
context: Optional[RawContent] = None
# These are for the output message coming from the assistant
stop_reason: Optional[StopReason] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
register_schema(ToolCall)
@ -67,13 +186,11 @@ class TopKSamplingStrategy(BaseModel):
top_k: int = Field(..., ge=1)
SamplingStrategy = register_schema(
Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
],
name="SamplingStrategy",
)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type

View file

@ -0,0 +1,285 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from PIL import Image as PIL_Image
from llama_stack.models.llama.datatypes import (
BuiltinTool,
RawContent,
RawMediaItem,
RawMessage,
RawTextItem,
Role,
StopReason,
ToolCall,
ToolPromptFormat,
)
from .tokenizer import Tokenizer
from .tool_utils import ToolUtils
@dataclass
class VisionInput:
mask: List[List[int]]
images: List[PIL_Image.Image]
@dataclass
class LLMInput:
tokens: List[int]
vision: Optional[VisionInput] = None
def role_str(role: Role) -> str:
role_strs = {
Role.user: "user",
Role.system: "system",
Role.tool: "ipython", # special
Role.assistant: "assistant",
}
return role_strs[role]
class ChatFormat:
possible_headers: Dict[Role, str]
def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
def _encode_header(self, role: str) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
def encode_content(self, content: RawContent) -> LLMInput:
tokens, images = self._encode_content(content, bos=True)
return self._model_input_from_tokens_images(tokens, images)
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
tokens = []
images = []
added_bos = False
def _process(c):
nonlocal added_bos, bos
if isinstance(c, str) or isinstance(c, RawTextItem):
if isinstance(c, RawTextItem):
c = c.text
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
added_bos = True
elif isinstance(c, RawMediaItem):
bos = False if added_bos else bos
if bos:
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
added_bos = True
tokens.append(self.vision_token)
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
image = image.convert("RGB")
images.append(image)
if isinstance(content, list):
for c in content:
_process(c)
else:
_process(content)
return tokens, images
def encode_message(
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
) -> Tuple[List[int], List[PIL_Image.Image]]:
tokens = self._encode_header(message.role)
images = []
def _process_content(c):
toks, imgs = self._encode_content(c)
tokens.extend(toks)
images.extend(imgs)
if (
message.role == "assistant"
and len(message.tool_calls) > 0
and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter
):
tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])
_process_content(message.content)
if message.role == "user" and message.context is not None:
# This is RAG context; why is it here in the chat format? I don't think
# this is needed and can be moved upwards
_process_content("\n\n")
_process_content(message.context)
if message.role == "assistant":
for t in message.tool_calls:
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
_process_content(content)
eom = False
if message.role == "assistant":
eom = message.stop_reason == StopReason.end_of_message
tokens.append(self.tokenizer.special_tokens["<|eom_id|>" if eom else "<|eot_id|>"])
return tokens, images
def encode_dialog_prompt(
self,
messages: List[RawMessage],
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> LLMInput:
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
tokens = []
images = []
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
for message in messages:
toks, imgs = self.encode_message(message, tool_prompt_format)
tokens.extend(toks)
images.extend(imgs)
# Add the start of an assistant message for the model to complete.
tokens.extend(self._encode_header("assistant"))
return self._model_input_from_tokens_images(tokens, images)
# TODO(this should be generic, not only for assistant messages)
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
content = self.tokenizer.decode(tokens)
return self.decode_assistant_message_from_content(content, stop_reason)
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
content = content.strip(" ")
header_str = self.possible_headers[Role.assistant]
if content.startswith(header_str):
content = content[len(header_str) :]
ipython = content.startswith("<|python_tag|>")
if ipython:
content = content[len("<|python_tag|>") :]
if content.endswith("<|eot_id|>"):
content = content[: -len("<|eot_id|>")]
stop_reason = StopReason.end_of_turn
elif content.endswith("<|eom_id|>"):
content = content[: -len("<|eom_id|>")]
stop_reason = StopReason.end_of_message
tool_name = None
tool_arguments = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None:
tool_name, tool_arguments = custom_tool_info
# Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
if isinstance(tool_arguments, dict):
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
tool_name, query = builtin_tool_info
tool_arguments = {
"query": query,
}
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = []
if tool_name is not None and tool_arguments is not None:
call_id = str(uuid.uuid4())
tool_calls.append(
ToolCall(
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
)
)
content = ""
return RawMessage(
role="assistant",
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
vision_input = None
if len(images) > 0:
vision_input = VisionInput(
mask=create_vision_mask(tokens, self.vision_token),
images=images,
)
return LLMInput(
tokens=[128256 if token == self.vision_token else token for token in tokens],
vision=vision_input,
)
def create_vision_mask(
tokens: List[int],
vision_token: int,
) -> List[List[int]]:
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
if len(vision_token_locations) == 0:
return []
if len(vision_token_locations) == 1:
# only one image present, unmask until end of sequence
return [[vision_token_locations[0], -1]]
vision_masks = [
[loc1, loc2] for loc1, loc2 in zip(vision_token_locations[:-1], vision_token_locations[1:], strict=False)
]
# last image will attend to all subsequent text
vision_masks.append([vision_token_locations[-1], len(tokens)])
# if there are two or more consecutive vision tokens,
# they should all attend to all subsequent
# text present
last_mask_end = vision_masks[-1][1]
for vision_mask in vision_masks[::-1]:
if vision_mask[0] == vision_mask[1] - 1:
vision_mask[1] = last_mask_end
last_mask_end = vision_mask[1]
return vision_masks

View file

@ -14,20 +14,19 @@
from pathlib import Path
from typing import List, Optional
from llama_models.datatypes import (
from termcolor import colored
from llama_stack.models.llama.datatypes import (
BuiltinTool,
RawMessage,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from termcolor import colored
from llama_stack.models.llama.datatypes import ToolDefinition
from . import template_data
from .chat_format import ChatFormat
from .prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
@ -35,6 +34,7 @@ from .prompt_templates import (
SystemDefaultGenerator,
ToolResponseGenerator,
)
from .tokenizer import Tokenizer
THIS_DIR = Path(__file__).parent

View file

@ -15,11 +15,8 @@ import textwrap
from datetime import datetime
from typing import Any, List, Optional
from llama_models.datatypes import (
BuiltinTool,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolDefinition,
ToolParamDefinition,
)
@ -37,7 +34,9 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
)
return PromptTemplate(
template_str.lstrip("\n"),
{"today": datetime.now().strftime("%d %B %Y")},
{
"today": datetime.now().strftime("%d %B %Y") # noqa: DTZ005 - we don't care about timezones here since we are displaying the date
},
)
def data_examples(self) -> List[Any]:
@ -226,10 +225,9 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
DEFAULT_PROMPT = textwrap.dedent(
"""
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
{{ function_description }}
""".strip("\n")
@ -246,6 +244,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
template_str = textwrap.dedent(
"""
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.

View file

@ -11,11 +11,8 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from llama_models.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
)
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from .prompt_templates import (
BuiltinToolGenerator,

View file

@ -1,199 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
import unittest
from datetime import datetime
from .prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
PythonListCustomToolGenerator,
SystemDefaultGenerator,
)
class PromptTemplateTests(unittest.TestCase):
def check_generator_output(self, generator, expected_text):
example = generator.data_examples()[0]
pt = generator.gen(example)
text = pt.render()
# print(text) # debugging
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
def test_system_default(self):
generator = SystemDefaultGenerator()
today = datetime.now().strftime("%d %B %Y")
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
self.check_generator_output(generator, expected_text)
def test_system_builtin_only(self):
generator = BuiltinToolGenerator()
expected_text = textwrap.dedent(
"""
Environment: ipython
Tools: brave_search, wolfram_alpha
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))
def test_system_custom_only(self):
self.maxDiff = None
generator = JsonCustomToolGenerator()
expected_text = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{
"type": "function",
"function": {
"name": "trending_songs",
"description": "Returns the trending songs on a Music site",
"parameters": {
"type": "object",
"properties": [
{
"n": {
"type": "object",
"description": "The number of songs to return"
}
},
{
"genre": {
"type": "object",
"description": "The genre of the songs to return"
}
}
],
"required": ["n"]
}
}
}
Return function calls in JSON format.
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))
def test_system_custom_function_tag(self):
self.maxDiff = None
generator = FunctionTagCustomToolGenerator()
expected_text = textwrap.dedent(
"""
You have access to the following functions:
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
Think very carefully before calling functions.
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- If looking for real time information use relevant functions before falling back to brave_search
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))
def test_llama_3_2_system_zero_shot(self):
generator = PythonListCustomToolGenerator()
expected_text = textwrap.dedent(
"""
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": ["city"],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))
def test_llama_3_2_provided_system_prompt(self):
generator = PythonListCustomToolGenerator()
expected_text = textwrap.dedent(
"""
Overriding message.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": ["city"],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]"""
)
user_system_prompt = textwrap.dedent(
"""
Overriding message.
{{ function_description }}
"""
)
example = generator.data_examples()[0]
pt = generator.gen(example, user_system_prompt)
text = pt.render()
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from logging import getLogger
from pathlib import Path
from typing import (
AbstractSet,
Collection,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)
import tiktoken
from tiktoken.load import load_tiktoken_bpe
logger = getLogger(__name__)
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
# https://github.com/openai/tiktoken/issues/195
# Here we iterate over subsequences and split if we exceed the limit
# of max consecutive non-whitespace or whitespace characters.
MAX_NO_WHITESPACES_CHARS = 25_000
_INSTANCE = None
class Tokenizer:
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
num_reserved_special_tokens = 256
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
@classmethod
def get_instance(cls):
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
return _INSTANCE
def __init__(self, model_path: str):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path
mergeable_ranks = load_tiktoken_bpe(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|finetune_right_pad_id|>",
"<|step_id|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|eom_id|>", # end of message
"<|eot_id|>", # end of turn
"<|python_tag|>",
"<|image|>",
]
reserved_tokens = [
f"<|reserved_special_token_{2 + i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
]
special_tokens = special_tokens + reserved_tokens
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
self.n_words: int = num_base_tokens + len(special_tokens)
# BOS / EOS token IDs
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
self.eot_id: int = self.special_tokens["<|eot_id|>"]
self.eom_id: int = self.special_tokens["<|eom_id|>"]
self.python_tag_id = self.special_tokens["<|python_tag|>"]
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
self.stop_tokens = [
self.eos_id,
self.special_tokens["<|eom_id|>"],
self.special_tokens["<|eot_id|>"],
]
def encode(
self,
s: str,
*,
bos: bool,
eos: bool,
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
"""
Encodes a string into a list of token IDs.
Args:
s (str): The input string to be encoded.
bos (bool): Whether to prepend the beginning-of-sequence token.
eos (bool): Whether to append the end-of-sequence token.
allowed_special ("all"|set[str]): allowed special tokens in string
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
Returns:
list[int]: A list of token IDs.
By default, setting disallowed_special=() encodes a string by ignoring
special tokens. Specifically:
- Setting `disallowed_special` to () will cause all text corresponding
to special tokens to be encoded as natural text (insteading of raising
an error).
- Setting `allowed_special` to "all" will treat all text corresponding
to special tokens to be encoded as special tokens.
"""
if allowed_special is None:
allowed_special = set()
assert type(s) is str
substrs = (
substr
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
for substr in self._split_whitespaces_or_nonwhitespaces(
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
for substr in substrs:
t.extend(
self.model.encode(
substr,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if bos:
t.insert(0, self.bos_id)
if eos:
t.append(self.eos_id)
return t
def decode(self, t: Sequence[int]) -> str:
"""
Decodes a list of token IDs into a string.
Args:
t (List[int]): The list of token IDs to be decoded.
Returns:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
"""
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
consecutive whitespaces or consecutive non-whitespaces.
"""
current_slice_len = 0
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
slice_start = 0
for i in range(len(s)):
is_now_space = s[i].isspace()
if current_slice_is_space ^ is_now_space:
current_slice_len = 1
current_slice_is_space = is_now_space
else:
current_slice_len += 1
if current_slice_len > max_consecutive_slice_len:
yield s[slice_start:i]
slice_start = i
current_slice_len = 1
yield s[slice_start:]

View file

@ -0,0 +1,210 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import ast
import json
import re
from typing import Optional, Tuple
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
def is_json(s):
try:
parsed = json.loads(s)
# Return True for valid objects and not for ints, strings, etc
return isinstance(parsed, dict)
except json.JSONDecodeError:
return False
return True
def is_valid_python_list(input_string):
"""Check if the input string is a valid Python list of function calls"""
try:
# Try to parse the string
tree = ast.parse(input_string)
# Check if it's a single expression
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
return False
# Check if the expression is a list
expr = tree.body[0].value
if not isinstance(expr, ast.List):
return False
# Check if the list is empty
if len(expr.elts) == 0:
return False
# Check if all elements in the list are function calls
for element in expr.elts:
if not isinstance(element, ast.Call):
return False
# Check if the function call has a valid name
if not isinstance(element.func, ast.Name):
return False
# Check if all arguments are keyword arguments
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
return False
return True
except SyntaxError:
# If parsing fails, it's not a valid Python expression
return False
def parse_python_list_for_function_calls(input_string):
"""
Parse a Python list of function calls and
return a list of tuples containing the function name and arguments
"""
# Parse the string into an AST
tree = ast.parse(input_string)
# Ensure the input is a list
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
raise ValueError("Input must be a list of function calls")
result = []
# Iterate through each function call in the list
for node in tree.body[0].value.elts:
if isinstance(node, ast.Call):
function_name = node.func.id
function_args = {}
# Extract keyword arguments
for keyword in node.keywords:
try:
function_args[keyword.arg] = ast.literal_eval(keyword.value)
except ValueError as e:
logger.error(
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
)
raise ValueError(
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
) from e
result.append((function_name, function_args))
return result
class ToolUtils:
@staticmethod
def is_builtin_tool_call(message_body: str) -> bool:
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
return match is not None
@staticmethod
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
# Find the first match in the text
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
# Check if a match is found and return it
if match:
tool_name = match.group("tool_name")
query = match.group("query")
return tool_name, query
else:
return None
@staticmethod
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
# NOTE: Custom function too calls are still experimental
# Sometimes, response is of the form
# {"type": "function", "name": "function_name", "parameters": {...}
# and some times
# <function=function_name>(parameters)</function>
# Find the first match in the text
match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
if match:
tool_name = match.group("function_name")
query = match.group("args")
try:
return tool_name, json.loads(query.replace("'", '"'))
except Exception as e:
print("Exception while parsing json query for custom tool call", query, e)
return None
elif is_json(message_body):
response = json.loads(message_body)
if ("type" in response and response["type"] == "function") or ("name" in response):
function_name = response["name"]
args = response["parameters"]
return function_name, args
else:
return None
elif is_valid_python_list(message_body):
res = parse_python_list_for_function_calls(message_body)
# FIXME: Enable multiple tool calls
return res[0]
else:
return None
@staticmethod
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
if t.tool_name == BuiltinTool.brave_search:
q = t.arguments["query"]
return f'brave_search.call(query="{q}")'
elif t.tool_name == BuiltinTool.wolfram_alpha:
q = t.arguments["query"]
return f'wolfram_alpha.call(query="{q}")'
elif t.tool_name == BuiltinTool.photogen:
q = t.arguments["query"]
return f'photogen.call(query="{q}")'
elif t.tool_name == BuiltinTool.code_interpreter:
return t.arguments["code"]
else:
fname = t.tool_name
if tool_prompt_format == ToolPromptFormat.json:
return json.dumps(
{
"type": "function",
"name": fname,
"parameters": t.arguments,
}
)
elif tool_prompt_format == ToolPromptFormat.function_tag:
args = json.dumps(t.arguments)
return f"<function={fname}>{args}</function>"
elif tool_prompt_format == ToolPromptFormat.python_list:
def format_value(value: RecursiveType) -> str:
if isinstance(value, str):
return f'"{value}"'
elif isinstance(value, (int, float, bool)) or value is None:
return str(value)
elif isinstance(value, list):
return f"[{', '.join(format_value(v) for v in value)}]"
elif isinstance(value, dict):
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
else:
raise ValueError(f"Unsupported type: {type(value)}")
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
return f"[{fname}({args_str})]"
else:
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")

View file

@ -0,0 +1,358 @@
# Llama 3.1 - Prompt Formats
## Tokens
Here is a list of special tokens that are supported by Llama 3.1:
- `<|begin_of_text|>`: Specifies the start of the prompt
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and ipython]
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
- at the end of a direct interaction between the model and the user
- at the end of multiple interactions between the model and any available tools
This token signals to the executor that the model has finished generating a response.
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
There are 4 different roles that are supported by Llama 3.1
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
- `ipython`: A new role introduced in Llama 3.1. Semantically, this role means "tool". This role is used to mark messages with the output of a tool call when sent back to the model from the executor.
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `ipython` and `user` prompts.
## Llama 3.1 Base Model
Text completion for Llama 3.1 base model uses this format.
##### Input Prompt Format
```
<|begin_of_text|>Color of sky is blue but sometimes can also be
```
##### Model Response Format
```
red, orange, yellow, green, purple, pink, brown, gray, black, white, and even rainbow colors. The color of the sky can change due to various reasons such as time of day, weather conditions, pollution, and atmospheric phenomena.
The color of the sky is primarily blue because of a phenomenon called
```
Note start special tag
## Llama 3.1 Instruct Model
## User and assistant conversation
Here is a regular multi-turn user assistant conversation and how its formatted.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
Answer who are you in the form of jeopardy?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
Here's my response
"What is a helpful assistant?"<|eot_id|>
```
## Tool Calling Formats
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
- Brave Search: Tool call to perform web searches.
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
- Code Interpreter: Enables the model to output python code.
## Builtin Tool Calling
Here is an example of a conversation using brave search
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Tools: brave_search, wolfram_alpha
Cutting Knowledge Date: December 2023
Today Date: 21 September 2024
You are a helpful assistant.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
```
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
- The message body of the assistant response starts with a special tag <|python_tag|>
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
## Builtin Code Interpreter
Here is an actual example of model responding with code
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython<|eot_id|><|start_header_id|>user<|end_header_id|>
Write code to check if number is prime, use that to see if the number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>def is_prime(n):
if n <= 1
return False
for i in range(2, int(n**0.5) + 1):
if n % i == 0:
return False
return True
print(is_prime(7)) # Output: True<|eom_id|>
```
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
## Built-in tools full interaction
Here is a full interaction with the built-in tools including the tool response and the final assistant response.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Tools: brave_search, wolfram_alpha
<|eot_id|><|start_header_id|>user<|end_header_id|>
What is the 100th decimal of pi?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
<|python_tag|>wolfram_alpha.call(query="100th decimal of pi")<|eom_id|><|start_header_id|>ipython<|end_header_id|>
{
"queryresult": {
"success": true,
"inputstring": "100th decimal of pi",
"pods": [
{
"title": "Input interpretation",
"subpods": [
{
"title": "",
"plaintext": "100th digit | π"
}
]
},
{
"title": "Nearby digits",
"subpods": [
{
"title": "",
"plaintext": "...86208998628034825342117067982148086513282306647093..."
}
]
},
{
"title": "Result",
"primary": true,
"subpods": [
{
"title": "",
"plaintext": "7"
}
]
}
]
}
}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
The 100th decimal of pi is 7.<|eot_id|>
```
- Note the `<|python_tag|>` in the assistant response.
- Role is `ipython` for the wolfram alpha response that is passed back to the model.
- Final message from assistant has <|eot_id|> tag.
## Zero shot tool calling
## JSON based tool calling
Llama models can now output custom tool calls from a single message to allow easier tool calling.
The following prompts provide an example of how custom tools can be called from the output of the model.
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Cutting Knowledge Date: December 2023
Today Date: 21 September 2024
You are a helpful assistant.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{
"type": "function",
"function": {
"name": "trending_songs",
"description": "Returns the trending songs on a Music site",
"parameters": {
"type": "object",
"properties": [
{
"n": {
"type": "object",
"description": "The number of songs to return"
}
},
{
"genre": {
"type": "object",
"description": "The genre of the songs to return"
}
}
],
"required": ["n"]
}
}
}
Return function calls in JSON format.<|eot_id|><|start_header_id|>user<|end_header_id|>
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>{
"type": "function",
"name": "trending_songs",
"parameters": {
"n": "10",
"genre": "all"
}
}<|eom_id|>
```
- JSON format for providing tools needs name, description and parameters
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
- Instructions for tools added as a user message
- Only single tool calls are supported as of now
## Example of a user defined tool calling
## `<function>` based tool calling
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
In this example, we define a custom tool calling format using the `<function>` tag.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Cutting Knowledge Date: December 2023
Today Date: 21 September 2024
You are a helpful assistant.
<|eot_id|><|start_header_id|>user<|end_header_id|>
You have access to the following functions:
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
Think very carefully before calling functions.
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- If looking for real time information use relevant functions before falling back to brave_search
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line<|eot_id|><|start_header_id|>user<|end_header_id|>
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<function=trending_songs>{"n": 10}</function><|eot_id|>
```
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
- Instructions for tools added as a user message
Thank You!

View file

@ -14,7 +14,7 @@
import textwrap
from typing import List
from llama_models.datatypes import (
from llama_stack.models.llama.datatypes import (
BuiltinTool,
RawMessage,
StopReason,

View file

@ -13,7 +13,7 @@
import json
import textwrap
from llama_models.datatypes import (
from llama_stack.models.llama.datatypes import (
RawMessage,
StopReason,
ToolCall,

View file

@ -14,7 +14,7 @@
import textwrap
from pathlib import Path
from llama_models.datatypes import (
from llama_stack.models.llama.datatypes import (
RawMediaItem,
RawMessage,
RawTextItem,

View file

@ -0,0 +1,286 @@
## User and assistant conversation
Here is a regular multi-turn user assistant conversation and how its formatted.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
I'm an AI assistant, which means I'm a computer program designed to simulate conversation and answer questions to the best of my ability. I'm here to help you with any questions or tasks you may have, and I'll do my best to provide you with accurate and helpful information.
I don't have a personal name, but I'm often referred to as a "virtual assistant" or a "chatbot." I'm a machine learning model, which means I've been trained on a large dataset of text and can generate responses based on patterns and context.
I can help with a wide range of topics, from general knowledge and trivia to more specialized subjects like science, history, and technology. I can also assist with tasks like language translation, text summarization, and even generating creative content like stories or poetry.
So, what can I help you with today?<|eot_id|>
```
##### Notes
This format is unchanged from Llama3.1
## Zero shot function calling
For Llama3.2 1B and 3B instruct models, we are introducing a new format for zero shot function calling.
This new format is designed to be more flexible and powerful than the previous format.
All available functions can be provided in the system message. A key difference is in the format of how the assistant responds with function calls.
It is pythonic in the form of `[func1(params_name=params_value, params_name2=params_value2...), func2(params)]` instead of the `json` or `<function>` tag that were defined in Llama3.1.
Here is an example for the same,
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": [
"city"
],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]<|eot_id|><|start_header_id|>user<|end_header_id|>
What is the weather in SF and Seattle?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
[get_weather(city='San Francisco', metric='celsius'), get_weather(city='Seattle', metric='celsius')]<|eot_id|>
```
##### Notes
- The output supports multiple tool calls natively
- JSON format for defining the functions in the system prompt is similar to Llama3.1
## Zero shot function calling with user message
While the default is to provide all function calls in a system message, in Llama3.2 text models you can also provide information for all the available tools in a user message.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
Here is a list of functions in JSON format that you can invoke:
[
{
"name": "get_user_info",
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
"parameters": {
"type": "dict",
"required": [
"user_id"
],
"properties": {
"user_id": {
"type": "integer",
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
},
"special": {
"type": "string",
"description": "Any special information or parameters that need to be considered while fetching user details.",
"default": "none"
}
}
}
}
]
Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
NO other text MUST be included.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
[get_user_info(user_id=7890, special='black')]<|eot_id|>
```
##### Notes
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
- While builtin tool calls end with a <|eom_id|>, notice the <|eot_id|> for zero shot tool calls.
## Code Interpreter
Code Interpreter continues to work in 3.2 text models similar to Llama 3.1 model family.
Here is an example,
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Cutting Knowledge Date: December 2023
Today Date: 24 September 2024
<|eot_id|><|start_header_id|>user<|end_header_id|>
Write code to check if number is prime. Use it to verify if number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>def is_prime(n):
if n <= 1:
return False
if n == 2:
return True
if n % 2 == 0:
return False
max_divisor = int(n**0.5) + 1
for d in range(3, max_divisor, 2):
if n % d == 0:
return False
return True
print(is_prime(7)) # Output: True<|eom_id|>
```
##### Notes
- Note `Environment: ipython` in the system prompt.
- Note that the response starts with `<|python_tag|>` and ends with `<|eom_id|>`
## Zero shot function calling E2E format
Here is an example of the e2e cycle of tool calls with the model in a muti-step way.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": [
"city"
],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]<|eot_id|><|start_header_id|>user<|end_header_id|>
What is the weather in SF?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
<|python_tag|>[get_weather(city="San Francisco", metric="celsius")]<|eot_id|><|start_header_id|>ipython<|end_header_id|>
"25 C"<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
The weather in San Francisco is 25 C.<|eot_id|>
```
##### Notes
- The output of the function call is provided back to the model as a tool response ( in json format ).
- Notice `<|start_header_id|>ipython<|end_header_id|>` as the header message preceding the tool response.
- The model finally summarizes the information from the tool response and returns the result to the user.
## Prompt format for base models
For base models (Llama3.2-1B and Llama3.2-3B), the prompt format for a simple completion is as follows
##### Input Prompt Format
```
<|begin_of_text|>The color of the sky is blue but sometimes it can also be
```
##### Model Response Format
```
gray or even purple. The color of the sky can change depending on the time of day, the weather, and the amount of pollution in the air. The color of the sky can also be affected by the presence of dust, smoke, and other particles in the air.
## Step 1: Identify the factors that
```
##### Notes
Same as Llama3.1
Thank You!

View file

@ -0,0 +1,141 @@
## User and assistant conversation
Here is a regular multi-turn user assistant conversation and how its formatted.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
I'm a helpful assistant, here to provide information, answer questions, and assist with tasks to the best of my abilities. I'm a large language model, which means I can understand and respond to natural language inputs, and I'm constantly learning and improving to provide more accurate and helpful responses.
I can help with a wide range of topics, from general knowledge and trivia to more specific areas like science, history, technology, and more. I can also assist with tasks like language translation, text summarization, and even generating creative content like stories or dialogues.
What can I help you with today?<|eot_id|>
```
##### Notes
This format is unchanged from Llama3.1
## User and assistant conversation with Images
This example shows how to pass and image to the model as part of the messages.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
<|image|>Describe this image in two sentences<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
The image depicts a small dog standing on a skateboard, with its front paws firmly planted on the board and its back paws slightly raised. The dog's fur is predominantly brown and white, with a distinctive black stripe running down its back, and it is wearing a black collar around its neck.<|eot_id|>
```
##### Notes
- The `<|image|>` tag is used to indicate presence of the image
- The model isn't an early fusion model so doesn't actually translate an image into several tokens. Instead the cross-attention layers take input "on the side" from a vision encoder
![Image](mm-model.png)
- Its important to postion the <|image|> tag appropriately in the prompt. Image will only attend to the subsequent text tokens
- The <|image|> tag is part of the user message body, implying that it should only come after the header `<|start_header_id|>{role}<|end_header_id|>` in the message body
- We recommend using a single image in one prompt
## Builtin and Zero Shot Tool Calling
Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only.
Use `Environment: ipython` to enable tools.
Add `Tools: {{tool_name1}},{{tool_name2}}` for each of the builtin tools.
The same builtin tools as Llama3.1 are available,
- code_interpreter (for executing python code)
- brave_search (to search the web)
- wolfram_alpha (for querying wolfram alpha for mathematical questions)
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Tools: brave_search, wolfram_alpha
Cutting Knowledge Date: December 2023
Today Date: 23 September 2024
You are a helpful assistant.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
```
##### Notes
- Note the `<|python_tag|>` before `brave_search` function call.
- The `<|eom_id|>` tag is used to indicate the end of the message.
- Similar to Llama3.1, code_interpreter is not explicitly mentioned but is enabled via `Environment: ipython`.
- Tool Calling does NOT work with images in the prompt as of now.
## Prompt format for base models
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), the prompt format for a simple completion is as follows
##### Input Prompt Format
```
<|begin_of_text|>The color of the sky is blue but sometimes it can also be
```
##### Model Response Format
```
red, orange, pink, purple, and even black. The color of the sky is determined by the amount of sunlight that is scattered by the atmosphere and the amount of dust and water vapor present in the atmosphere. During sunrise and sunset, the sky can take on a range of colors due to the scattering of light by
```
##### Notes
- Same as Llama3.1
## Prompt format for base models with Image
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), here is an example of how the text completion format looks with an image,
##### Input Prompt Format
```
<|begin_of_text|><|image|>If I had to write a haiku for this one
```
##### Model Response Format
```
, it would be: A skateboarder's delight, a puppy on a board, a furry little thrill-seeker. This puppy is a true skateboarding enthusiast, always eager to hit the streets and show off his skills. He's a master of the board, gliding effortlessly across the pavement with grace and style.
```
##### Notes
- Note the placement of the special tags <|begin_of_text|> and <|image|>
Thank You!

View file

@ -14,7 +14,7 @@
import textwrap
from typing import List
from llama_models.datatypes import (
from llama_stack.models.llama.datatypes import (
BuiltinTool,
RawMessage,
StopReason,

View file

@ -16,7 +16,9 @@ import textwrap
from pathlib import Path
from typing import List
from llama_models.datatypes import (
from pydantic import BaseModel, Field
from llama_stack.models.llama.datatypes import (
RawContent,
RawMediaItem,
RawMessage,
@ -25,7 +27,6 @@ from llama_models.datatypes import (
ToolCall,
ToolPromptFormat,
)
from pydantic import BaseModel, Field
from .llama3.interface import LLama31Interface
from .llama3.template_data import (

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl(

View file

@ -12,6 +12,7 @@ import uuid
from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
Agents,
@ -21,12 +22,15 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
ListAgentSessionsResponse,
ListAgentsResponse,
Session,
Turn,
)
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
@ -83,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id=agent_id,
)
async def get_agent(self, agent_id: str) -> ChatAgent:
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
@ -119,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
@ -140,7 +144,6 @@ class MetaReferenceAgentsImpl(Agents):
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -150,7 +153,6 @@ class MetaReferenceAgentsImpl(Agents):
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
allow_turn_resume=allow_turn_resume,
)
if stream:
return self._create_agent_turn_streaming(request)
@ -161,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event
@ -170,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
@ -189,22 +191,18 @@ class MetaReferenceAgentsImpl(Agents):
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
steps = turn.steps
for step in steps:
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
for step in turn.steps:
if step.step_id == step_id:
return AgentStepResponse(step=step)
raise ValueError(f"Provided step_id {step_id} could not be found")
@ -215,20 +213,18 @@ class MetaReferenceAgentsImpl(Agents):
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session:
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
session = Session(**json.loads(session), turns=[])
turns = []
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
turns = await agent.storage.get_session_turns(session_id)
if turn_ids:
for turn_id in turn_ids:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
turns.append(turn)
turns = [turn for turn in turns if turn.turn_id in turn_ids]
return Session(
session_name=session.session_name,
session_name=session_info.session_name,
session_id=session_id,
turns=turns if turns else [],
started_at=session.started_at,
turns=turns,
started_at=session_info.started_at,
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
@ -239,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def shutdown(self) -> None:
pass
async def list_agents(self) -> ListAgentsResponse:
pass
async def get_agent(self, agent_id: str) -> Agent:
pass
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
pass

View file

@ -7,7 +7,7 @@
import json
import logging
import uuid
from datetime import datetime
from datetime import datetime, timezone
from typing import List, Optional
from pydantic import BaseModel
@ -21,6 +21,7 @@ log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
# TODO: is this used anywhere?
vector_db_id: Optional[str] = None
started_at: datetime
@ -35,7 +36,7 @@ class AgentPersistence:
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(),
started_at=datetime.now(timezone.utc),
)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
@ -85,6 +86,14 @@ class AgentPersistence:
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
)
if not value:
return None
return Turn(**json.loads(value))
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
@ -96,3 +105,15 @@ class AgentPersistence:
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
await self.kvstore.set(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
value=str(num_infer_iters),
)
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
value = await self.kvstore.get(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)
return int(value) if value else None

Some files were not shown because too many files have changed in this diff Show more