From 22e560351e5b74ffcb3bfd5a30a9b4772c130b57 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Tue, 18 Mar 2025 17:39:22 -0400 Subject: [PATCH 1/6] ci: Add scheduled workflow to update changelog (#1503) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This is a follow up from https://github.com/meta-llama/llama-stack/pull/1463. cc @yanxi0830 --------- Signed-off-by: Yuan Tang Co-authored-by: Sébastien Han --- .github/workflows/changelog.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 .github/workflows/changelog.yml diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml new file mode 100644 index 000000000..5b63e231c --- /dev/null +++ b/.github/workflows/changelog.yml @@ -0,0 +1,29 @@ +name: Update Changelog + +on: + release: + types: [published, unpublished, created, edited, deleted, released] + +permissions: + contents: read + +jobs: + generate_changelog: + name: Generate changelog + permissions: + contents: write # for peter-evans/create-pull-request to create branch + pull-requests: write # for peter-evans/create-pull-request to create a PR + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: main + fetch-depth: 0 + - run: | + python ./scripts/gen-changelog.py + - uses: peter-evans/create-pull-request@v7 + with: + title: 'docs: update CHANGELOG.md for ${{ github.ref_name }}' + commit-message: 'docs: update CHANGELOG.md for ${{ github.ref_name }}' + branch: create-pull-request/changelog + signoff: true From f86f3cf8783e8923f9c67658d06187a6535e842f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 18 Mar 2025 22:52:21 +0100 Subject: [PATCH 2/6] docs: remove redundant installation instructions (#1138) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? The previous installation instructions were mostly duplicating information already covered in the documentation, either in the “Start a Server” or “Contributing Guide” sections. Removed these redundant details to avoid confusion and streamline the setup process. Signed-off-by: Sébastien Han Signed-off-by: Sébastien Han --- README.md | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/README.md b/README.md index d2adc3376..918433d51 100644 --- a/README.md +++ b/README.md @@ -73,26 +73,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider | Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) | | vLLM | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) | -### Installation - -You have two ways to install this repository: - -* **Install as a package**: - You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command: - ```bash - pip install llama-stack - ``` - -* **Install from source**: - If you prefer to install from the source code, we recommend using [uv](https://github.com/astral-sh/uv). - Then, run the following commands: - ```bash - git clone git@github.com:meta-llama/llama-stack.git - cd llama-stack - - uv sync - uv pip install -e . - ``` ### Documentation From 0cbb7f7f21982bf943a257009bd916dbfe510122 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Tue, 18 Mar 2025 17:58:16 -0400 Subject: [PATCH 3/6] chore: fix mypy violations in post_training modules (#1548) # What does this PR do? Fixes a bunch of violations. Note: this patch touches all files but post_training.py that will be significantly changed by #1437, hence leaving it out of the picture for now. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Testing with https://github.com/meta-llama/llama-stack/pull/1543 Also checked that GPU training works with the change: ``` INFO: ::1:53316 - "POST /v1/post-training/supervised-fine-tune HTTP/1.1" 200 OK INFO: ::1:53316 - "GET /v1/post-training/job/status?job_uuid=test-jobb5ca2d84-d541-42f8-883b-762828b4c0e7 HTTP/1.1" 200 OK INFO: ::1:53316 - "GET /v1/post-training/job/artifacts?job_uuid=test-jobb5ca2d84-d541-42f8-883b-762828b4c0e7 HTTP/1.1" 200 OK 21:24:01.161 [END] /v1/post-training/supervised-fine-tune [StatusCode.OK] (32526.75ms) 21:23:28.769 [DEBUG] Setting manual seed to local seed 3918872849. Local seed is seed + rank = 3918872849 + 0 21:23:28.996 [INFO] Identified model_type = Llama3_2. Ignoring output.weight in checkpoint in favor of the tok_embedding.weight tied weights. 21:23:29.933 [INFO] Memory stats after model init: GPU peak memory allocation: 6.05 GiB GPU peak memory reserved: 6.10 GiB GPU peak memory active: 6.05 GiB 21:23:29.934 [INFO] Model is initialized with precision torch.bfloat16. 21:23:30.115 [INFO] Tokenizer is initialized. 21:23:30.118 [INFO] Optimizer is initialized. 21:23:30.119 [INFO] Loss is initialized. 21:23:30.896 [INFO] Dataset and Sampler are initialized. 21:23:30.898 [INFO] Learning rate scheduler is initialized. 21:23:31.618 [INFO] Memory stats after model init: GPU peak memory allocation: 6.24 GiB GPU peak memory reserved: 6.30 GiB GPU peak memory active: 6.24 GiB 21:23:31.620 [INFO] Starting checkpoint save... 21:23:59.428 [INFO] Model checkpoint of size 6.43 GB saved to /home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/consolidated.00.pth 21:23:59.445 [INFO] Adapter checkpoint of size 0.00 GB saved to /home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/adapter/adapter.pth ``` [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka --- docs/_static/llama-stack-spec.html | 26 ++++------- docs/_static/llama-stack-spec.yaml | 13 ++---- .../apis/post_training/post_training.py | 6 +-- .../inline/post_training/common/validator.py | 8 +++- .../torchtune/common/checkpointer.py | 8 ++-- .../post_training/torchtune/common/utils.py | 13 +++--- .../post_training/torchtune/datasets/sft.py | 2 +- .../recipes/lora_finetuning_single_device.py | 45 ++++++++++--------- pyproject.toml | 4 -- 9 files changed, 56 insertions(+), 69 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 72b2e6b17..2362dfa53 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9847,23 +9847,6 @@ ], "title": "ScoreBatchResponse" }, - "AlgorithmConfig": { - "oneOf": [ - { - "$ref": "#/components/schemas/LoraFinetuningConfig" - }, - { - "$ref": "#/components/schemas/QATFinetuningConfig" - } - ], - "discriminator": { - "propertyName": "type", - "mapping": { - "LoRA": "#/components/schemas/LoraFinetuningConfig", - "QAT": "#/components/schemas/QATFinetuningConfig" - } - } - }, "LoraFinetuningConfig": { "type": "object", "properties": { @@ -9999,7 +9982,14 @@ "type": "string" }, "algorithm_config": { - "$ref": "#/components/schemas/AlgorithmConfig" + "oneOf": [ + { + "$ref": "#/components/schemas/LoraFinetuningConfig" + }, + { + "$ref": "#/components/schemas/QATFinetuningConfig" + } + ] } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 6f4a9528b..38e08e41c 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6678,15 +6678,6 @@ components: required: - results title: ScoreBatchResponse - AlgorithmConfig: - oneOf: - - $ref: '#/components/schemas/LoraFinetuningConfig' - - $ref: '#/components/schemas/QATFinetuningConfig' - discriminator: - propertyName: type - mapping: - LoRA: '#/components/schemas/LoraFinetuningConfig' - QAT: '#/components/schemas/QATFinetuningConfig' LoraFinetuningConfig: type: object properties: @@ -6770,7 +6761,9 @@ components: checkpoint_dir: type: string algorithm_config: - $ref: '#/components/schemas/AlgorithmConfig' + oneOf: + - $ref: '#/components/schemas/LoraFinetuningConfig' + - $ref: '#/components/schemas/QATFinetuningConfig' additionalProperties: false required: - job_uuid diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 636eb7e7b..362f87a26 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -6,7 +6,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Protocol, Union +from typing import Any, Dict, List, Literal, Optional, Protocol from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -89,7 +89,7 @@ class QATFinetuningConfig(BaseModel): AlgorithmConfig = register_schema( - Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")], + Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")], name="AlgorithmConfig", ) @@ -184,7 +184,7 @@ class PostTraining(Protocol): description="Model descriptor from `llama model list`", ), checkpoint_dir: Optional[str] = None, - algorithm_config: Optional[AlgorithmConfig] = None, + algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize", method="POST") diff --git a/llama_stack/providers/inline/post_training/common/validator.py b/llama_stack/providers/inline/post_training/common/validator.py index e76edf3a0..b0aec6187 100644 --- a/llama_stack/providers/inline/post_training/common/validator.py +++ b/llama_stack/providers/inline/post_training/common/validator.py @@ -9,6 +9,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from typing import Any + from llama_stack.apis.common.type_system import ( ChatCompletionInputType, DialogType, @@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import ( validate_dataset_schema, ) -EXPECTED_DATASET_SCHEMA = { +EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = { "instruct": [ { ColumnName.chat_completion_input.value: ChatCompletionInputType(), @@ -41,6 +44,9 @@ async def validate_input_dataset_schema( dataset_type: str, ) -> None: dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def: + raise ValueError(f"Dataset {dataset_id} does not exist.") + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py index 64d61b053..fcadd0884 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -37,7 +37,7 @@ class TorchtuneCheckpointer: checkpoint_files: List[str], output_dir: str, model_type: str, - ) -> None: + ): # Fail fast if ``checkpoint_files`` is invalid # TODO: support loading more than one file if len(checkpoint_files) != 1: @@ -58,7 +58,7 @@ class TorchtuneCheckpointer: """ Load Meta checkpoint from file. Currently only loading from a single file is supported. """ - state_dict: Dict[str:Any] = {} + state_dict: Dict[str, Any] = {} model_state_dict = safe_torch_load(self._checkpoint_path) if self._model_type == ModelType.LLAMA3_VISION: from torchtune.models.llama3_2_vision._convert_weights import ( @@ -85,10 +85,10 @@ class TorchtuneCheckpointer: state_dict: Dict[str, Any], epoch: int, adapter_only: bool = False, - checkpoint_format: str = "meta", + checkpoint_format: str | None = None, ) -> str: model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}" - if checkpoint_format == "meta": + if checkpoint_format == "meta" or checkpoint_format is None: self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only) elif checkpoint_format == "huggingface": # Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 98e16f9d7..f8a1c0436 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -10,7 +10,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Callable, Dict +from typing import Callable, Dict import torch from pydantic import BaseModel @@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat from llama_stack.models.llama.datatypes import Model from llama_stack.models.llama.sku_list import resolve_model +BuildLoraModelCallable = Callable[..., torch.nn.Module] +BuildTokenizerCallable = Callable[..., Llama3Tokenizer] + class ModelConfig(BaseModel): - model_definition: Any - tokenizer_type: Any + model_definition: BuildLoraModelCallable + tokenizer_type: BuildTokenizerCallable checkpoint_type: str @@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = { } -BuildLoraModelCallable = Callable[..., torch.nn.Module] -BuildTokenizerCallable = Callable[..., Llama3Tokenizer] - - def _validate_model_id(model_id: str) -> Model: model = resolve_model(model_id) if model is None or model.core_model_id.value not in MODEL_CONFIGS: diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py index b556b59a6..050996860 100644 --- a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py @@ -55,7 +55,7 @@ class SFTDataset(Dataset): if "messages" in transformed_sample: validate_messages(transformed_sample["messages"]) - tokenized_dict = self._model_transform(transformed_sample) + tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample) if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): keys_str = ", ".join(tokenized_dict.keys()) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 0f89b4064..edc1ceb90 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( - AlgorithmConfig, Checkpoint, LoraFinetuningConfig, OptimizerConfig, + QATFinetuningConfig, TrainingConfig, ) from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR @@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice: # Currently logging only logs limited training metrics to local disk # will figure out more loggings and how it works with telemetry in future PRs + + _checkpointer: TorchtuneCheckpointer + def __init__( self, config: TorchtunePostTrainingConfig, @@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice: logger_config: Dict[str, Any], model: str, checkpoint_dir: Optional[str], - algorithm_config: Optional[AlgorithmConfig], + algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None, datasetio_api: DatasetIO, datasets_api: Datasets, ) -> None: @@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice: return str(checkpoint_dir) if checkpoint_dir and checkpoint_dir != "null": - self.checkpoint_dir = config.checkpoint_dir + self.checkpoint_dir = checkpoint_dir else: - model = resolve_model(self.model_id) - if model is None: + model_obj = resolve_model(self.model_id) + if model_obj is None: raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list") - self.checkpoint_dir = model_checkpoint_dir(model) + self.checkpoint_dir = model_checkpoint_dir(model_obj) self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self._checkpoint_format = config.checkpoint_format @@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice: self.max_validation_steps = training_config.max_validation_steps self._clip_grad_norm = 1.0 - self._enable_activation_checkpointing = ( - (training_config.efficiency_config.enable_activation_checkpointing) - if training_config.efficiency_config - else False - ) - self._enable_activation_offloading = ( - (training_config.efficiency_config.enable_activation_offloading) - if training_config.efficiency_config - else False - ) + + self._enable_activation_checkpointing = False + self._enable_activation_offloading = False + if training_config.efficiency_config: + if training_config.efficiency_config.enable_activation_checkpointing: + self._enable_activation_checkpointing = ( + training_config.efficiency_config.enable_activation_checkpointing + ) + if training_config.efficiency_config.enable_activation_offloading: + self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading self.datasetio_api = datasetio_api self.datasets_api = datasets_api @@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice: """ # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() - running_loss = 0 + running_loss: float = 0.0 num_tokens = 0 # training artifacts checkpoints = [] - memory_stats = {} + memory_stats: Dict[str, Any] = {} # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients current_loss = await self._loss_step(batch) * current_num_tokens - running_loss += current_loss + running_loss += current_loss.detach().item() current_loss.backward() # Step with optimizer @@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice: # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = running_loss.item() / num_tokens + loss_to_log = running_loss / num_tokens pbar.update(1) pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}") @@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice: ) # Reset running stats for the next step - running_loss = 0 + running_loss = 0.0 num_tokens = 0 t0 = time.perf_counter() diff --git a/pyproject.toml b/pyproject.toml index f57b91462..107150cee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,10 +228,6 @@ exclude = [ "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/post_training/common/validator\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/common/checkpointer\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/common/utils\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/datasets/sft\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device\\.py$", "^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$", "^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/llama_guard/", From 9c8e88ea9ca756dd10b2db0e68a4166e35c6e5ff Mon Sep 17 00:00:00 2001 From: Sarthak Deshpande <60317842+cheesecake100201@users.noreply.github.com> Date: Wed, 19 Mar 2025 03:30:48 +0530 Subject: [PATCH 4/6] fix: Fixed import errors for UI and playground (#1666) # What does this PR do? Fixed import errors for playground and ui --------- Co-authored-by: sarthakdeshpande --- .../distribution/ui/page/distribution/datasets.py | 3 ++- .../distribution/ui/page/distribution/eval_tasks.py | 3 ++- .../distribution/ui/page/distribution/models.py | 3 ++- .../distribution/ui/page/distribution/providers.py | 3 ++- .../distribution/ui/page/distribution/resources.py | 13 +++++++------ .../ui/page/distribution/scoring_functions.py | 3 ++- .../distribution/ui/page/distribution/shields.py | 3 ++- .../distribution/ui/page/distribution/vector_dbs.py | 3 ++- .../distribution/ui/page/evaluations/app_eval.py | 5 +++-- .../distribution/ui/page/evaluations/native_eval.py | 3 ++- llama_stack/distribution/ui/page/playground/chat.py | 3 ++- llama_stack/distribution/ui/page/playground/rag.py | 7 ++++--- 12 files changed, 32 insertions(+), 20 deletions(-) diff --git a/llama_stack/distribution/ui/page/distribution/datasets.py b/llama_stack/distribution/ui/page/distribution/datasets.py index b583c93fd..6842b29a7 100644 --- a/llama_stack/distribution/ui/page/distribution/datasets.py +++ b/llama_stack/distribution/ui/page/distribution/datasets.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def datasets(): diff --git a/llama_stack/distribution/ui/page/distribution/eval_tasks.py b/llama_stack/distribution/ui/page/distribution/eval_tasks.py index 1428ae9ab..492be4700 100644 --- a/llama_stack/distribution/ui/page/distribution/eval_tasks.py +++ b/llama_stack/distribution/ui/page/distribution/eval_tasks.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def benchmarks(): diff --git a/llama_stack/distribution/ui/page/distribution/models.py b/llama_stack/distribution/ui/page/distribution/models.py index 3141c1627..f29459098 100644 --- a/llama_stack/distribution/ui/page/distribution/models.py +++ b/llama_stack/distribution/ui/page/distribution/models.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def models(): diff --git a/llama_stack/distribution/ui/page/distribution/providers.py b/llama_stack/distribution/ui/page/distribution/providers.py index 9aeb7f2a5..c660cb986 100644 --- a/llama_stack/distribution/ui/page/distribution/providers.py +++ b/llama_stack/distribution/ui/page/distribution/providers.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def providers(): diff --git a/llama_stack/distribution/ui/page/distribution/resources.py b/llama_stack/distribution/ui/page/distribution/resources.py index 684270d4d..5e10e6e80 100644 --- a/llama_stack/distribution/ui/page/distribution/resources.py +++ b/llama_stack/distribution/ui/page/distribution/resources.py @@ -4,14 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from page.distribution.benchmarks import benchmarks -from page.distribution.datasets import datasets -from page.distribution.models import models -from page.distribution.scoring_functions import scoring_functions -from page.distribution.shields import shields -from page.distribution.vector_dbs import vector_dbs from streamlit_option_menu import option_menu +from llama_stack.distribution.ui.page.distribution.datasets import datasets +from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks +from llama_stack.distribution.ui.page.distribution.models import models +from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions +from llama_stack.distribution.ui.page.distribution.shields import shields +from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs + def resources_page(): options = [ diff --git a/llama_stack/distribution/ui/page/distribution/scoring_functions.py b/llama_stack/distribution/ui/page/distribution/scoring_functions.py index 6a2a08c6d..193146356 100644 --- a/llama_stack/distribution/ui/page/distribution/scoring_functions.py +++ b/llama_stack/distribution/ui/page/distribution/scoring_functions.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def scoring_functions(): diff --git a/llama_stack/distribution/ui/page/distribution/shields.py b/llama_stack/distribution/ui/page/distribution/shields.py index b5ed27ef9..67d66d64f 100644 --- a/llama_stack/distribution/ui/page/distribution/shields.py +++ b/llama_stack/distribution/ui/page/distribution/shields.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def shields(): diff --git a/llama_stack/distribution/ui/page/distribution/vector_dbs.py b/llama_stack/distribution/ui/page/distribution/vector_dbs.py index 1c9d06e8d..49a4f25bb 100644 --- a/llama_stack/distribution/ui/page/distribution/vector_dbs.py +++ b/llama_stack/distribution/ui/page/distribution/vector_dbs.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def vector_dbs(): diff --git a/llama_stack/distribution/ui/page/evaluations/app_eval.py b/llama_stack/distribution/ui/page/evaluations/app_eval.py index 26bc28451..d7bc6388c 100644 --- a/llama_stack/distribution/ui/page/evaluations/app_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/app_eval.py @@ -8,8 +8,9 @@ import json import pandas as pd import streamlit as st -from modules.api import llama_stack_api -from modules.utils import process_dataset + +from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.distribution.ui.modules.utils import process_dataset def application_evaluation_page(): diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py index 7c39adc4a..97f875e17 100644 --- a/llama_stack/distribution/ui/page/evaluations/native_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -8,7 +8,8 @@ import json import pandas as pd import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def select_benchmark_1(): diff --git a/llama_stack/distribution/ui/page/playground/chat.py b/llama_stack/distribution/ui/page/playground/chat.py index e69f559db..8e7345169 100644 --- a/llama_stack/distribution/ui/page/playground/chat.py +++ b/llama_stack/distribution/ui/page/playground/chat.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api # Sidebar configurations with st.sidebar: diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index 7ee934fb7..e2f451668 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -7,9 +7,10 @@ import streamlit as st from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.types.memory_insert_params import Document -from modules.api import llama_stack_api -from modules.utils import data_url_from_file +from llama_stack_client.types.shared.document import Document + +from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.distribution.ui.modules.utils import data_url_from_file def rag_chat_page(): From b79e0435de6be38a6dd4061b8748939305815750 Mon Sep 17 00:00:00 2001 From: yyymeta <123776235+yyymeta@users.noreply.github.com> Date: Tue, 18 Mar 2025 16:17:29 -0700 Subject: [PATCH 5/6] fix: avoid tensor memory error (#1688) # What does this PR do? we randomly get errors like the following, it's most likely due to accessing an object that is already deallocated ``` E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] Traceback (most recent call last): E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/.conda/envs/myenv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 90, in _wrap E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] fn(i, *args) E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/.conda/envs/myenv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 611, in _wrap E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] ret = record(fn)(*args_) E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/.conda/envs/myenv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] return f(*args, **kwargs) E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/internal-llama-stack/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py", line 249, in worker_process_entrypoint E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] task = req_gen.send(result) E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/internal-llama-stack/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py", line 156, in retrieve_requests E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] torch.distributed.broadcast_object_list( E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/.conda/envs/myenv/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] return func(*args, **kwargs) E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/.conda/envs/myenv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3504, in broadcast_object_list E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] object_list[i] = _tensor_to_object(obj_view, obj_size, group) E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/.conda/envs/myenv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2961, in _tensor_to_object E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] return _unpickler(io.BytesIO(buf)).load() E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] EOFError: Ran out of input E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] Process SpawnProcess-1: Traceback (most recent call last): ``` ## Test Plan start server ``` llama-stack-client eval run-benchmark mmmu_v1 --model-id meta-llama/Llama-4-17B-Omni-Instruct --output-dir /tmp/mmmu_standard --num-examples 30 ``` [//]: # (## Documentation) --- .../inline/inference/meta_reference/parallel_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 738f9ddcd..e8767c2ff 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -10,6 +10,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import copy import json import logging import multiprocessing @@ -213,7 +214,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage def parse_message(json_str: str) -> ProcessingMessage: data = json.loads(json_str) - return ProcessingMessageWrapper(**data).payload + return copy.deepcopy(ProcessingMessageWrapper(**data).payload) def worker_process_entrypoint( From 5b39d5a76af13f055974c5cd1d66a31c92f01ccd Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 18 Mar 2025 16:24:18 -0700 Subject: [PATCH 6/6] feat(auth, rfc): Add support for Bearer (api_key) Authentication (#1626) This PR adds support (or is a proposal for) for supporting API KEY authentication on the Llama Stack server end. `llama-stack-client` already supports accepting an api_key parameter and passes it down through every request as an `Authentication: ` header. Currently, Llama Stack does not propose APIs for handling authentication or authorization for resources of any kind. Given that, and the fact that any deployment will typically have _some_ authentication system present, we simply adopt a delegation mechanism: delegate to an HTTPS endpoint performing key management / authentication. It is configured via: ```yaml server: auth: endpoint: <...> ``` in the run.yaml configuration. ## How It Works When authentication is enabled: 1. Every API request must include an `Authorization: Bearer ` header 2. The server will send a _POST_ validation request to the configured endpoint with the following payload: ```json { "api_key": "", "request": { "path": "/api/path", "headers": { "header1": "value1", ... }, "params": { "param1": "value1", ... } } } ``` 3. If the authentication endpoint returns a 200 status code, the request is allowed to proceed 4. If the authentication endpoint returns any other status code, a 401 Unauthorized response is returned ## Test Plan Unit tests --- llama_stack/distribution/datatypes.py | 11 ++ llama_stack/distribution/server/auth.py | 69 ++++++++++++ llama_stack/distribution/server/server.py | 6 ++ tests/unit/server/test_auth.py | 124 ++++++++++++++++++++++ 4 files changed, 210 insertions(+) create mode 100644 llama_stack/distribution/server/auth.py create mode 100644 tests/unit/server/test_auth.py diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 7e1d8c016..e16e047e5 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -125,6 +125,13 @@ class LoggingConfig(BaseModel): ) +class AuthenticationConfig(BaseModel): + endpoint: str = Field( + ..., + description="Endpoint URL to validate authentication tokens", + ) + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -140,6 +147,10 @@ class ServerConfig(BaseModel): default=None, description="Path to TLS key file for HTTPS", ) + auth: Optional[AuthenticationConfig] = Field( + default=None, + description="Authentication configuration for the server", + ) class StackRunConfig(BaseModel): diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py new file mode 100644 index 000000000..bb577bae5 --- /dev/null +++ b/llama_stack/distribution/server/auth.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from urllib.parse import parse_qs + +import httpx + +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="auth") + + +class AuthenticationMiddleware: + def __init__(self, app, auth_endpoint): + self.app = app + self.auth_endpoint = auth_endpoint + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + headers = dict(scope.get("headers", [])) + auth_header = headers.get(b"authorization", b"").decode() + + if not auth_header or not auth_header.startswith("Bearer "): + return await self._send_auth_error(send, "Missing or invalid Authorization header") + + api_key = auth_header.split("Bearer ", 1)[1] + + path = scope.get("path", "") + request_headers = {k.decode(): v.decode() for k, v in headers.items()} + + query_string = scope.get("query_string", b"").decode() + params = parse_qs(query_string) + + auth_data = { + "api_key": api_key, + "request": { + "path": path, + "headers": request_headers, + "params": params, + }, + } + + # Validate with authentication endpoint + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.auth_endpoint, json=auth_data) + if response.status_code != 200: + logger.warning(f"Authentication failed: {response.status_code}") + return await self._send_auth_error(send, "Authentication failed") + except Exception: + logger.exception("Error during authentication") + return await self._send_auth_error(send, "Authentication service error") + + return await self.app(scope, receive, send) + + async def _send_auth_error(self, send, message): + await send( + { + "type": "http.response.start", + "status": 401, + "headers": [[b"content-type", b"application/json"]], + } + ) + error_msg = json.dumps({"error": {"message": message}}).encode() + await send({"type": "http.response.body", "body": error_msg}) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index b37b3a007..460acbc87 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -52,6 +52,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( start_trace, ) +from .auth import AuthenticationMiddleware from .endpoints import get_all_api_endpoints REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -351,6 +352,11 @@ def main(): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) + # Add authentication middleware if configured + if config.server.auth and config.server.auth.endpoint: + logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}") + app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint) + try: impls = asyncio.run(construct_stack(config)) except InvalidProviderError as e: diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py new file mode 100644 index 000000000..70f08dbd6 --- /dev/null +++ b/tests/unit/server/test_auth.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from llama_stack.distribution.server.auth import AuthenticationMiddleware + + +@pytest.fixture +def mock_auth_endpoint(): + return "http://mock-auth-service/validate" + + +@pytest.fixture +def valid_api_key(): + return "valid_api_key_12345" + + +@pytest.fixture +def invalid_api_key(): + return "invalid_api_key_67890" + + +@pytest.fixture +def app(mock_auth_endpoint): + app = FastAPI() + app.add_middleware(AuthenticationMiddleware, auth_endpoint=mock_auth_endpoint) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +async def mock_post_success(*args, **kwargs): + mock_response = AsyncMock() + mock_response.status_code = 200 + return mock_response + + +async def mock_post_failure(*args, **kwargs): + mock_response = AsyncMock() + mock_response.status_code = 401 + return mock_response + + +async def mock_post_exception(*args, **kwargs): + raise Exception("Connection error") + + +def test_missing_auth_header(client): + response = client.get("/test") + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +def test_invalid_auth_header_format(client): + response = client.get("/test", headers={"Authorization": "InvalidFormat token123"}) + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_post_success) +def test_valid_authentication(client, valid_api_key): + response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"}) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} + + +@patch("httpx.AsyncClient.post", new=mock_post_failure) +def test_invalid_authentication(client, invalid_api_key): + response = client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"}) + assert response.status_code == 401 + assert "Authentication failed" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_post_exception) +def test_auth_service_error(client, valid_api_key): + response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"}) + assert response.status_code == 401 + assert "Authentication service error" in response.json()["error"]["message"] + + +def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint): + with patch("httpx.AsyncClient.post") as mock_post: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + client.get( + "/test?param1=value1¶m2=value2", + headers={ + "Authorization": f"Bearer {valid_api_key}", + "User-Agent": "TestClient", + "Content-Type": "application/json", + }, + ) + + # Check that the auth endpoint was called with the correct payload + call_args = mock_post.call_args + assert call_args is not None + + url, kwargs = call_args[0][0], call_args[1] + assert url == mock_auth_endpoint + + payload = kwargs["json"] + assert payload["api_key"] == valid_api_key + assert payload["request"]["path"] == "/test" + assert "authorization" in payload["request"]["headers"] + assert "param1" in payload["request"]["params"] + assert "param2" in payload["request"]["params"]