diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml new file mode 100644 index 000000000..5b63e231c --- /dev/null +++ b/.github/workflows/changelog.yml @@ -0,0 +1,29 @@ +name: Update Changelog + +on: + release: + types: [published, unpublished, created, edited, deleted, released] + +permissions: + contents: read + +jobs: + generate_changelog: + name: Generate changelog + permissions: + contents: write # for peter-evans/create-pull-request to create branch + pull-requests: write # for peter-evans/create-pull-request to create a PR + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: main + fetch-depth: 0 + - run: | + python ./scripts/gen-changelog.py + - uses: peter-evans/create-pull-request@v7 + with: + title: 'docs: update CHANGELOG.md for ${{ github.ref_name }}' + commit-message: 'docs: update CHANGELOG.md for ${{ github.ref_name }}' + branch: create-pull-request/changelog + signoff: true diff --git a/README.md b/README.md index d2adc3376..918433d51 100644 --- a/README.md +++ b/README.md @@ -73,26 +73,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider | Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) | | vLLM | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) | -### Installation - -You have two ways to install this repository: - -* **Install as a package**: - You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command: - ```bash - pip install llama-stack - ``` - -* **Install from source**: - If you prefer to install from the source code, we recommend using [uv](https://github.com/astral-sh/uv). - Then, run the following commands: - ```bash - git clone git@github.com:meta-llama/llama-stack.git - cd llama-stack - - uv sync - uv pip install -e . - ``` ### Documentation diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index ff53f1aed..a5ea562bf 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -10940,23 +10940,6 @@ ], "title": "ScoreBatchResponse" }, - "AlgorithmConfig": { - "oneOf": [ - { - "$ref": "#/components/schemas/LoraFinetuningConfig" - }, - { - "$ref": "#/components/schemas/QATFinetuningConfig" - } - ], - "discriminator": { - "propertyName": "type", - "mapping": { - "LoRA": "#/components/schemas/LoraFinetuningConfig", - "QAT": "#/components/schemas/QATFinetuningConfig" - } - } - }, "LoraFinetuningConfig": { "type": "object", "properties": { @@ -11092,7 +11075,14 @@ "type": "string" }, "algorithm_config": { - "$ref": "#/components/schemas/AlgorithmConfig" + "oneOf": [ + { + "$ref": "#/components/schemas/LoraFinetuningConfig" + }, + { + "$ref": "#/components/schemas/QATFinetuningConfig" + } + ] } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 45546fa11..5676c91b6 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -7500,15 +7500,6 @@ components: required: - results title: ScoreBatchResponse - AlgorithmConfig: - oneOf: - - $ref: '#/components/schemas/LoraFinetuningConfig' - - $ref: '#/components/schemas/QATFinetuningConfig' - discriminator: - propertyName: type - mapping: - LoRA: '#/components/schemas/LoraFinetuningConfig' - QAT: '#/components/schemas/QATFinetuningConfig' LoraFinetuningConfig: type: object properties: @@ -7592,7 +7583,9 @@ components: checkpoint_dir: type: string algorithm_config: - $ref: '#/components/schemas/AlgorithmConfig' + oneOf: + - $ref: '#/components/schemas/LoraFinetuningConfig' + - $ref: '#/components/schemas/QATFinetuningConfig' additionalProperties: false required: - job_uuid diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 636eb7e7b..362f87a26 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -6,7 +6,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Protocol, Union +from typing import Any, Dict, List, Literal, Optional, Protocol from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -89,7 +89,7 @@ class QATFinetuningConfig(BaseModel): AlgorithmConfig = register_schema( - Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")], + Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")], name="AlgorithmConfig", ) @@ -184,7 +184,7 @@ class PostTraining(Protocol): description="Model descriptor from `llama model list`", ), checkpoint_dir: Optional[str] = None, - algorithm_config: Optional[AlgorithmConfig] = None, + algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize", method="POST") diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 7e1d8c016..e16e047e5 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -125,6 +125,13 @@ class LoggingConfig(BaseModel): ) +class AuthenticationConfig(BaseModel): + endpoint: str = Field( + ..., + description="Endpoint URL to validate authentication tokens", + ) + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -140,6 +147,10 @@ class ServerConfig(BaseModel): default=None, description="Path to TLS key file for HTTPS", ) + auth: Optional[AuthenticationConfig] = Field( + default=None, + description="Authentication configuration for the server", + ) class StackRunConfig(BaseModel): diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py new file mode 100644 index 000000000..bb577bae5 --- /dev/null +++ b/llama_stack/distribution/server/auth.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from urllib.parse import parse_qs + +import httpx + +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="auth") + + +class AuthenticationMiddleware: + def __init__(self, app, auth_endpoint): + self.app = app + self.auth_endpoint = auth_endpoint + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + headers = dict(scope.get("headers", [])) + auth_header = headers.get(b"authorization", b"").decode() + + if not auth_header or not auth_header.startswith("Bearer "): + return await self._send_auth_error(send, "Missing or invalid Authorization header") + + api_key = auth_header.split("Bearer ", 1)[1] + + path = scope.get("path", "") + request_headers = {k.decode(): v.decode() for k, v in headers.items()} + + query_string = scope.get("query_string", b"").decode() + params = parse_qs(query_string) + + auth_data = { + "api_key": api_key, + "request": { + "path": path, + "headers": request_headers, + "params": params, + }, + } + + # Validate with authentication endpoint + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.auth_endpoint, json=auth_data) + if response.status_code != 200: + logger.warning(f"Authentication failed: {response.status_code}") + return await self._send_auth_error(send, "Authentication failed") + except Exception: + logger.exception("Error during authentication") + return await self._send_auth_error(send, "Authentication service error") + + return await self.app(scope, receive, send) + + async def _send_auth_error(self, send, message): + await send( + { + "type": "http.response.start", + "status": 401, + "headers": [[b"content-type", b"application/json"]], + } + ) + error_msg = json.dumps({"error": {"message": message}}).encode() + await send({"type": "http.response.body", "body": error_msg}) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index b37b3a007..460acbc87 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -52,6 +52,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( start_trace, ) +from .auth import AuthenticationMiddleware from .endpoints import get_all_api_endpoints REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -351,6 +352,11 @@ def main(): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) + # Add authentication middleware if configured + if config.server.auth and config.server.auth.endpoint: + logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}") + app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint) + try: impls = asyncio.run(construct_stack(config)) except InvalidProviderError as e: diff --git a/llama_stack/distribution/ui/page/distribution/datasets.py b/llama_stack/distribution/ui/page/distribution/datasets.py index b583c93fd..6842b29a7 100644 --- a/llama_stack/distribution/ui/page/distribution/datasets.py +++ b/llama_stack/distribution/ui/page/distribution/datasets.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def datasets(): diff --git a/llama_stack/distribution/ui/page/distribution/eval_tasks.py b/llama_stack/distribution/ui/page/distribution/eval_tasks.py index 1428ae9ab..492be4700 100644 --- a/llama_stack/distribution/ui/page/distribution/eval_tasks.py +++ b/llama_stack/distribution/ui/page/distribution/eval_tasks.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def benchmarks(): diff --git a/llama_stack/distribution/ui/page/distribution/models.py b/llama_stack/distribution/ui/page/distribution/models.py index 3141c1627..f29459098 100644 --- a/llama_stack/distribution/ui/page/distribution/models.py +++ b/llama_stack/distribution/ui/page/distribution/models.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def models(): diff --git a/llama_stack/distribution/ui/page/distribution/providers.py b/llama_stack/distribution/ui/page/distribution/providers.py index 9aeb7f2a5..c660cb986 100644 --- a/llama_stack/distribution/ui/page/distribution/providers.py +++ b/llama_stack/distribution/ui/page/distribution/providers.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def providers(): diff --git a/llama_stack/distribution/ui/page/distribution/resources.py b/llama_stack/distribution/ui/page/distribution/resources.py index 684270d4d..5e10e6e80 100644 --- a/llama_stack/distribution/ui/page/distribution/resources.py +++ b/llama_stack/distribution/ui/page/distribution/resources.py @@ -4,14 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from page.distribution.benchmarks import benchmarks -from page.distribution.datasets import datasets -from page.distribution.models import models -from page.distribution.scoring_functions import scoring_functions -from page.distribution.shields import shields -from page.distribution.vector_dbs import vector_dbs from streamlit_option_menu import option_menu +from llama_stack.distribution.ui.page.distribution.datasets import datasets +from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks +from llama_stack.distribution.ui.page.distribution.models import models +from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions +from llama_stack.distribution.ui.page.distribution.shields import shields +from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs + def resources_page(): options = [ diff --git a/llama_stack/distribution/ui/page/distribution/scoring_functions.py b/llama_stack/distribution/ui/page/distribution/scoring_functions.py index 6a2a08c6d..193146356 100644 --- a/llama_stack/distribution/ui/page/distribution/scoring_functions.py +++ b/llama_stack/distribution/ui/page/distribution/scoring_functions.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def scoring_functions(): diff --git a/llama_stack/distribution/ui/page/distribution/shields.py b/llama_stack/distribution/ui/page/distribution/shields.py index b5ed27ef9..67d66d64f 100644 --- a/llama_stack/distribution/ui/page/distribution/shields.py +++ b/llama_stack/distribution/ui/page/distribution/shields.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def shields(): diff --git a/llama_stack/distribution/ui/page/distribution/vector_dbs.py b/llama_stack/distribution/ui/page/distribution/vector_dbs.py index 1c9d06e8d..49a4f25bb 100644 --- a/llama_stack/distribution/ui/page/distribution/vector_dbs.py +++ b/llama_stack/distribution/ui/page/distribution/vector_dbs.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def vector_dbs(): diff --git a/llama_stack/distribution/ui/page/evaluations/app_eval.py b/llama_stack/distribution/ui/page/evaluations/app_eval.py index 26bc28451..d7bc6388c 100644 --- a/llama_stack/distribution/ui/page/evaluations/app_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/app_eval.py @@ -8,8 +8,9 @@ import json import pandas as pd import streamlit as st -from modules.api import llama_stack_api -from modules.utils import process_dataset + +from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.distribution.ui.modules.utils import process_dataset def application_evaluation_page(): diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py index 7c39adc4a..97f875e17 100644 --- a/llama_stack/distribution/ui/page/evaluations/native_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -8,7 +8,8 @@ import json import pandas as pd import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def select_benchmark_1(): diff --git a/llama_stack/distribution/ui/page/playground/chat.py b/llama_stack/distribution/ui/page/playground/chat.py index e69f559db..8e7345169 100644 --- a/llama_stack/distribution/ui/page/playground/chat.py +++ b/llama_stack/distribution/ui/page/playground/chat.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api # Sidebar configurations with st.sidebar: diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index 7ee934fb7..e2f451668 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -7,9 +7,10 @@ import streamlit as st from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.types.memory_insert_params import Document -from modules.api import llama_stack_api -from modules.utils import data_url_from_file +from llama_stack_client.types.shared.document import Document + +from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.distribution.ui.modules.utils import data_url_from_file def rag_chat_page(): diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 738f9ddcd..e8767c2ff 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -10,6 +10,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import copy import json import logging import multiprocessing @@ -213,7 +214,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage def parse_message(json_str: str) -> ProcessingMessage: data = json.loads(json_str) - return ProcessingMessageWrapper(**data).payload + return copy.deepcopy(ProcessingMessageWrapper(**data).payload) def worker_process_entrypoint( diff --git a/llama_stack/providers/inline/post_training/common/validator.py b/llama_stack/providers/inline/post_training/common/validator.py index e76edf3a0..b0aec6187 100644 --- a/llama_stack/providers/inline/post_training/common/validator.py +++ b/llama_stack/providers/inline/post_training/common/validator.py @@ -9,6 +9,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from typing import Any + from llama_stack.apis.common.type_system import ( ChatCompletionInputType, DialogType, @@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import ( validate_dataset_schema, ) -EXPECTED_DATASET_SCHEMA = { +EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = { "instruct": [ { ColumnName.chat_completion_input.value: ChatCompletionInputType(), @@ -41,6 +44,9 @@ async def validate_input_dataset_schema( dataset_type: str, ) -> None: dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def: + raise ValueError(f"Dataset {dataset_id} does not exist.") + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py index 64d61b053..fcadd0884 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -37,7 +37,7 @@ class TorchtuneCheckpointer: checkpoint_files: List[str], output_dir: str, model_type: str, - ) -> None: + ): # Fail fast if ``checkpoint_files`` is invalid # TODO: support loading more than one file if len(checkpoint_files) != 1: @@ -58,7 +58,7 @@ class TorchtuneCheckpointer: """ Load Meta checkpoint from file. Currently only loading from a single file is supported. """ - state_dict: Dict[str:Any] = {} + state_dict: Dict[str, Any] = {} model_state_dict = safe_torch_load(self._checkpoint_path) if self._model_type == ModelType.LLAMA3_VISION: from torchtune.models.llama3_2_vision._convert_weights import ( @@ -85,10 +85,10 @@ class TorchtuneCheckpointer: state_dict: Dict[str, Any], epoch: int, adapter_only: bool = False, - checkpoint_format: str = "meta", + checkpoint_format: str | None = None, ) -> str: model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}" - if checkpoint_format == "meta": + if checkpoint_format == "meta" or checkpoint_format is None: self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only) elif checkpoint_format == "huggingface": # Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 98e16f9d7..f8a1c0436 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -10,7 +10,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Callable, Dict +from typing import Callable, Dict import torch from pydantic import BaseModel @@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat from llama_stack.models.llama.datatypes import Model from llama_stack.models.llama.sku_list import resolve_model +BuildLoraModelCallable = Callable[..., torch.nn.Module] +BuildTokenizerCallable = Callable[..., Llama3Tokenizer] + class ModelConfig(BaseModel): - model_definition: Any - tokenizer_type: Any + model_definition: BuildLoraModelCallable + tokenizer_type: BuildTokenizerCallable checkpoint_type: str @@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = { } -BuildLoraModelCallable = Callable[..., torch.nn.Module] -BuildTokenizerCallable = Callable[..., Llama3Tokenizer] - - def _validate_model_id(model_id: str) -> Model: model = resolve_model(model_id) if model is None or model.core_model_id.value not in MODEL_CONFIGS: diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py index b556b59a6..050996860 100644 --- a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py @@ -55,7 +55,7 @@ class SFTDataset(Dataset): if "messages" in transformed_sample: validate_messages(transformed_sample["messages"]) - tokenized_dict = self._model_transform(transformed_sample) + tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample) if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): keys_str = ", ".join(tokenized_dict.keys()) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 0f89b4064..edc1ceb90 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( - AlgorithmConfig, Checkpoint, LoraFinetuningConfig, OptimizerConfig, + QATFinetuningConfig, TrainingConfig, ) from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR @@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice: # Currently logging only logs limited training metrics to local disk # will figure out more loggings and how it works with telemetry in future PRs + + _checkpointer: TorchtuneCheckpointer + def __init__( self, config: TorchtunePostTrainingConfig, @@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice: logger_config: Dict[str, Any], model: str, checkpoint_dir: Optional[str], - algorithm_config: Optional[AlgorithmConfig], + algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None, datasetio_api: DatasetIO, datasets_api: Datasets, ) -> None: @@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice: return str(checkpoint_dir) if checkpoint_dir and checkpoint_dir != "null": - self.checkpoint_dir = config.checkpoint_dir + self.checkpoint_dir = checkpoint_dir else: - model = resolve_model(self.model_id) - if model is None: + model_obj = resolve_model(self.model_id) + if model_obj is None: raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list") - self.checkpoint_dir = model_checkpoint_dir(model) + self.checkpoint_dir = model_checkpoint_dir(model_obj) self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self._checkpoint_format = config.checkpoint_format @@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice: self.max_validation_steps = training_config.max_validation_steps self._clip_grad_norm = 1.0 - self._enable_activation_checkpointing = ( - (training_config.efficiency_config.enable_activation_checkpointing) - if training_config.efficiency_config - else False - ) - self._enable_activation_offloading = ( - (training_config.efficiency_config.enable_activation_offloading) - if training_config.efficiency_config - else False - ) + + self._enable_activation_checkpointing = False + self._enable_activation_offloading = False + if training_config.efficiency_config: + if training_config.efficiency_config.enable_activation_checkpointing: + self._enable_activation_checkpointing = ( + training_config.efficiency_config.enable_activation_checkpointing + ) + if training_config.efficiency_config.enable_activation_offloading: + self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading self.datasetio_api = datasetio_api self.datasets_api = datasets_api @@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice: """ # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() - running_loss = 0 + running_loss: float = 0.0 num_tokens = 0 # training artifacts checkpoints = [] - memory_stats = {} + memory_stats: Dict[str, Any] = {} # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients current_loss = await self._loss_step(batch) * current_num_tokens - running_loss += current_loss + running_loss += current_loss.detach().item() current_loss.backward() # Step with optimizer @@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice: # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = running_loss.item() / num_tokens + loss_to_log = running_loss / num_tokens pbar.update(1) pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}") @@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice: ) # Reset running stats for the next step - running_loss = 0 + running_loss = 0.0 num_tokens = 0 t0 = time.perf_counter() diff --git a/pyproject.toml b/pyproject.toml index f57b91462..107150cee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,10 +228,6 @@ exclude = [ "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/post_training/common/validator\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/common/checkpointer\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/common/utils\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/datasets/sft\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device\\.py$", "^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$", "^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/llama_guard/", diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py new file mode 100644 index 000000000..70f08dbd6 --- /dev/null +++ b/tests/unit/server/test_auth.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from llama_stack.distribution.server.auth import AuthenticationMiddleware + + +@pytest.fixture +def mock_auth_endpoint(): + return "http://mock-auth-service/validate" + + +@pytest.fixture +def valid_api_key(): + return "valid_api_key_12345" + + +@pytest.fixture +def invalid_api_key(): + return "invalid_api_key_67890" + + +@pytest.fixture +def app(mock_auth_endpoint): + app = FastAPI() + app.add_middleware(AuthenticationMiddleware, auth_endpoint=mock_auth_endpoint) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +async def mock_post_success(*args, **kwargs): + mock_response = AsyncMock() + mock_response.status_code = 200 + return mock_response + + +async def mock_post_failure(*args, **kwargs): + mock_response = AsyncMock() + mock_response.status_code = 401 + return mock_response + + +async def mock_post_exception(*args, **kwargs): + raise Exception("Connection error") + + +def test_missing_auth_header(client): + response = client.get("/test") + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +def test_invalid_auth_header_format(client): + response = client.get("/test", headers={"Authorization": "InvalidFormat token123"}) + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_post_success) +def test_valid_authentication(client, valid_api_key): + response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"}) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} + + +@patch("httpx.AsyncClient.post", new=mock_post_failure) +def test_invalid_authentication(client, invalid_api_key): + response = client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"}) + assert response.status_code == 401 + assert "Authentication failed" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_post_exception) +def test_auth_service_error(client, valid_api_key): + response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"}) + assert response.status_code == 401 + assert "Authentication service error" in response.json()["error"]["message"] + + +def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint): + with patch("httpx.AsyncClient.post") as mock_post: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + client.get( + "/test?param1=value1¶m2=value2", + headers={ + "Authorization": f"Bearer {valid_api_key}", + "User-Agent": "TestClient", + "Content-Type": "application/json", + }, + ) + + # Check that the auth endpoint was called with the correct payload + call_args = mock_post.call_args + assert call_args is not None + + url, kwargs = call_args[0][0], call_args[1] + assert url == mock_auth_endpoint + + payload = kwargs["json"] + assert payload["api_key"] == valid_api_key + assert payload["request"]["path"] == "/test" + assert "authorization" in payload["request"]["headers"] + assert "param1" in payload["request"]["params"] + assert "param2" in payload["request"]["params"]