forked from phoenix-oss/llama-stack-mirror
Merge branch 'main' into eval_api_final
This commit is contained in:
commit
24d48b3692
28 changed files with 329 additions and 110 deletions
29
.github/workflows/changelog.yml
vendored
Normal file
29
.github/workflows/changelog.yml
vendored
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
name: Update Changelog
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published, unpublished, created, edited, deleted, released]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
generate_changelog:
|
||||||
|
name: Generate changelog
|
||||||
|
permissions:
|
||||||
|
contents: write # for peter-evans/create-pull-request to create branch
|
||||||
|
pull-requests: write # for peter-evans/create-pull-request to create a PR
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: main
|
||||||
|
fetch-depth: 0
|
||||||
|
- run: |
|
||||||
|
python ./scripts/gen-changelog.py
|
||||||
|
- uses: peter-evans/create-pull-request@v7
|
||||||
|
with:
|
||||||
|
title: 'docs: update CHANGELOG.md for ${{ github.ref_name }}'
|
||||||
|
commit-message: 'docs: update CHANGELOG.md for ${{ github.ref_name }}'
|
||||||
|
branch: create-pull-request/changelog
|
||||||
|
signoff: true
|
20
README.md
20
README.md
|
@ -73,26 +73,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
|
||||||
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) |
|
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) |
|
||||||
| vLLM | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) |
|
| vLLM | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) |
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
You have two ways to install this repository:
|
|
||||||
|
|
||||||
* **Install as a package**:
|
|
||||||
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
|
|
||||||
```bash
|
|
||||||
pip install llama-stack
|
|
||||||
```
|
|
||||||
|
|
||||||
* **Install from source**:
|
|
||||||
If you prefer to install from the source code, we recommend using [uv](https://github.com/astral-sh/uv).
|
|
||||||
Then, run the following commands:
|
|
||||||
```bash
|
|
||||||
git clone git@github.com:meta-llama/llama-stack.git
|
|
||||||
cd llama-stack
|
|
||||||
|
|
||||||
uv sync
|
|
||||||
uv pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
### Documentation
|
### Documentation
|
||||||
|
|
||||||
|
|
26
docs/_static/llama-stack-spec.html
vendored
26
docs/_static/llama-stack-spec.html
vendored
|
@ -10940,23 +10940,6 @@
|
||||||
],
|
],
|
||||||
"title": "ScoreBatchResponse"
|
"title": "ScoreBatchResponse"
|
||||||
},
|
},
|
||||||
"AlgorithmConfig": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/LoraFinetuningConfig"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/QATFinetuningConfig"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"discriminator": {
|
|
||||||
"propertyName": "type",
|
|
||||||
"mapping": {
|
|
||||||
"LoRA": "#/components/schemas/LoraFinetuningConfig",
|
|
||||||
"QAT": "#/components/schemas/QATFinetuningConfig"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"LoraFinetuningConfig": {
|
"LoraFinetuningConfig": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -11092,7 +11075,14 @@
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"algorithm_config": {
|
"algorithm_config": {
|
||||||
"$ref": "#/components/schemas/AlgorithmConfig"
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/LoraFinetuningConfig"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/QATFinetuningConfig"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
13
docs/_static/llama-stack-spec.yaml
vendored
13
docs/_static/llama-stack-spec.yaml
vendored
|
@ -7500,15 +7500,6 @@ components:
|
||||||
required:
|
required:
|
||||||
- results
|
- results
|
||||||
title: ScoreBatchResponse
|
title: ScoreBatchResponse
|
||||||
AlgorithmConfig:
|
|
||||||
oneOf:
|
|
||||||
- $ref: '#/components/schemas/LoraFinetuningConfig'
|
|
||||||
- $ref: '#/components/schemas/QATFinetuningConfig'
|
|
||||||
discriminator:
|
|
||||||
propertyName: type
|
|
||||||
mapping:
|
|
||||||
LoRA: '#/components/schemas/LoraFinetuningConfig'
|
|
||||||
QAT: '#/components/schemas/QATFinetuningConfig'
|
|
||||||
LoraFinetuningConfig:
|
LoraFinetuningConfig:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -7592,7 +7583,9 @@ components:
|
||||||
checkpoint_dir:
|
checkpoint_dir:
|
||||||
type: string
|
type: string
|
||||||
algorithm_config:
|
algorithm_config:
|
||||||
$ref: '#/components/schemas/AlgorithmConfig'
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/LoraFinetuningConfig'
|
||||||
|
- $ref: '#/components/schemas/QATFinetuningConfig'
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- job_uuid
|
- job_uuid
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
@ -89,7 +89,7 @@ class QATFinetuningConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
AlgorithmConfig = register_schema(
|
AlgorithmConfig = register_schema(
|
||||||
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
|
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
|
||||||
name="AlgorithmConfig",
|
name="AlgorithmConfig",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -184,7 +184,7 @@ class PostTraining(Protocol):
|
||||||
description="Model descriptor from `llama model list`",
|
description="Model descriptor from `llama model list`",
|
||||||
),
|
),
|
||||||
checkpoint_dir: Optional[str] = None,
|
checkpoint_dir: Optional[str] = None,
|
||||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||||
|
|
|
@ -125,6 +125,13 @@ class LoggingConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationConfig(BaseModel):
|
||||||
|
endpoint: str = Field(
|
||||||
|
...,
|
||||||
|
description="Endpoint URL to validate authentication tokens",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ServerConfig(BaseModel):
|
class ServerConfig(BaseModel):
|
||||||
port: int = Field(
|
port: int = Field(
|
||||||
default=8321,
|
default=8321,
|
||||||
|
@ -140,6 +147,10 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to TLS key file for HTTPS",
|
description="Path to TLS key file for HTTPS",
|
||||||
)
|
)
|
||||||
|
auth: Optional[AuthenticationConfig] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Authentication configuration for the server",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
|
69
llama_stack/distribution/server/auth.py
Normal file
69
llama_stack/distribution/server/auth.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationMiddleware:
|
||||||
|
def __init__(self, app, auth_endpoint):
|
||||||
|
self.app = app
|
||||||
|
self.auth_endpoint = auth_endpoint
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
if scope["type"] == "http":
|
||||||
|
headers = dict(scope.get("headers", []))
|
||||||
|
auth_header = headers.get(b"authorization", b"").decode()
|
||||||
|
|
||||||
|
if not auth_header or not auth_header.startswith("Bearer "):
|
||||||
|
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
||||||
|
|
||||||
|
api_key = auth_header.split("Bearer ", 1)[1]
|
||||||
|
|
||||||
|
path = scope.get("path", "")
|
||||||
|
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||||
|
|
||||||
|
query_string = scope.get("query_string", b"").decode()
|
||||||
|
params = parse_qs(query_string)
|
||||||
|
|
||||||
|
auth_data = {
|
||||||
|
"api_key": api_key,
|
||||||
|
"request": {
|
||||||
|
"path": path,
|
||||||
|
"headers": request_headers,
|
||||||
|
"params": params,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Validate with authentication endpoint
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(self.auth_endpoint, json=auth_data)
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.warning(f"Authentication failed: {response.status_code}")
|
||||||
|
return await self._send_auth_error(send, "Authentication failed")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error during authentication")
|
||||||
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
async def _send_auth_error(self, send, message):
|
||||||
|
await send(
|
||||||
|
{
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 401,
|
||||||
|
"headers": [[b"content-type", b"application/json"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
error_msg = json.dumps({"error": {"message": message}}).encode()
|
||||||
|
await send({"type": "http.response.body", "body": error_msg})
|
|
@ -52,6 +52,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .auth import AuthenticationMiddleware
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
@ -351,6 +352,11 @@ def main():
|
||||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||||
app.add_middleware(ClientVersionMiddleware)
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
|
# Add authentication middleware if configured
|
||||||
|
if config.server.auth and config.server.auth.endpoint:
|
||||||
|
logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}")
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
except InvalidProviderError as e:
|
except InvalidProviderError as e:
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from modules.api import llama_stack_api
|
|
||||||
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
def datasets():
|
def datasets():
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from modules.api import llama_stack_api
|
|
||||||
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
def benchmarks():
|
def benchmarks():
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from modules.api import llama_stack_api
|
|
||||||
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
def models():
|
def models():
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from modules.api import llama_stack_api
|
|
||||||
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
def providers():
|
def providers():
|
||||||
|
|
|
@ -4,14 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from page.distribution.benchmarks import benchmarks
|
|
||||||
from page.distribution.datasets import datasets
|
|
||||||
from page.distribution.models import models
|
|
||||||
from page.distribution.scoring_functions import scoring_functions
|
|
||||||
from page.distribution.shields import shields
|
|
||||||
from page.distribution.vector_dbs import vector_dbs
|
|
||||||
from streamlit_option_menu import option_menu
|
from streamlit_option_menu import option_menu
|
||||||
|
|
||||||
|
from llama_stack.distribution.ui.page.distribution.datasets import datasets
|
||||||
|
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
|
||||||
|
from llama_stack.distribution.ui.page.distribution.models import models
|
||||||
|
from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions
|
||||||
|
from llama_stack.distribution.ui.page.distribution.shields import shields
|
||||||
|
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
|
||||||
|
|
||||||
|
|
||||||
def resources_page():
|
def resources_page():
|
||||||
options = [
|
options = [
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from modules.api import llama_stack_api
|
|
||||||
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
def scoring_functions():
|
def scoring_functions():
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from modules.api import llama_stack_api
|
|
||||||
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
def shields():
|
def shields():
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from modules.api import llama_stack_api
|
|
||||||
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
def vector_dbs():
|
def vector_dbs():
|
||||||
|
|
|
@ -8,8 +8,9 @@ import json
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from modules.api import llama_stack_api
|
|
||||||
from modules.utils import process_dataset
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
from llama_stack.distribution.ui.modules.utils import process_dataset
|
||||||
|
|
||||||
|
|
||||||
def application_evaluation_page():
|
def application_evaluation_page():
|
||||||
|
|
|
@ -8,7 +8,8 @@ import json
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from modules.api import llama_stack_api
|
|
||||||
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
def select_benchmark_1():
|
def select_benchmark_1():
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from modules.api import llama_stack_api
|
|
||||||
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
# Sidebar configurations
|
# Sidebar configurations
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
|
|
|
@ -7,9 +7,10 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
from llama_stack_client.types.memory_insert_params import Document
|
from llama_stack_client.types.shared.document import Document
|
||||||
from modules.api import llama_stack_api
|
|
||||||
from modules.utils import data_url_from_file
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
||||||
|
|
||||||
|
|
||||||
def rag_chat_page():
|
def rag_chat_page():
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
@ -213,7 +214,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
|
||||||
|
|
||||||
def parse_message(json_str: str) -> ProcessingMessage:
|
def parse_message(json_str: str) -> ProcessingMessage:
|
||||||
data = json.loads(json_str)
|
data = json.loads(json_str)
|
||||||
return ProcessingMessageWrapper(**data).payload
|
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
|
||||||
|
|
||||||
|
|
||||||
def worker_process_entrypoint(
|
def worker_process_entrypoint(
|
||||||
|
|
|
@ -9,6 +9,9 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import (
|
from llama_stack.apis.common.type_system import (
|
||||||
ChatCompletionInputType,
|
ChatCompletionInputType,
|
||||||
DialogType,
|
DialogType,
|
||||||
|
@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
||||||
validate_dataset_schema,
|
validate_dataset_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
EXPECTED_DATASET_SCHEMA = {
|
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
||||||
"instruct": [
|
"instruct": [
|
||||||
{
|
{
|
||||||
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||||
|
@ -41,6 +44,9 @@ async def validate_input_dataset_schema(
|
||||||
dataset_type: str,
|
dataset_type: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
|
if not dataset_def:
|
||||||
|
raise ValueError(f"Dataset {dataset_id} does not exist.")
|
||||||
|
|
||||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ class TorchtuneCheckpointer:
|
||||||
checkpoint_files: List[str],
|
checkpoint_files: List[str],
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
) -> None:
|
):
|
||||||
# Fail fast if ``checkpoint_files`` is invalid
|
# Fail fast if ``checkpoint_files`` is invalid
|
||||||
# TODO: support loading more than one file
|
# TODO: support loading more than one file
|
||||||
if len(checkpoint_files) != 1:
|
if len(checkpoint_files) != 1:
|
||||||
|
@ -58,7 +58,7 @@ class TorchtuneCheckpointer:
|
||||||
"""
|
"""
|
||||||
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||||
"""
|
"""
|
||||||
state_dict: Dict[str:Any] = {}
|
state_dict: Dict[str, Any] = {}
|
||||||
model_state_dict = safe_torch_load(self._checkpoint_path)
|
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||||
if self._model_type == ModelType.LLAMA3_VISION:
|
if self._model_type == ModelType.LLAMA3_VISION:
|
||||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||||
|
@ -85,10 +85,10 @@ class TorchtuneCheckpointer:
|
||||||
state_dict: Dict[str, Any],
|
state_dict: Dict[str, Any],
|
||||||
epoch: int,
|
epoch: int,
|
||||||
adapter_only: bool = False,
|
adapter_only: bool = False,
|
||||||
checkpoint_format: str = "meta",
|
checkpoint_format: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||||
if checkpoint_format == "meta":
|
if checkpoint_format == "meta" or checkpoint_format is None:
|
||||||
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
|
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
|
||||||
elif checkpoint_format == "huggingface":
|
elif checkpoint_format == "huggingface":
|
||||||
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Callable, Dict
|
from typing import Callable, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat
|
||||||
from llama_stack.models.llama.datatypes import Model
|
from llama_stack.models.llama.datatypes import Model
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
|
||||||
|
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||||
|
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
model_definition: Any
|
model_definition: BuildLoraModelCallable
|
||||||
tokenizer_type: Any
|
tokenizer_type: BuildTokenizerCallable
|
||||||
checkpoint_type: str
|
checkpoint_type: str
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
|
||||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_model_id(model_id: str) -> Model:
|
def _validate_model_id(model_id: str) -> Model:
|
||||||
model = resolve_model(model_id)
|
model = resolve_model(model_id)
|
||||||
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||||
|
|
|
@ -55,7 +55,7 @@ class SFTDataset(Dataset):
|
||||||
if "messages" in transformed_sample:
|
if "messages" in transformed_sample:
|
||||||
validate_messages(transformed_sample["messages"])
|
validate_messages(transformed_sample["messages"])
|
||||||
|
|
||||||
tokenized_dict = self._model_transform(transformed_sample)
|
tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample)
|
||||||
|
|
||||||
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
|
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
|
||||||
keys_str = ", ".join(tokenized_dict.keys())
|
keys_str = ", ".join(tokenized_dict.keys())
|
||||||
|
|
|
@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
AlgorithmConfig,
|
|
||||||
Checkpoint,
|
Checkpoint,
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
OptimizerConfig,
|
OptimizerConfig,
|
||||||
|
QATFinetuningConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
|
@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice:
|
||||||
|
|
||||||
# Currently logging only logs limited training metrics to local disk
|
# Currently logging only logs limited training metrics to local disk
|
||||||
# will figure out more loggings and how it works with telemetry in future PRs
|
# will figure out more loggings and how it works with telemetry in future PRs
|
||||||
|
|
||||||
|
_checkpointer: TorchtuneCheckpointer
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: TorchtunePostTrainingConfig,
|
config: TorchtunePostTrainingConfig,
|
||||||
|
@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice:
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
model: str,
|
model: str,
|
||||||
checkpoint_dir: Optional[str],
|
checkpoint_dir: Optional[str],
|
||||||
algorithm_config: Optional[AlgorithmConfig],
|
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
datasets_api: Datasets,
|
datasets_api: Datasets,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice:
|
||||||
return str(checkpoint_dir)
|
return str(checkpoint_dir)
|
||||||
|
|
||||||
if checkpoint_dir and checkpoint_dir != "null":
|
if checkpoint_dir and checkpoint_dir != "null":
|
||||||
self.checkpoint_dir = config.checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
else:
|
else:
|
||||||
model = resolve_model(self.model_id)
|
model_obj = resolve_model(self.model_id)
|
||||||
if model is None:
|
if model_obj is None:
|
||||||
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
|
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
|
||||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
self.checkpoint_dir = model_checkpoint_dir(model_obj)
|
||||||
|
|
||||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||||
self._checkpoint_format = config.checkpoint_format
|
self._checkpoint_format = config.checkpoint_format
|
||||||
|
@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice:
|
||||||
self.max_validation_steps = training_config.max_validation_steps
|
self.max_validation_steps = training_config.max_validation_steps
|
||||||
|
|
||||||
self._clip_grad_norm = 1.0
|
self._clip_grad_norm = 1.0
|
||||||
self._enable_activation_checkpointing = (
|
|
||||||
(training_config.efficiency_config.enable_activation_checkpointing)
|
self._enable_activation_checkpointing = False
|
||||||
if training_config.efficiency_config
|
self._enable_activation_offloading = False
|
||||||
else False
|
if training_config.efficiency_config:
|
||||||
)
|
if training_config.efficiency_config.enable_activation_checkpointing:
|
||||||
self._enable_activation_offloading = (
|
self._enable_activation_checkpointing = (
|
||||||
(training_config.efficiency_config.enable_activation_offloading)
|
training_config.efficiency_config.enable_activation_checkpointing
|
||||||
if training_config.efficiency_config
|
)
|
||||||
else False
|
if training_config.efficiency_config.enable_activation_offloading:
|
||||||
)
|
self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading
|
||||||
|
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets_api
|
self.datasets_api = datasets_api
|
||||||
|
@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice:
|
||||||
"""
|
"""
|
||||||
# Initialize tokens count and running loss (for grad accumulation)
|
# Initialize tokens count and running loss (for grad accumulation)
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
running_loss = 0
|
running_loss: float = 0.0
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
|
|
||||||
# training artifacts
|
# training artifacts
|
||||||
checkpoints = []
|
checkpoints = []
|
||||||
memory_stats = {}
|
memory_stats: Dict[str, Any] = {}
|
||||||
|
|
||||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||||
|
@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice:
|
||||||
# Loss is normalized by default so we multiply by the number of tokens
|
# Loss is normalized by default so we multiply by the number of tokens
|
||||||
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||||
current_loss = await self._loss_step(batch) * current_num_tokens
|
current_loss = await self._loss_step(batch) * current_num_tokens
|
||||||
running_loss += current_loss
|
running_loss += current_loss.detach().item()
|
||||||
current_loss.backward()
|
current_loss.backward()
|
||||||
|
|
||||||
# Step with optimizer
|
# Step with optimizer
|
||||||
|
@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice:
|
||||||
# Update the number of steps when the weights are updated
|
# Update the number of steps when the weights are updated
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
|
||||||
loss_to_log = running_loss.item() / num_tokens
|
loss_to_log = running_loss / num_tokens
|
||||||
|
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
|
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
|
||||||
|
@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset running stats for the next step
|
# Reset running stats for the next step
|
||||||
running_loss = 0
|
running_loss = 0.0
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
|
|
|
@ -228,10 +228,6 @@ exclude = [
|
||||||
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
||||||
"^llama_stack/providers/inline/inference/vllm/",
|
"^llama_stack/providers/inline/inference/vllm/",
|
||||||
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
||||||
"^llama_stack/providers/inline/post_training/torchtune/common/checkpointer\\.py$",
|
|
||||||
"^llama_stack/providers/inline/post_training/torchtune/common/utils\\.py$",
|
|
||||||
"^llama_stack/providers/inline/post_training/torchtune/datasets/sft\\.py$",
|
|
||||||
"^llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device\\.py$",
|
|
||||||
"^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$",
|
"^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$",
|
||||||
"^llama_stack/providers/inline/safety/code_scanner/",
|
"^llama_stack/providers/inline/safety/code_scanner/",
|
||||||
"^llama_stack/providers/inline/safety/llama_guard/",
|
"^llama_stack/providers/inline/safety/llama_guard/",
|
||||||
|
|
124
tests/unit/server/test_auth.py
Normal file
124
tests/unit/server/test_auth.py
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_auth_endpoint():
|
||||||
|
return "http://mock-auth-service/validate"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_api_key():
|
||||||
|
return "valid_api_key_12345"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def invalid_api_key():
|
||||||
|
return "invalid_api_key_67890"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(mock_auth_endpoint):
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_endpoint=mock_auth_endpoint)
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(app):
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_post_success(*args, **kwargs):
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_post_failure(*args, **kwargs):
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 401
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_post_exception(*args, **kwargs):
|
||||||
|
raise Exception("Connection error")
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_auth_header(client):
|
||||||
|
response = client.get("/test")
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_auth_header_format(client):
|
||||||
|
response = client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_post_success)
|
||||||
|
def test_valid_authentication(client, valid_api_key):
|
||||||
|
response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_post_failure)
|
||||||
|
def test_invalid_authentication(client, invalid_api_key):
|
||||||
|
response = client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Authentication failed" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_post_exception)
|
||||||
|
def test_auth_service_error(client, valid_api_key):
|
||||||
|
response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Authentication service error" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
||||||
|
with patch("httpx.AsyncClient.post") as mock_post:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
client.get(
|
||||||
|
"/test?param1=value1¶m2=value2",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {valid_api_key}",
|
||||||
|
"User-Agent": "TestClient",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the auth endpoint was called with the correct payload
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
assert call_args is not None
|
||||||
|
|
||||||
|
url, kwargs = call_args[0][0], call_args[1]
|
||||||
|
assert url == mock_auth_endpoint
|
||||||
|
|
||||||
|
payload = kwargs["json"]
|
||||||
|
assert payload["api_key"] == valid_api_key
|
||||||
|
assert payload["request"]["path"] == "/test"
|
||||||
|
assert "authorization" in payload["request"]["headers"]
|
||||||
|
assert "param1" in payload["request"]["params"]
|
||||||
|
assert "param2" in payload["request"]["params"]
|
Loading…
Add table
Add a link
Reference in a new issue