forked from phoenix-oss/llama-stack-mirror
Merge branch 'main' into eval_api_final
This commit is contained in:
commit
24d48b3692
28 changed files with 329 additions and 110 deletions
29
.github/workflows/changelog.yml
vendored
Normal file
29
.github/workflows/changelog.yml
vendored
Normal file
|
@ -0,0 +1,29 @@
|
|||
name: Update Changelog
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published, unpublished, created, edited, deleted, released]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
generate_changelog:
|
||||
name: Generate changelog
|
||||
permissions:
|
||||
contents: write # for peter-evans/create-pull-request to create branch
|
||||
pull-requests: write # for peter-evans/create-pull-request to create a PR
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: main
|
||||
fetch-depth: 0
|
||||
- run: |
|
||||
python ./scripts/gen-changelog.py
|
||||
- uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
title: 'docs: update CHANGELOG.md for ${{ github.ref_name }}'
|
||||
commit-message: 'docs: update CHANGELOG.md for ${{ github.ref_name }}'
|
||||
branch: create-pull-request/changelog
|
||||
signoff: true
|
20
README.md
20
README.md
|
@ -73,26 +73,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
|
|||
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) |
|
||||
| vLLM | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) |
|
||||
|
||||
### Installation
|
||||
|
||||
You have two ways to install this repository:
|
||||
|
||||
* **Install as a package**:
|
||||
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
|
||||
```bash
|
||||
pip install llama-stack
|
||||
```
|
||||
|
||||
* **Install from source**:
|
||||
If you prefer to install from the source code, we recommend using [uv](https://github.com/astral-sh/uv).
|
||||
Then, run the following commands:
|
||||
```bash
|
||||
git clone git@github.com:meta-llama/llama-stack.git
|
||||
cd llama-stack
|
||||
|
||||
uv sync
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
### Documentation
|
||||
|
||||
|
|
26
docs/_static/llama-stack-spec.html
vendored
26
docs/_static/llama-stack-spec.html
vendored
|
@ -10940,23 +10940,6 @@
|
|||
],
|
||||
"title": "ScoreBatchResponse"
|
||||
},
|
||||
"AlgorithmConfig": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/LoraFinetuningConfig"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/QATFinetuningConfig"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"LoRA": "#/components/schemas/LoraFinetuningConfig",
|
||||
"QAT": "#/components/schemas/QATFinetuningConfig"
|
||||
}
|
||||
}
|
||||
},
|
||||
"LoraFinetuningConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -11092,7 +11075,14 @@
|
|||
"type": "string"
|
||||
},
|
||||
"algorithm_config": {
|
||||
"$ref": "#/components/schemas/AlgorithmConfig"
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/LoraFinetuningConfig"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/QATFinetuningConfig"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
|
13
docs/_static/llama-stack-spec.yaml
vendored
13
docs/_static/llama-stack-spec.yaml
vendored
|
@ -7500,15 +7500,6 @@ components:
|
|||
required:
|
||||
- results
|
||||
title: ScoreBatchResponse
|
||||
AlgorithmConfig:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/LoraFinetuningConfig'
|
||||
- $ref: '#/components/schemas/QATFinetuningConfig'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
LoRA: '#/components/schemas/LoraFinetuningConfig'
|
||||
QAT: '#/components/schemas/QATFinetuningConfig'
|
||||
LoraFinetuningConfig:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -7592,7 +7583,9 @@ components:
|
|||
checkpoint_dir:
|
||||
type: string
|
||||
algorithm_config:
|
||||
$ref: '#/components/schemas/AlgorithmConfig'
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/LoraFinetuningConfig'
|
||||
- $ref: '#/components/schemas/QATFinetuningConfig'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- job_uuid
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
@ -89,7 +89,7 @@ class QATFinetuningConfig(BaseModel):
|
|||
|
||||
|
||||
AlgorithmConfig = register_schema(
|
||||
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
|
||||
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
|
||||
name="AlgorithmConfig",
|
||||
)
|
||||
|
||||
|
@ -184,7 +184,7 @@ class PostTraining(Protocol):
|
|||
description="Model descriptor from `llama model list`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
|
|
|
@ -125,6 +125,13 @@ class LoggingConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class AuthenticationConfig(BaseModel):
|
||||
endpoint: str = Field(
|
||||
...,
|
||||
description="Endpoint URL to validate authentication tokens",
|
||||
)
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
port: int = Field(
|
||||
default=8321,
|
||||
|
@ -140,6 +147,10 @@ class ServerConfig(BaseModel):
|
|||
default=None,
|
||||
description="Path to TLS key file for HTTPS",
|
||||
)
|
||||
auth: Optional[AuthenticationConfig] = Field(
|
||||
default=None,
|
||||
description="Authentication configuration for the server",
|
||||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
|
|
69
llama_stack/distribution/server/auth.py
Normal file
69
llama_stack/distribution/server/auth.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
def __init__(self, app, auth_endpoint):
|
||||
self.app = app
|
||||
self.auth_endpoint = auth_endpoint
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
||||
|
||||
api_key = auth_header.split("Bearer ", 1)[1]
|
||||
|
||||
path = scope.get("path", "")
|
||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
|
||||
auth_data = {
|
||||
"api_key": api_key,
|
||||
"request": {
|
||||
"path": path,
|
||||
"headers": request_headers,
|
||||
"params": params,
|
||||
},
|
||||
}
|
||||
|
||||
# Validate with authentication endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(self.auth_endpoint, json=auth_data)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed: {response.status_code}")
|
||||
return await self._send_auth_error(send, "Authentication failed")
|
||||
except Exception:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send, message):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 401,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
error_msg = json.dumps({"error": {"message": message}}).encode()
|
||||
await send({"type": "http.response.body", "body": error_msg})
|
|
@ -52,6 +52,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
start_trace,
|
||||
)
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
@ -351,6 +352,11 @@ def main():
|
|||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
# Add authentication middleware if configured
|
||||
if config.server.auth and config.server.auth.endpoint:
|
||||
logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError as e:
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def datasets():
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def benchmarks():
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def models():
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def providers():
|
||||
|
|
|
@ -4,14 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from page.distribution.benchmarks import benchmarks
|
||||
from page.distribution.datasets import datasets
|
||||
from page.distribution.models import models
|
||||
from page.distribution.scoring_functions import scoring_functions
|
||||
from page.distribution.shields import shields
|
||||
from page.distribution.vector_dbs import vector_dbs
|
||||
from streamlit_option_menu import option_menu
|
||||
|
||||
from llama_stack.distribution.ui.page.distribution.datasets import datasets
|
||||
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
|
||||
from llama_stack.distribution.ui.page.distribution.models import models
|
||||
from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions
|
||||
from llama_stack.distribution.ui.page.distribution.shields import shields
|
||||
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
|
||||
|
||||
|
||||
def resources_page():
|
||||
options = [
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def scoring_functions():
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def shields():
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def vector_dbs():
|
||||
|
|
|
@ -8,8 +8,9 @@ import json
|
|||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import process_dataset
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.distribution.ui.modules.utils import process_dataset
|
||||
|
||||
|
||||
def application_evaluation_page():
|
||||
|
|
|
@ -8,7 +8,8 @@ import json
|
|||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def select_benchmark_1():
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
# Sidebar configurations
|
||||
with st.sidebar:
|
||||
|
|
|
@ -7,9 +7,10 @@
|
|||
import streamlit as st
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import data_url_from_file
|
||||
from llama_stack_client.types.shared.document import Document
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
||||
|
||||
|
||||
def rag_chat_page():
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
|
@ -213,7 +214,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
|
|||
|
||||
def parse_message(json_str: str) -> ProcessingMessage:
|
||||
data = json.loads(json_str)
|
||||
return ProcessingMessageWrapper(**data).payload
|
||||
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
|
||||
|
||||
|
||||
def worker_process_entrypoint(
|
||||
|
|
|
@ -9,6 +9,9 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.type_system import (
|
||||
ChatCompletionInputType,
|
||||
DialogType,
|
||||
|
@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
validate_dataset_schema,
|
||||
)
|
||||
|
||||
EXPECTED_DATASET_SCHEMA = {
|
||||
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
||||
"instruct": [
|
||||
{
|
||||
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||
|
@ -41,6 +44,9 @@ async def validate_input_dataset_schema(
|
|||
dataset_type: str,
|
||||
) -> None:
|
||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def:
|
||||
raise ValueError(f"Dataset {dataset_id} does not exist.")
|
||||
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class TorchtuneCheckpointer:
|
|||
checkpoint_files: List[str],
|
||||
output_dir: str,
|
||||
model_type: str,
|
||||
) -> None:
|
||||
):
|
||||
# Fail fast if ``checkpoint_files`` is invalid
|
||||
# TODO: support loading more than one file
|
||||
if len(checkpoint_files) != 1:
|
||||
|
@ -58,7 +58,7 @@ class TorchtuneCheckpointer:
|
|||
"""
|
||||
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||
"""
|
||||
state_dict: Dict[str:Any] = {}
|
||||
state_dict: Dict[str, Any] = {}
|
||||
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
|
@ -85,10 +85,10 @@ class TorchtuneCheckpointer:
|
|||
state_dict: Dict[str, Any],
|
||||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
checkpoint_format: str = "meta",
|
||||
checkpoint_format: str | None = None,
|
||||
) -> str:
|
||||
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||
if checkpoint_format == "meta":
|
||||
if checkpoint_format == "meta" or checkpoint_format is None:
|
||||
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
|
||||
elif checkpoint_format == "huggingface":
|
||||
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat
|
|||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_definition: Any
|
||||
tokenizer_type: Any
|
||||
model_definition: BuildLoraModelCallable
|
||||
tokenizer_type: BuildTokenizerCallable
|
||||
checkpoint_type: str
|
||||
|
||||
|
||||
|
@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = {
|
|||
}
|
||||
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
def _validate_model_id(model_id: str) -> Model:
|
||||
model = resolve_model(model_id)
|
||||
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||
|
|
|
@ -55,7 +55,7 @@ class SFTDataset(Dataset):
|
|||
if "messages" in transformed_sample:
|
||||
validate_messages(transformed_sample["messages"])
|
||||
|
||||
tokenized_dict = self._model_transform(transformed_sample)
|
||||
tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample)
|
||||
|
||||
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
|
||||
keys_str = ", ".join(tokenized_dict.keys())
|
||||
|
|
|
@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
QATFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
# Currently logging only logs limited training metrics to local disk
|
||||
# will figure out more loggings and how it works with telemetry in future PRs
|
||||
|
||||
_checkpointer: TorchtuneCheckpointer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
|
@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice:
|
|||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
) -> None:
|
||||
|
@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice:
|
|||
return str(checkpoint_dir)
|
||||
|
||||
if checkpoint_dir and checkpoint_dir != "null":
|
||||
self.checkpoint_dir = config.checkpoint_dir
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
else:
|
||||
model = resolve_model(self.model_id)
|
||||
if model is None:
|
||||
model_obj = resolve_model(self.model_id)
|
||||
if model_obj is None:
|
||||
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
self.checkpoint_dir = model_checkpoint_dir(model_obj)
|
||||
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
self._checkpoint_format = config.checkpoint_format
|
||||
|
@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice:
|
|||
self.max_validation_steps = training_config.max_validation_steps
|
||||
|
||||
self._clip_grad_norm = 1.0
|
||||
self._enable_activation_checkpointing = (
|
||||
(training_config.efficiency_config.enable_activation_checkpointing)
|
||||
if training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
self._enable_activation_offloading = (
|
||||
(training_config.efficiency_config.enable_activation_offloading)
|
||||
if training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
|
||||
self._enable_activation_checkpointing = False
|
||||
self._enable_activation_offloading = False
|
||||
if training_config.efficiency_config:
|
||||
if training_config.efficiency_config.enable_activation_checkpointing:
|
||||
self._enable_activation_checkpointing = (
|
||||
training_config.efficiency_config.enable_activation_checkpointing
|
||||
)
|
||||
if training_config.efficiency_config.enable_activation_offloading:
|
||||
self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading
|
||||
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
|
@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice:
|
|||
"""
|
||||
# Initialize tokens count and running loss (for grad accumulation)
|
||||
t0 = time.perf_counter()
|
||||
running_loss = 0
|
||||
running_loss: float = 0.0
|
||||
num_tokens = 0
|
||||
|
||||
# training artifacts
|
||||
checkpoints = []
|
||||
memory_stats = {}
|
||||
memory_stats: Dict[str, Any] = {}
|
||||
|
||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
|
@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||
current_loss = await self._loss_step(batch) * current_num_tokens
|
||||
running_loss += current_loss
|
||||
running_loss += current_loss.detach().item()
|
||||
current_loss.backward()
|
||||
|
||||
# Step with optimizer
|
||||
|
@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Update the number of steps when the weights are updated
|
||||
self.global_step += 1
|
||||
|
||||
loss_to_log = running_loss.item() / num_tokens
|
||||
loss_to_log = running_loss / num_tokens
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
|
||||
|
@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
|
||||
# Reset running stats for the next step
|
||||
running_loss = 0
|
||||
running_loss = 0.0
|
||||
num_tokens = 0
|
||||
t0 = time.perf_counter()
|
||||
|
||||
|
|
|
@ -228,10 +228,6 @@ exclude = [
|
|||
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
||||
"^llama_stack/providers/inline/inference/vllm/",
|
||||
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
||||
"^llama_stack/providers/inline/post_training/torchtune/common/checkpointer\\.py$",
|
||||
"^llama_stack/providers/inline/post_training/torchtune/common/utils\\.py$",
|
||||
"^llama_stack/providers/inline/post_training/torchtune/datasets/sft\\.py$",
|
||||
"^llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device\\.py$",
|
||||
"^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$",
|
||||
"^llama_stack/providers/inline/safety/code_scanner/",
|
||||
"^llama_stack/providers/inline/safety/llama_guard/",
|
||||
|
|
124
tests/unit/server/test_auth.py
Normal file
124
tests/unit/server/test_auth.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_endpoint():
|
||||
return "http://mock-auth-service/validate"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_api_key():
|
||||
return "valid_api_key_12345"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_api_key():
|
||||
return "invalid_api_key_67890"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_auth_endpoint):
|
||||
app = FastAPI()
|
||||
app.add_middleware(AuthenticationMiddleware, auth_endpoint=mock_auth_endpoint)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
async def mock_post_success(*args, **kwargs):
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
return mock_response
|
||||
|
||||
|
||||
async def mock_post_failure(*args, **kwargs):
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
return mock_response
|
||||
|
||||
|
||||
async def mock_post_exception(*args, **kwargs):
|
||||
raise Exception("Connection error")
|
||||
|
||||
|
||||
def test_missing_auth_header(client):
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format(client):
|
||||
response = client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success)
|
||||
def test_valid_authentication(client, valid_api_key):
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_failure)
|
||||
def test_invalid_authentication(client, invalid_api_key):
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Authentication failed" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_exception)
|
||||
def test_auth_service_error(client, valid_api_key):
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Authentication service error" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client.get(
|
||||
"/test?param1=value1¶m2=value2",
|
||||
headers={
|
||||
"Authorization": f"Bearer {valid_api_key}",
|
||||
"User-Agent": "TestClient",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
# Check that the auth endpoint was called with the correct payload
|
||||
call_args = mock_post.call_args
|
||||
assert call_args is not None
|
||||
|
||||
url, kwargs = call_args[0][0], call_args[1]
|
||||
assert url == mock_auth_endpoint
|
||||
|
||||
payload = kwargs["json"]
|
||||
assert payload["api_key"] == valid_api_key
|
||||
assert payload["request"]["path"] == "/test"
|
||||
assert "authorization" in payload["request"]["headers"]
|
||||
assert "param1" in payload["request"]["params"]
|
||||
assert "param2" in payload["request"]["params"]
|
Loading…
Add table
Add a link
Reference in a new issue