Merge branch 'main' into eval_api_final

This commit is contained in:
Xi Yan 2025-03-18 20:17:24 -07:00
commit 24d48b3692
28 changed files with 329 additions and 110 deletions

29
.github/workflows/changelog.yml vendored Normal file
View file

@ -0,0 +1,29 @@
name: Update Changelog
on:
release:
types: [published, unpublished, created, edited, deleted, released]
permissions:
contents: read
jobs:
generate_changelog:
name: Generate changelog
permissions:
contents: write # for peter-evans/create-pull-request to create branch
pull-requests: write # for peter-evans/create-pull-request to create a PR
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
ref: main
fetch-depth: 0
- run: |
python ./scripts/gen-changelog.py
- uses: peter-evans/create-pull-request@v7
with:
title: 'docs: update CHANGELOG.md for ${{ github.ref_name }}'
commit-message: 'docs: update CHANGELOG.md for ${{ github.ref_name }}'
branch: create-pull-request/changelog
signoff: true

View file

@ -73,26 +73,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) |
| vLLM | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) |
### Installation
You have two ways to install this repository:
* **Install as a package**:
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
```bash
pip install llama-stack
```
* **Install from source**:
If you prefer to install from the source code, we recommend using [uv](https://github.com/astral-sh/uv).
Then, run the following commands:
```bash
git clone git@github.com:meta-llama/llama-stack.git
cd llama-stack
uv sync
uv pip install -e .
```
### Documentation

View file

@ -10940,23 +10940,6 @@
],
"title": "ScoreBatchResponse"
},
"AlgorithmConfig": {
"oneOf": [
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"LoRA": "#/components/schemas/LoraFinetuningConfig",
"QAT": "#/components/schemas/QATFinetuningConfig"
}
}
},
"LoraFinetuningConfig": {
"type": "object",
"properties": {
@ -11092,7 +11075,14 @@
"type": "string"
},
"algorithm_config": {
"$ref": "#/components/schemas/AlgorithmConfig"
"oneOf": [
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
]
}
},
"additionalProperties": false,

View file

@ -7500,15 +7500,6 @@ components:
required:
- results
title: ScoreBatchResponse
AlgorithmConfig:
oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
discriminator:
propertyName: type
mapping:
LoRA: '#/components/schemas/LoraFinetuningConfig'
QAT: '#/components/schemas/QATFinetuningConfig'
LoraFinetuningConfig:
type: object
properties:
@ -7592,7 +7583,9 @@ components:
checkpoint_dir:
type: string
algorithm_config:
$ref: '#/components/schemas/AlgorithmConfig'
oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
additionalProperties: false
required:
- job_uuid

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from typing import Any, Dict, List, Literal, Optional, Protocol
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -89,7 +89,7 @@ class QATFinetuningConfig(BaseModel):
AlgorithmConfig = register_schema(
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
name="AlgorithmConfig",
)
@ -184,7 +184,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")

View file

@ -125,6 +125,13 @@ class LoggingConfig(BaseModel):
)
class AuthenticationConfig(BaseModel):
endpoint: str = Field(
...,
description="Endpoint URL to validate authentication tokens",
)
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
@ -140,6 +147,10 @@ class ServerConfig(BaseModel):
default=None,
description="Path to TLS key file for HTTPS",
)
auth: Optional[AuthenticationConfig] = Field(
default=None,
description="Authentication configuration for the server",
)
class StackRunConfig(BaseModel):

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from urllib.parse import parse_qs
import httpx
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthenticationMiddleware:
def __init__(self, app, auth_endpoint):
self.app = app
self.auth_endpoint = auth_endpoint
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if not auth_header or not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Missing or invalid Authorization header")
api_key = auth_header.split("Bearer ", 1)[1]
path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string)
auth_data = {
"api_key": api_key,
"request": {
"path": path,
"headers": request_headers,
"params": params,
},
}
# Validate with authentication endpoint
try:
async with httpx.AsyncClient() as client:
response = await client.post(self.auth_endpoint, json=auth_data)
if response.status_code != 200:
logger.warning(f"Authentication failed: {response.status_code}")
return await self._send_auth_error(send, "Authentication failed")
except Exception:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
return await self.app(scope, receive, send)
async def _send_auth_error(self, send, message):
await send(
{
"type": "http.response.start",
"status": 401,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps({"error": {"message": message}}).encode()
await send({"type": "http.response.body", "body": error_msg})

View file

@ -52,6 +52,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
start_trace,
)
from .auth import AuthenticationMiddleware
from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -351,6 +352,11 @@ def main():
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
if config.server.auth and config.server.auth.endpoint:
logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}")
app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint)
try:
impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e:

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def datasets():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def benchmarks():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def models():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def providers():

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from page.distribution.benchmarks import benchmarks
from page.distribution.datasets import datasets
from page.distribution.models import models
from page.distribution.scoring_functions import scoring_functions
from page.distribution.shields import shields
from page.distribution.vector_dbs import vector_dbs
from streamlit_option_menu import option_menu
from llama_stack.distribution.ui.page.distribution.datasets import datasets
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.distribution.ui.page.distribution.models import models
from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions
from llama_stack.distribution.ui.page.distribution.shields import shields
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
def resources_page():
options = [

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def scoring_functions():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def shields():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def vector_dbs():

View file

@ -8,8 +8,9 @@ import json
import pandas as pd
import streamlit as st
from modules.api import llama_stack_api
from modules.utils import process_dataset
from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import process_dataset
def application_evaluation_page():

View file

@ -8,7 +8,8 @@ import json
import pandas as pd
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def select_benchmark_1():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
# Sidebar configurations
with st.sidebar:

View file

@ -7,9 +7,10 @@
import streamlit as st
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.memory_insert_params import Document
from modules.api import llama_stack_api
from modules.utils import data_url_from_file
from llama_stack_client.types.shared.document import Document
from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import data_url_from_file
def rag_chat_page():

View file

@ -10,6 +10,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import copy
import json
import logging
import multiprocessing
@ -213,7 +214,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
def parse_message(json_str: str) -> ProcessingMessage:
data = json.loads(json_str)
return ProcessingMessageWrapper(**data).payload
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
def worker_process_entrypoint(

View file

@ -9,6 +9,9 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
DialogType,
@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
validate_dataset_schema,
)
EXPECTED_DATASET_SCHEMA = {
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
"instruct": [
{
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
@ -41,6 +44,9 @@ async def validate_input_dataset_schema(
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def:
raise ValueError(f"Dataset {dataset_id} does not exist.")
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")

View file

@ -37,7 +37,7 @@ class TorchtuneCheckpointer:
checkpoint_files: List[str],
output_dir: str,
model_type: str,
) -> None:
):
# Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file
if len(checkpoint_files) != 1:
@ -58,7 +58,7 @@ class TorchtuneCheckpointer:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str:Any] = {}
state_dict: Dict[str, Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
@ -85,10 +85,10 @@ class TorchtuneCheckpointer:
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
checkpoint_format: str = "meta",
checkpoint_format: str | None = None,
) -> str:
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
if checkpoint_format == "meta":
if checkpoint_format == "meta" or checkpoint_format is None:
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
elif checkpoint_format == "huggingface":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now

View file

@ -10,7 +10,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Callable, Dict
from typing import Callable, Dict
import torch
from pydantic import BaseModel
@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import resolve_model
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
class ModelConfig(BaseModel):
model_definition: Any
tokenizer_type: Any
model_definition: BuildLoraModelCallable
tokenizer_type: BuildTokenizerCallable
checkpoint_type: str
@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = {
}
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
def _validate_model_id(model_id: str) -> Model:
model = resolve_model(model_id)
if model is None or model.core_model_id.value not in MODEL_CONFIGS:

View file

@ -55,7 +55,7 @@ class SFTDataset(Dataset):
if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"])
tokenized_dict = self._model_transform(transformed_sample)
tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample)
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys())

View file

@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
LoraFinetuningConfig,
OptimizerConfig,
QATFinetuningConfig,
TrainingConfig,
)
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice:
# Currently logging only logs limited training metrics to local disk
# will figure out more loggings and how it works with telemetry in future PRs
_checkpointer: TorchtuneCheckpointer
def __init__(
self,
config: TorchtunePostTrainingConfig,
@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice:
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice:
return str(checkpoint_dir)
if checkpoint_dir and checkpoint_dir != "null":
self.checkpoint_dir = config.checkpoint_dir
self.checkpoint_dir = checkpoint_dir
else:
model = resolve_model(self.model_id)
if model is None:
model_obj = resolve_model(self.model_id)
if model_obj is None:
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
self.checkpoint_dir = model_checkpoint_dir(model)
self.checkpoint_dir = model_checkpoint_dir(model_obj)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self._checkpoint_format = config.checkpoint_format
@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice:
self.max_validation_steps = training_config.max_validation_steps
self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = False
self._enable_activation_offloading = False
if training_config.efficiency_config:
if training_config.efficiency_config.enable_activation_checkpointing:
self._enable_activation_checkpointing = (
(training_config.efficiency_config.enable_activation_checkpointing)
if training_config.efficiency_config
else False
)
self._enable_activation_offloading = (
(training_config.efficiency_config.enable_activation_offloading)
if training_config.efficiency_config
else False
training_config.efficiency_config.enable_activation_checkpointing
)
if training_config.efficiency_config.enable_activation_offloading:
self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice:
"""
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
running_loss = 0
running_loss: float = 0.0
num_tokens = 0
# training artifacts
checkpoints = []
memory_stats = {}
memory_stats: Dict[str, Any] = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice:
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = await self._loss_step(batch) * current_num_tokens
running_loss += current_loss
running_loss += current_loss.detach().item()
current_loss.backward()
# Step with optimizer
@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice:
# Update the number of steps when the weights are updated
self.global_step += 1
loss_to_log = running_loss.item() / num_tokens
loss_to_log = running_loss / num_tokens
pbar.update(1)
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice:
)
# Reset running stats for the next step
running_loss = 0
running_loss = 0.0
num_tokens = 0
t0 = time.perf_counter()

View file

@ -228,10 +228,6 @@ exclude = [
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/inference/vllm/",
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/common/checkpointer\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/common/utils\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/datasets/sft\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$",
"^llama_stack/providers/inline/safety/code_scanner/",
"^llama_stack/providers/inline/safety/llama_guard/",

View file

@ -0,0 +1,124 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from llama_stack.distribution.server.auth import AuthenticationMiddleware
@pytest.fixture
def mock_auth_endpoint():
return "http://mock-auth-service/validate"
@pytest.fixture
def valid_api_key():
return "valid_api_key_12345"
@pytest.fixture
def invalid_api_key():
return "invalid_api_key_67890"
@pytest.fixture
def app(mock_auth_endpoint):
app = FastAPI()
app.add_middleware(AuthenticationMiddleware, auth_endpoint=mock_auth_endpoint)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def client(app):
return TestClient(app)
async def mock_post_success(*args, **kwargs):
mock_response = AsyncMock()
mock_response.status_code = 200
return mock_response
async def mock_post_failure(*args, **kwargs):
mock_response = AsyncMock()
mock_response.status_code = 401
return mock_response
async def mock_post_exception(*args, **kwargs):
raise Exception("Connection error")
def test_missing_auth_header(client):
response = client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
def test_invalid_auth_header_format(client):
response = client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
@patch("httpx.AsyncClient.post", new=mock_post_success)
def test_valid_authentication(client, valid_api_key):
response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
@patch("httpx.AsyncClient.post", new=mock_post_failure)
def test_invalid_authentication(client, invalid_api_key):
response = client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Authentication failed" in response.json()["error"]["message"]
@patch("httpx.AsyncClient.post", new=mock_post_exception)
def test_auth_service_error(client, valid_api_key):
response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
assert response.status_code == 401
assert "Authentication service error" in response.json()["error"]["message"]
def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
with patch("httpx.AsyncClient.post") as mock_post:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
client.get(
"/test?param1=value1&param2=value2",
headers={
"Authorization": f"Bearer {valid_api_key}",
"User-Agent": "TestClient",
"Content-Type": "application/json",
},
)
# Check that the auth endpoint was called with the correct payload
call_args = mock_post.call_args
assert call_args is not None
url, kwargs = call_args[0][0], call_args[1]
assert url == mock_auth_endpoint
payload = kwargs["json"]
assert payload["api_key"] == valid_api_key
assert payload["request"]["path"] == "/test"
assert "authorization" in payload["request"]["headers"]
assert "param1" in payload["request"]["params"]
assert "param2" in payload["request"]["params"]