mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: Add nemo customizer (#1448)
# What does this PR do? This PR adds support for NVIDIA's NeMo Customizer API to the Llama Stack post-training module. The integration enables users to fine-tune models using NVIDIA's cloud-based customization service through a consistent Llama Stack interface. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] Yet to be done Things pending under this PR: - [x] Integration of fine-tuned model(new checkpoint) for inference with nvidia llm distribution - [x] distribution integration of API - [x] Add test cases for customizer(In Progress) - [x] Documentation ``` LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -v tests/client-sdk/post_training/test_supervised_fine_tuning.py ============================================================================================================================================================================ test session starts ============================================================================================================================================================================= platform linux -- Python 3.10.0, pytest-8.3.4, pluggy-1.5.0 -- /home/ubuntu/llama-stack/.venv/bin/python cachedir: .pytest_cache metadata: {'Python': '3.10.0', 'Platform': 'Linux-6.8.0-1021-gcp-x86_64-with-glibc2.35', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'nbval': '0.11.0', 'metadata': '3.1.1', 'anyio': '4.8.0', 'html': '4.1.1', 'asyncio': '0.25.3'}} rootdir: /home/ubuntu/llama-stack configfile: pyproject.toml plugins: nbval-0.11.0, metadata-3.1.1, anyio-4.8.0, html-4.1.1, asyncio-0.25.3 asyncio: mode=strict, asyncio_default_fixture_loop_scope=None collected 2 items tests/client-sdk/post_training/test_supervised_fine_tuning.py::test_post_training_provider_registration[txt=8B] PASSED [ 50%] tests/client-sdk/post_training/test_supervised_fine_tuning.py::test_list_training_jobs[txt=8B] PASSED [100%] ======================================================================================================================================================================== 2 passed, 1 warning in 0.10s ======================================================================================================================================================================== ``` cc: @mattf @dglogo @sumitb --------- Co-authored-by: Ubuntu <ubuntu@llama-stack-customizer-dev-inst-2tx95fyisatvlic4we8hidx5tfj.us-central1-a.c.brevdevprod.internal>
This commit is contained in:
parent
ba14552a32
commit
1a73f8305b
26 changed files with 1571 additions and 8 deletions
|
@ -433,6 +433,7 @@
|
||||||
"zmq"
|
"zmq"
|
||||||
],
|
],
|
||||||
"nvidia": [
|
"nvidia": [
|
||||||
|
"aiohttp",
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
"chardet",
|
"chardet",
|
||||||
|
|
6
docs/_static/llama-stack-spec.html
vendored
6
docs/_static/llama-stack-spec.html
vendored
|
@ -7671,7 +7671,8 @@
|
||||||
"completed",
|
"completed",
|
||||||
"in_progress",
|
"in_progress",
|
||||||
"failed",
|
"failed",
|
||||||
"scheduled"
|
"scheduled",
|
||||||
|
"cancelled"
|
||||||
],
|
],
|
||||||
"title": "JobStatus"
|
"title": "JobStatus"
|
||||||
},
|
},
|
||||||
|
@ -8135,7 +8136,8 @@
|
||||||
"completed",
|
"completed",
|
||||||
"in_progress",
|
"in_progress",
|
||||||
"failed",
|
"failed",
|
||||||
"scheduled"
|
"scheduled",
|
||||||
|
"cancelled"
|
||||||
],
|
],
|
||||||
"title": "JobStatus"
|
"title": "JobStatus"
|
||||||
}
|
}
|
||||||
|
|
2
docs/_static/llama-stack-spec.yaml
vendored
2
docs/_static/llama-stack-spec.yaml
vendored
|
@ -5306,6 +5306,7 @@ components:
|
||||||
- in_progress
|
- in_progress
|
||||||
- failed
|
- failed
|
||||||
- scheduled
|
- scheduled
|
||||||
|
- cancelled
|
||||||
title: JobStatus
|
title: JobStatus
|
||||||
scheduled_at:
|
scheduled_at:
|
||||||
type: string
|
type: string
|
||||||
|
@ -5583,6 +5584,7 @@ components:
|
||||||
- in_progress
|
- in_progress
|
||||||
- failed
|
- failed
|
||||||
- scheduled
|
- scheduled
|
||||||
|
- cancelled
|
||||||
title: JobStatus
|
title: JobStatus
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
|
|
|
@ -9,6 +9,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
| datasetio | `inline::localfs` |
|
| datasetio | `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::nvidia` |
|
| inference | `remote::nvidia` |
|
||||||
|
| post_training | `remote::nvidia` |
|
||||||
| safety | `remote::nvidia` |
|
| safety | `remote::nvidia` |
|
||||||
| scoring | `inline::basic` |
|
| scoring | `inline::basic` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
@ -21,6 +22,12 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||||
|
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
|
||||||
|
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||||
|
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
|
||||||
|
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||||
|
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||||
|
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||||
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
||||||
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
||||||
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
||||||
|
|
|
@ -15,6 +15,7 @@ class JobStatus(Enum):
|
||||||
in_progress = "in_progress"
|
in_progress = "in_progress"
|
||||||
failed = "failed"
|
failed = "failed"
|
||||||
scheduled = "scheduled"
|
scheduled = "scheduled"
|
||||||
|
cancelled = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
@ -22,4 +22,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.datasets,
|
Api.datasets,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.post_training,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="nvidia",
|
||||||
|
pip_packages=["requests", "aiohttp"],
|
||||||
|
module="llama_stack.providers.remote.post_training.nvidia",
|
||||||
|
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -55,7 +55,7 @@ from .openai_utils import (
|
||||||
convert_openai_completion_choice,
|
convert_openai_completion_choice,
|
||||||
convert_openai_completion_stream,
|
convert_openai_completion_stream,
|
||||||
)
|
)
|
||||||
from .utils import _is_nvidia_hosted, check_health
|
from .utils import _is_nvidia_hosted
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -134,7 +134,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
if content_has_media(content):
|
if content_has_media(content):
|
||||||
raise NotImplementedError("Media is not supported")
|
raise NotImplementedError("Media is not supported")
|
||||||
|
|
||||||
await check_health(self._config) # this raises errors
|
# ToDo: check health of NeMo endpoints and enable this
|
||||||
|
# removing this health check as NeMo customizer endpoint health check is returning 404
|
||||||
|
# await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
provider_model_id = self.get_provider_model_id(model_id)
|
provider_model_id = self.get_provider_model_id(model_id)
|
||||||
request = convert_completion_request(
|
request = convert_completion_request(
|
||||||
|
@ -236,7 +238,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
if tool_prompt_format:
|
if tool_prompt_format:
|
||||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
|
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
|
||||||
|
|
||||||
await check_health(self._config) # this raises errors
|
# await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
provider_model_id = self.get_provider_model_id(model_id)
|
provider_model_id = self.get_provider_model_id(model_id)
|
||||||
request = await convert_chat_completion_request(
|
request = await convert_chat_completion_request(
|
||||||
|
|
5
llama_stack/providers/remote/post_training/__init__.py
Normal file
5
llama_stack/providers/remote/post_training/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
138
llama_stack/providers/remote/post_training/nvidia/README.md
Normal file
138
llama_stack/providers/remote/post_training/nvidia/README.md
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
# NVIDIA Post-Training Provider for LlamaStack
|
||||||
|
|
||||||
|
This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Supervised fine-tuning of Llama models
|
||||||
|
- LoRA fine-tuning support
|
||||||
|
- Job management and status tracking
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- LlamaStack with NVIDIA configuration
|
||||||
|
- Access to Hosted NVIDIA NeMo Customizer service
|
||||||
|
- Dataset registered in the Hosted NVIDIA NeMo Customizer service
|
||||||
|
- Base model downloaded and available in the Hosted NVIDIA NeMo Customizer service
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
Build the NVIDIA environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template nvidia --image-type conda
|
||||||
|
```
|
||||||
|
|
||||||
|
### Basic Usage using the LlamaStack Python Client
|
||||||
|
|
||||||
|
### Create Customization Job
|
||||||
|
|
||||||
|
#### Initialize the client
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||||
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||||
|
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
|
||||||
|
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
||||||
|
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
||||||
|
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
||||||
|
|
||||||
|
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
client.initialize()
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Configure fine-tuning parameters
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||||
|
TrainingConfig,
|
||||||
|
TrainingConfigDataConfig,
|
||||||
|
TrainingConfigOptimizerConfig,
|
||||||
|
)
|
||||||
|
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Set up LoRA configuration
|
||||||
|
|
||||||
|
```python
|
||||||
|
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Configure training data
|
||||||
|
|
||||||
|
```python
|
||||||
|
data_config = TrainingConfigDataConfig(
|
||||||
|
dataset_id="your-dataset-id", # Use client.datasets.list() to see available datasets
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Configure optimizer
|
||||||
|
|
||||||
|
```python
|
||||||
|
optimizer_config = TrainingConfigOptimizerConfig(
|
||||||
|
lr=0.0001,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Set up training configuration
|
||||||
|
|
||||||
|
```python
|
||||||
|
training_config = TrainingConfig(
|
||||||
|
n_epochs=2,
|
||||||
|
data_config=data_config,
|
||||||
|
optimizer_config=optimizer_config,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Start fine-tuning job
|
||||||
|
|
||||||
|
```python
|
||||||
|
training_job = client.post_training.supervised_fine_tune(
|
||||||
|
job_uuid="unique-job-id",
|
||||||
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
checkpoint_dir="",
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
training_config=training_config,
|
||||||
|
logger_config={},
|
||||||
|
hyperparam_search_config={},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### List all jobs
|
||||||
|
|
||||||
|
```python
|
||||||
|
jobs = client.post_training.job.list()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Check job status
|
||||||
|
|
||||||
|
```python
|
||||||
|
job_status = client.post_training.job.status(job_uuid="your-job-id")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cancel a job
|
||||||
|
|
||||||
|
```python
|
||||||
|
client.post_training.job.cancel(job_uuid="your-job-id")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Inference with the fine-tuned model
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = client.inference.completion(
|
||||||
|
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||||
|
stream=False,
|
||||||
|
model_id="test-example-model@v1",
|
||||||
|
sampling_params={
|
||||||
|
"max_tokens": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.content)
|
||||||
|
```
|
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import NvidiaPostTrainingConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(
|
||||||
|
config: NvidiaPostTrainingConfig,
|
||||||
|
_deps,
|
||||||
|
):
|
||||||
|
from .post_training import NvidiaPostTrainingAdapter
|
||||||
|
|
||||||
|
if not isinstance(config, NvidiaPostTrainingConfig):
|
||||||
|
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||||
|
|
||||||
|
impl = NvidiaPostTrainingAdapter(config)
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["get_adapter_impl", "NvidiaPostTrainingAdapter"]
|
113
llama_stack/providers/remote/post_training/nvidia/config.py
Normal file
113
llama_stack/providers/remote/post_training/nvidia/config.py
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# TODO: add default values for all fields
|
||||||
|
|
||||||
|
|
||||||
|
class NvidiaPostTrainingConfig(BaseModel):
|
||||||
|
"""Configuration for NVIDIA Post Training implementation."""
|
||||||
|
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
||||||
|
description="The NVIDIA API key.",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_namespace: Optional[str] = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
|
||||||
|
description="The NVIDIA dataset namespace.",
|
||||||
|
)
|
||||||
|
|
||||||
|
project_id: Optional[str] = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-example-model@v1"),
|
||||||
|
description="The NVIDIA project ID.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ToDO: validate this, add default value
|
||||||
|
customizer_url: Optional[str] = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_CUSTOMIZER_URL"),
|
||||||
|
description="Base URL for the NeMo Customizer API",
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout: int = Field(
|
||||||
|
default=300,
|
||||||
|
description="Timeout for the NVIDIA Post Training API",
|
||||||
|
)
|
||||||
|
|
||||||
|
max_retries: int = Field(
|
||||||
|
default=3,
|
||||||
|
description="Maximum number of retries for the NVIDIA Post Training API",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ToDo: validate this
|
||||||
|
output_model_dir: str = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"),
|
||||||
|
description="Directory to save the output model",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"api_key": "${env.NVIDIA_API_KEY:}",
|
||||||
|
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
|
||||||
|
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
|
||||||
|
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SFTLoRADefaultConfig(BaseModel):
|
||||||
|
"""NVIDIA-specific training configuration with default values."""
|
||||||
|
|
||||||
|
# ToDo: split into SFT and LoRA configs??
|
||||||
|
|
||||||
|
# General training parameters
|
||||||
|
n_epochs: int = 50
|
||||||
|
|
||||||
|
# NeMo customizer specific parameters
|
||||||
|
log_every_n_steps: Optional[int] = None
|
||||||
|
val_check_interval: float = 0.25
|
||||||
|
sequence_packing_enabled: bool = False
|
||||||
|
weight_decay: float = 0.01
|
||||||
|
lr: float = 0.0001
|
||||||
|
|
||||||
|
# SFT specific parameters
|
||||||
|
hidden_dropout: Optional[float] = None
|
||||||
|
attention_dropout: Optional[float] = None
|
||||||
|
ffn_dropout: Optional[float] = None
|
||||||
|
|
||||||
|
# LoRA default parameters
|
||||||
|
lora_adapter_dim: int = 8
|
||||||
|
lora_adapter_dropout: Optional[float] = None
|
||||||
|
lora_alpha: int = 16
|
||||||
|
|
||||||
|
# Data config
|
||||||
|
batch_size: int = 8
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_config(cls) -> Dict[str, Any]:
|
||||||
|
"""Return a sample configuration for NVIDIA training."""
|
||||||
|
return {
|
||||||
|
"n_epochs": 50,
|
||||||
|
"log_every_n_steps": 10,
|
||||||
|
"val_check_interval": 0.25,
|
||||||
|
"sequence_packing_enabled": False,
|
||||||
|
"weight_decay": 0.01,
|
||||||
|
"hidden_dropout": 0.1,
|
||||||
|
"attention_dropout": 0.1,
|
||||||
|
"lora_adapter_dim": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"data_config": {
|
||||||
|
"dataset_id": "default",
|
||||||
|
"batch_size": 8,
|
||||||
|
},
|
||||||
|
"optimizer_config": {
|
||||||
|
"lr": 0.0001,
|
||||||
|
},
|
||||||
|
}
|
24
llama_stack/providers/remote/post_training/nvidia/models.py
Normal file
24
llama_stack/providers/remote/post_training/nvidia/models.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ProviderModelEntry,
|
||||||
|
build_hf_repo_model_entry,
|
||||||
|
)
|
||||||
|
|
||||||
|
_MODEL_ENTRIES = [
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama-3.1-8b-instruct",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_entries() -> List[ProviderModelEntry]:
|
||||||
|
return _MODEL_ENTRIES
|
|
@ -0,0 +1,439 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import warnings
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from llama_stack.apis.post_training import (
|
||||||
|
AlgorithmConfig,
|
||||||
|
DPOAlignmentConfig,
|
||||||
|
JobStatus,
|
||||||
|
PostTrainingJob,
|
||||||
|
PostTrainingJobArtifactsResponse,
|
||||||
|
PostTrainingJobStatusResponse,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
|
||||||
|
from llama_stack.providers.remote.post_training.nvidia.utils import warn_unsupported_params
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
|
from .models import _MODEL_ENTRIES
|
||||||
|
|
||||||
|
# Map API status to JobStatus enum
|
||||||
|
STATUS_MAPPING = {
|
||||||
|
"running": "in_progress",
|
||||||
|
"completed": "completed",
|
||||||
|
"failed": "failed",
|
||||||
|
"cancelled": "cancelled",
|
||||||
|
"pending": "scheduled",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NvidiaPostTrainingJob(PostTrainingJob):
|
||||||
|
"""Parse the response from the Customizer API.
|
||||||
|
Inherits job_uuid from PostTrainingJob.
|
||||||
|
Adds status, created_at, updated_at parameters.
|
||||||
|
Passes through all other parameters from data field in the response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
status: JobStatus
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ListNvidiaPostTrainingJobs(BaseModel):
|
||||||
|
data: List[NvidiaPostTrainingJob]
|
||||||
|
|
||||||
|
|
||||||
|
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
def __init__(self, config: NvidiaPostTrainingConfig):
|
||||||
|
self.config = config
|
||||||
|
self.headers = {}
|
||||||
|
if config.api_key:
|
||||||
|
self.headers["Authorization"] = f"Bearer {config.api_key}"
|
||||||
|
|
||||||
|
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
||||||
|
# TODO: filter by available models based on /config endpoint
|
||||||
|
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||||
|
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
||||||
|
self.customizer_url = config.customizer_url
|
||||||
|
|
||||||
|
if not self.customizer_url:
|
||||||
|
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
||||||
|
self.customizer_url = "http://nemo.test"
|
||||||
|
|
||||||
|
async def _make_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
headers: Optional[Dict[str, Any]] = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
json: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Helper method to make HTTP requests to the Customizer API."""
|
||||||
|
url = f"{self.customizer_url}{path}"
|
||||||
|
request_headers = self.headers.copy()
|
||||||
|
|
||||||
|
if headers:
|
||||||
|
request_headers.update(headers)
|
||||||
|
|
||||||
|
# Add content-type header for JSON requests
|
||||||
|
if json and "Content-Type" not in request_headers:
|
||||||
|
request_headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
|
for _ in range(self.config.max_retries):
|
||||||
|
async with self.session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||||
|
if response.status >= 400:
|
||||||
|
error_data = await response.json()
|
||||||
|
raise Exception(f"API request failed: {error_data}")
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
async def get_training_jobs(
|
||||||
|
self,
|
||||||
|
page: Optional[int] = 1,
|
||||||
|
page_size: Optional[int] = 10,
|
||||||
|
sort: Optional[Literal["created_at", "-created_at"]] = "created_at",
|
||||||
|
) -> ListNvidiaPostTrainingJobs:
|
||||||
|
"""Get all customization jobs.
|
||||||
|
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
|
||||||
|
|
||||||
|
Returns a ListNvidiaPostTrainingJobs object with the following fields:
|
||||||
|
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects
|
||||||
|
|
||||||
|
ToDo: Support for schema input for filtering.
|
||||||
|
"""
|
||||||
|
params = {"page": page, "page_size": page_size, "sort": sort}
|
||||||
|
|
||||||
|
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
|
||||||
|
|
||||||
|
jobs = []
|
||||||
|
for job in response.get("data", []):
|
||||||
|
job_id = job.pop("id")
|
||||||
|
job_status = job.pop("status", "unknown").lower()
|
||||||
|
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
|
||||||
|
|
||||||
|
# Convert string timestamps to datetime objects
|
||||||
|
created_at = (
|
||||||
|
datetime.fromisoformat(job.pop("created_at"))
|
||||||
|
if "created_at" in job
|
||||||
|
else datetime.now(tz=datetime.timezone.utc)
|
||||||
|
)
|
||||||
|
updated_at = (
|
||||||
|
datetime.fromisoformat(job.pop("updated_at"))
|
||||||
|
if "updated_at" in job
|
||||||
|
else datetime.now(tz=datetime.timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create NvidiaPostTrainingJob instance
|
||||||
|
jobs.append(
|
||||||
|
NvidiaPostTrainingJob(
|
||||||
|
job_uuid=job_id,
|
||||||
|
status=JobStatus(mapped_status),
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=updated_at,
|
||||||
|
**job,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ListNvidiaPostTrainingJobs(data=jobs)
|
||||||
|
|
||||||
|
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
|
||||||
|
"""Get the status of a customization job.
|
||||||
|
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
|
||||||
|
|
||||||
|
Returns a NvidiaPostTrainingJob object with the following fields:
|
||||||
|
- job_uuid: str - Unique identifier for the job
|
||||||
|
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
|
||||||
|
- created_at: datetime - The time when the job was created
|
||||||
|
- updated_at: datetime - The last time the job status was updated
|
||||||
|
|
||||||
|
Additional fields that may be included:
|
||||||
|
- steps_completed: Optional[int] - Number of training steps completed
|
||||||
|
- epochs_completed: Optional[int] - Number of epochs completed
|
||||||
|
- percentage_done: Optional[float] - Percentage of training completed (0-100)
|
||||||
|
- best_epoch: Optional[int] - The epoch with the best performance
|
||||||
|
- train_loss: Optional[float] - Training loss of the best checkpoint
|
||||||
|
- val_loss: Optional[float] - Validation loss of the best checkpoint
|
||||||
|
- metrics: Optional[Dict] - Additional training metrics
|
||||||
|
- status_logs: Optional[List] - Detailed logs of status changes
|
||||||
|
"""
|
||||||
|
response = await self._make_request(
|
||||||
|
"GET",
|
||||||
|
f"/v1/customization/jobs/{job_uuid}/status",
|
||||||
|
params={"job_id": job_uuid},
|
||||||
|
)
|
||||||
|
|
||||||
|
api_status = response.pop("status").lower()
|
||||||
|
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
|
||||||
|
|
||||||
|
return NvidiaPostTrainingJobStatusResponse(
|
||||||
|
status=JobStatus(mapped_status),
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
started_at=datetime.fromisoformat(response.pop("created_at")),
|
||||||
|
updated_at=datetime.fromisoformat(response.pop("updated_at")),
|
||||||
|
**response,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
|
await self._make_request(
|
||||||
|
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||||
|
raise NotImplementedError("Job artifacts are not implemented yet")
|
||||||
|
|
||||||
|
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||||
|
raise NotImplementedError("Job artifacts are not implemented yet")
|
||||||
|
|
||||||
|
async def supervised_fine_tune(
|
||||||
|
self,
|
||||||
|
job_uuid: str,
|
||||||
|
training_config: Dict[str, Any],
|
||||||
|
hyperparam_search_config: Dict[str, Any],
|
||||||
|
logger_config: Dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
checkpoint_dir: Optional[str],
|
||||||
|
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||||
|
extra_json: Optional[Dict[str, Any]] = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
headers: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> NvidiaPostTrainingJob:
|
||||||
|
"""
|
||||||
|
Fine-tunes a model on a dataset.
|
||||||
|
Currently only supports Lora finetuning for standlone docker container.
|
||||||
|
Assumptions:
|
||||||
|
- nemo microservice is running and endpoint is set in config.customizer_url
|
||||||
|
- dataset is registered separately in nemo datastore
|
||||||
|
- model checkpoint is downloaded as per nemo customizer requirements
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
training_config: TrainingConfig - Configuration for training
|
||||||
|
model: str - Model identifier
|
||||||
|
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
|
||||||
|
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
|
||||||
|
job_uuid: str - Unique identifier for the job, ignored atm
|
||||||
|
hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search, ignored atm
|
||||||
|
logger_config: Dict[str, Any] - Configuration for logging, ignored atm
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
- NVIDIA_API_KEY: str - API key for the NVIDIA API
|
||||||
|
Default: None
|
||||||
|
- NVIDIA_DATASET_NAMESPACE: str - Namespace of the dataset
|
||||||
|
Default: "default"
|
||||||
|
- NVIDIA_CUSTOMIZER_URL: str - URL of the NeMo Customizer API
|
||||||
|
Default: "http://nemo.test"
|
||||||
|
- NVIDIA_PROJECT_ID: str - ID of the project
|
||||||
|
Default: "test-project"
|
||||||
|
- NVIDIA_OUTPUT_MODEL_DIR: str - Directory to save the output model
|
||||||
|
Default: "test-example-model@v1"
|
||||||
|
|
||||||
|
Supported models:
|
||||||
|
- meta/llama-3.1-8b-instruct
|
||||||
|
|
||||||
|
Supported algorithm configs:
|
||||||
|
- LoRA, SFT
|
||||||
|
|
||||||
|
Supported Parameters:
|
||||||
|
- TrainingConfig:
|
||||||
|
- n_epochs: int - Number of epochs to train
|
||||||
|
Default: 50
|
||||||
|
- data_config: DataConfig - Configuration for the dataset
|
||||||
|
- optimizer_config: OptimizerConfig - Configuration for the optimizer
|
||||||
|
- dtype: str - Data type for training
|
||||||
|
not supported (users are informed via warnings)
|
||||||
|
- efficiency_config: EfficiencyConfig - Configuration for efficiency
|
||||||
|
not supported
|
||||||
|
- max_steps_per_epoch: int - Maximum number of steps per epoch
|
||||||
|
Default: 1000
|
||||||
|
## NeMo customizer specific parameters
|
||||||
|
- log_every_n_steps: int - Log every n steps
|
||||||
|
Default: None
|
||||||
|
- val_check_interval: float - Validation check interval
|
||||||
|
Default: 0.25
|
||||||
|
- sequence_packing_enabled: bool - Sequence packing enabled
|
||||||
|
Default: False
|
||||||
|
## NeMo customizer specific SFT parameters
|
||||||
|
- hidden_dropout: float - Hidden dropout
|
||||||
|
Default: None (0.0-1.0)
|
||||||
|
- attention_dropout: float - Attention dropout
|
||||||
|
Default: None (0.0-1.0)
|
||||||
|
- ffn_dropout: float - FFN dropout
|
||||||
|
Default: None (0.0-1.0)
|
||||||
|
|
||||||
|
- DataConfig:
|
||||||
|
- dataset_id: str - Dataset ID
|
||||||
|
- batch_size: int - Batch size
|
||||||
|
Default: 8
|
||||||
|
|
||||||
|
- OptimizerConfig:
|
||||||
|
- lr: float - Learning rate
|
||||||
|
Default: 0.0001
|
||||||
|
## NeMo customizer specific parameter
|
||||||
|
- weight_decay: float - Weight decay
|
||||||
|
Default: 0.01
|
||||||
|
|
||||||
|
- LoRA config:
|
||||||
|
## NeMo customizer specific LoRA parameters
|
||||||
|
- adapter_dim: int - Adapter dimension
|
||||||
|
Default: 8 (supports powers of 2)
|
||||||
|
- adapter_dropout: float - Adapter dropout
|
||||||
|
Default: None (0.0-1.0)
|
||||||
|
- alpha: int - Scaling factor for the LoRA update
|
||||||
|
Default: 16
|
||||||
|
Note:
|
||||||
|
- checkpoint_dir, hyperparam_search_config, logger_config are not supported (users are informed via warnings)
|
||||||
|
- Some parameters from TrainingConfig, DataConfig, OptimizerConfig are not supported (users are informed via warnings)
|
||||||
|
|
||||||
|
User is informed about unsupported parameters via warnings.
|
||||||
|
"""
|
||||||
|
# Map model to nvidia model name
|
||||||
|
# ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models
|
||||||
|
nvidia_model = self.get_provider_model_id(model)
|
||||||
|
|
||||||
|
# Check for unsupported method parameters
|
||||||
|
unsupported_method_params = []
|
||||||
|
if checkpoint_dir:
|
||||||
|
unsupported_method_params.append(f"checkpoint_dir={checkpoint_dir}")
|
||||||
|
if hyperparam_search_config:
|
||||||
|
unsupported_method_params.append("hyperparam_search_config")
|
||||||
|
if logger_config:
|
||||||
|
unsupported_method_params.append("logger_config")
|
||||||
|
|
||||||
|
if unsupported_method_params:
|
||||||
|
warnings.warn(
|
||||||
|
f"Parameters: {', '.join(unsupported_method_params)} are not supported and will be ignored",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define all supported parameters
|
||||||
|
supported_params = {
|
||||||
|
"training_config": {
|
||||||
|
"n_epochs",
|
||||||
|
"data_config",
|
||||||
|
"optimizer_config",
|
||||||
|
"log_every_n_steps",
|
||||||
|
"val_check_interval",
|
||||||
|
"sequence_packing_enabled",
|
||||||
|
"hidden_dropout",
|
||||||
|
"attention_dropout",
|
||||||
|
"ffn_dropout",
|
||||||
|
},
|
||||||
|
"data_config": {"dataset_id", "batch_size"},
|
||||||
|
"optimizer_config": {"lr", "weight_decay"},
|
||||||
|
"lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Validate all parameters at once
|
||||||
|
warn_unsupported_params(training_config, supported_params["training_config"], "TrainingConfig")
|
||||||
|
warn_unsupported_params(training_config["data_config"], supported_params["data_config"], "DataConfig")
|
||||||
|
warn_unsupported_params(
|
||||||
|
training_config["optimizer_config"], supported_params["optimizer_config"], "OptimizerConfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
output_model = self.config.output_model_dir
|
||||||
|
|
||||||
|
# Prepare base job configuration
|
||||||
|
job_config = {
|
||||||
|
"config": nvidia_model,
|
||||||
|
"dataset": {
|
||||||
|
"name": training_config["data_config"]["dataset_id"],
|
||||||
|
"namespace": self.config.dataset_namespace,
|
||||||
|
},
|
||||||
|
"hyperparameters": {
|
||||||
|
"training_type": "sft",
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in {
|
||||||
|
"epochs": training_config.get("n_epochs"),
|
||||||
|
"batch_size": training_config["data_config"].get("batch_size"),
|
||||||
|
"learning_rate": training_config["optimizer_config"].get("lr"),
|
||||||
|
"weight_decay": training_config["optimizer_config"].get("weight_decay"),
|
||||||
|
"log_every_n_steps": training_config.get("log_every_n_steps"),
|
||||||
|
"val_check_interval": training_config.get("val_check_interval"),
|
||||||
|
"sequence_packing_enabled": training_config.get("sequence_packing_enabled"),
|
||||||
|
}.items()
|
||||||
|
if v is not None
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"project": self.config.project_id,
|
||||||
|
# TODO: ignored ownership, add it later
|
||||||
|
# "ownership": {"created_by": self.config.user_id, "access_policies": self.config.access_policies},
|
||||||
|
"output_model": output_model,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle SFT-specific optional parameters
|
||||||
|
job_config["hyperparameters"]["sft"] = {
|
||||||
|
k: v
|
||||||
|
for k, v in {
|
||||||
|
"ffn_dropout": training_config.get("ffn_dropout"),
|
||||||
|
"hidden_dropout": training_config.get("hidden_dropout"),
|
||||||
|
"attention_dropout": training_config.get("attention_dropout"),
|
||||||
|
}.items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove the sft dictionary if it's empty
|
||||||
|
if not job_config["hyperparameters"]["sft"]:
|
||||||
|
job_config["hyperparameters"].pop("sft")
|
||||||
|
|
||||||
|
# Handle LoRA-specific configuration
|
||||||
|
if algorithm_config:
|
||||||
|
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
|
||||||
|
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
|
||||||
|
job_config["hyperparameters"]["lora"] = {
|
||||||
|
k: v
|
||||||
|
for k, v in {
|
||||||
|
"adapter_dim": algorithm_config.get("adapter_dim"),
|
||||||
|
"alpha": algorithm_config.get("alpha"),
|
||||||
|
"adapter_dropout": algorithm_config.get("adapter_dropout"),
|
||||||
|
}.items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
||||||
|
|
||||||
|
# Create the customization job
|
||||||
|
response = await self._make_request(
|
||||||
|
method="POST",
|
||||||
|
path="/v1/customization/jobs",
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
json=job_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
job_uuid = response["id"]
|
||||||
|
response.pop("status")
|
||||||
|
created_at = datetime.fromisoformat(response.pop("created_at"))
|
||||||
|
updated_at = datetime.fromisoformat(response.pop("updated_at"))
|
||||||
|
|
||||||
|
return NvidiaPostTrainingJob(
|
||||||
|
job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response
|
||||||
|
)
|
||||||
|
|
||||||
|
async def preference_optimize(
|
||||||
|
self,
|
||||||
|
job_uuid: str,
|
||||||
|
finetuned_model: str,
|
||||||
|
algorithm_config: DPOAlignmentConfig,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: Dict[str, Any],
|
||||||
|
logger_config: Dict[str, Any],
|
||||||
|
) -> PostTrainingJob:
|
||||||
|
"""Optimize a model based on preference data."""
|
||||||
|
raise NotImplementedError("Preference optimization is not implemented yet")
|
||||||
|
|
||||||
|
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||||
|
raise NotImplementedError("Job logs are not implemented yet")
|
63
llama_stack/providers/remote/post_training/nvidia/utils.py
Normal file
63
llama_stack/providers/remote/post_training/nvidia/utils.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from typing import Any, Dict, Set, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.post_training import TrainingConfig
|
||||||
|
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
|
||||||
|
|
||||||
|
from .config import NvidiaPostTrainingConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def warn_unsupported_params(config_dict: Any, supported_keys: Set[str], config_name: str) -> None:
|
||||||
|
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()
|
||||||
|
unsupported_params = [k for k in keys if k not in supported_keys]
|
||||||
|
if unsupported_params:
|
||||||
|
warnings.warn(
|
||||||
|
f"Parameters: {unsupported_params} in `{config_name}` not supported and will be ignored.", stacklevel=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_training_params(
|
||||||
|
training_config: Dict[str, Any], supported_keys: Set[str], config_name: str = "TrainingConfig"
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Validates training parameters against supported keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
training_config: Dictionary containing training configuration parameters
|
||||||
|
supported_keys: Set of supported parameter keys
|
||||||
|
config_name: Name of the configuration for warning messages
|
||||||
|
"""
|
||||||
|
sft_lora_fields = set(SFTLoRADefaultConfig.__annotations__.keys())
|
||||||
|
training_config_fields = set(TrainingConfig.__annotations__.keys())
|
||||||
|
|
||||||
|
# Check for not supported parameters:
|
||||||
|
# - not in either of configs
|
||||||
|
# - in TrainingConfig but not in SFTLoRADefaultConfig
|
||||||
|
unsupported_params = []
|
||||||
|
for key in training_config:
|
||||||
|
if isinstance(key, str) and key not in (supported_keys.union(sft_lora_fields)):
|
||||||
|
if key in (not sft_lora_fields or training_config_fields):
|
||||||
|
unsupported_params.append(key)
|
||||||
|
|
||||||
|
if unsupported_params:
|
||||||
|
warnings.warn(
|
||||||
|
f"Parameters: {unsupported_params} in `{config_name}` are not supported and will be ignored.", stacklevel=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ToDo: implement post health checks for customizer are enabled
|
||||||
|
async def _get_health(url: str) -> Tuple[bool, bool]: ...
|
||||||
|
|
||||||
|
|
||||||
|
async def check_health(config: NvidiaPostTrainingConfig) -> None: ...
|
|
@ -14,6 +14,8 @@ distribution_spec:
|
||||||
- inline::meta-reference
|
- inline::meta-reference
|
||||||
eval:
|
eval:
|
||||||
- inline::meta-reference
|
- inline::meta-reference
|
||||||
|
post_training:
|
||||||
|
- remote::nvidia
|
||||||
datasetio:
|
datasetio:
|
||||||
- inline::localfs
|
- inline::localfs
|
||||||
scoring:
|
scoring:
|
||||||
|
|
|
@ -21,6 +21,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
"telemetry": ["inline::meta-reference"],
|
"telemetry": ["inline::meta-reference"],
|
||||||
"eval": ["inline::meta-reference"],
|
"eval": ["inline::meta-reference"],
|
||||||
|
"post_training": ["remote::nvidia"],
|
||||||
"datasetio": ["inline::localfs"],
|
"datasetio": ["inline::localfs"],
|
||||||
"scoring": ["inline::basic"],
|
"scoring": ["inline::basic"],
|
||||||
"tool_runtime": ["inline::rag-runtime"],
|
"tool_runtime": ["inline::rag-runtime"],
|
||||||
|
@ -89,6 +90,31 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"",
|
"",
|
||||||
"NVIDIA API Key",
|
"NVIDIA API Key",
|
||||||
),
|
),
|
||||||
|
## Nemo Customizer related variables
|
||||||
|
"NVIDIA_USER_ID": (
|
||||||
|
"llama-stack-user",
|
||||||
|
"NVIDIA User ID",
|
||||||
|
),
|
||||||
|
"NVIDIA_DATASET_NAMESPACE": (
|
||||||
|
"default",
|
||||||
|
"NVIDIA Dataset Namespace",
|
||||||
|
),
|
||||||
|
"NVIDIA_ACCESS_POLICIES": (
|
||||||
|
"{}",
|
||||||
|
"NVIDIA Access Policies",
|
||||||
|
),
|
||||||
|
"NVIDIA_PROJECT_ID": (
|
||||||
|
"test-project",
|
||||||
|
"NVIDIA Project ID",
|
||||||
|
),
|
||||||
|
"NVIDIA_CUSTOMIZER_URL": (
|
||||||
|
"https://customizer.api.nvidia.com",
|
||||||
|
"NVIDIA Customizer URL",
|
||||||
|
),
|
||||||
|
"NVIDIA_OUTPUT_MODEL_DIR": (
|
||||||
|
"test-example-model@v1",
|
||||||
|
"NVIDIA Output Model Directory",
|
||||||
|
),
|
||||||
"GUARDRAILS_SERVICE_URL": (
|
"GUARDRAILS_SERVICE_URL": (
|
||||||
"http://0.0.0.0:7331",
|
"http://0.0.0.0:7331",
|
||||||
"URL for the NeMo Guardrails Service",
|
"URL for the NeMo Guardrails Service",
|
||||||
|
|
|
@ -5,6 +5,7 @@ apis:
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
- inference
|
- inference
|
||||||
|
- post_training
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
- telemetry
|
- telemetry
|
||||||
|
@ -58,6 +59,14 @@ providers:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
namespace: null
|
namespace: null
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
||||||
|
post_training:
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
api_key: ${env.NVIDIA_API_KEY:}
|
||||||
|
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:default}
|
||||||
|
project_id: ${env.NVIDIA_PROJECT_ID:test-project}
|
||||||
|
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}
|
||||||
datasetio:
|
datasetio:
|
||||||
- provider_id: localfs
|
- provider_id: localfs
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
|
|
|
@ -5,6 +5,7 @@ apis:
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
- inference
|
- inference
|
||||||
|
- post_training
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
- telemetry
|
- telemetry
|
||||||
|
@ -53,6 +54,14 @@ providers:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
namespace: null
|
namespace: null
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
||||||
|
post_training:
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
api_key: ${env.NVIDIA_API_KEY:}
|
||||||
|
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:default}
|
||||||
|
project_id: ${env.NVIDIA_PROJECT_ID:test-project}
|
||||||
|
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}
|
||||||
datasetio:
|
datasetio:
|
||||||
- provider_id: localfs
|
- provider_id: localfs
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
|
|
|
@ -56,7 +56,7 @@ dev = [
|
||||||
"ruamel.yaml", # needed for openapi generator
|
"ruamel.yaml", # needed for openapi generator
|
||||||
]
|
]
|
||||||
# These are the dependencies required for running unit tests.
|
# These are the dependencies required for running unit tests.
|
||||||
unit = ["sqlite-vec", "openai", "aiosqlite", "pypdf", "chardet", "qdrant-client"]
|
unit = ["sqlite-vec", "openai", "aiosqlite", "aiohttp", "pypdf", "chardet", "qdrant-client"]
|
||||||
# These are the core dependencies required for running integration tests. They are shared across all
|
# These are the core dependencies required for running integration tests. They are shared across all
|
||||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||||
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
|
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
|
||||||
|
@ -64,6 +64,7 @@ unit = ["sqlite-vec", "openai", "aiosqlite", "pypdf", "chardet", "qdrant-client"
|
||||||
test = [
|
test = [
|
||||||
"openai",
|
"openai",
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
|
"aiohttp",
|
||||||
"torch>=2.6.0",
|
"torch>=2.6.0",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
"opentelemetry-sdk",
|
"opentelemetry-sdk",
|
||||||
|
@ -130,7 +131,6 @@ select = [
|
||||||
"F", # Pyflakes
|
"F", # Pyflakes
|
||||||
"N", # Naming
|
"N", # Naming
|
||||||
"W", # Warnings
|
"W", # Warnings
|
||||||
"I", # isort
|
|
||||||
"DTZ", # datetime rules
|
"DTZ", # datetime rules
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
|
@ -262,6 +262,7 @@ exclude = [
|
||||||
"^llama_stack/providers/remote/tool_runtime/model_context_protocol/",
|
"^llama_stack/providers/remote/tool_runtime/model_context_protocol/",
|
||||||
"^llama_stack/providers/remote/tool_runtime/tavily_search/",
|
"^llama_stack/providers/remote/tool_runtime/tavily_search/",
|
||||||
"^llama_stack/providers/remote/tool_runtime/wolfram_alpha/",
|
"^llama_stack/providers/remote/tool_runtime/wolfram_alpha/",
|
||||||
|
"^llama_stack/providers/remote/post_training/nvidia/",
|
||||||
"^llama_stack/providers/remote/vector_io/chroma/",
|
"^llama_stack/providers/remote/vector_io/chroma/",
|
||||||
"^llama_stack/providers/remote/vector_io/milvus/",
|
"^llama_stack/providers/remote/vector_io/milvus/",
|
||||||
"^llama_stack/providers/remote/vector_io/pgvector/",
|
"^llama_stack/providers/remote/vector_io/pgvector/",
|
||||||
|
|
5
tests/client-sdk/post_training/__init__.py
Normal file
5
tests/client-sdk/post_training/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
60
tests/client-sdk/post_training/test_supervied_fine_tuning.py
Normal file
60
tests/client-sdk/post_training/test_supervied_fine_tuning.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
POST_TRAINING_PROVIDER_TYPES = ["remote::nvidia"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def post_training_provider_available(llama_stack_client):
|
||||||
|
providers = llama_stack_client.providers.list()
|
||||||
|
post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES]
|
||||||
|
return len(post_training_providers) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_post_training_provider_registration(llama_stack_client, post_training_provider_available):
|
||||||
|
"""Check if post_training is in the api list.
|
||||||
|
This is a sanity check to ensure the provider is registered."""
|
||||||
|
if not post_training_provider_available:
|
||||||
|
pytest.skip("post training provider not available")
|
||||||
|
|
||||||
|
providers = llama_stack_client.providers.list()
|
||||||
|
post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES]
|
||||||
|
assert len(post_training_providers) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_get_training_jobs(llama_stack_client, post_training_provider_available):
|
||||||
|
"""Test listing all training jobs."""
|
||||||
|
if not post_training_provider_available:
|
||||||
|
pytest.skip("post training provider not available")
|
||||||
|
|
||||||
|
jobs = llama_stack_client.post_training.get_training_jobs()
|
||||||
|
assert isinstance(jobs, dict)
|
||||||
|
assert "data" in jobs
|
||||||
|
assert isinstance(jobs["data"], list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_get_training_job_status(llama_stack_client, post_training_provider_available):
|
||||||
|
"""Test getting status of a specific training job."""
|
||||||
|
if not post_training_provider_available:
|
||||||
|
pytest.skip("post training provider not available")
|
||||||
|
|
||||||
|
jobs = llama_stack_client.post_training.get_training_jobs()
|
||||||
|
if not jobs["data"]:
|
||||||
|
pytest.skip("No training jobs available to check status")
|
||||||
|
|
||||||
|
job_uuid = jobs["data"][0]["job_uuid"]
|
||||||
|
job_status = llama_stack_client.post_training.get_training_job_status(job_uuid=job_uuid)
|
||||||
|
|
||||||
|
assert job_status is not None
|
||||||
|
assert "job_uuid" in job_status
|
||||||
|
assert "status" in job_status
|
||||||
|
assert job_status["job_uuid"] == job_uuid
|
5
tests/unit/providers/nvidia/__init__.py
Normal file
5
tests/unit/providers/nvidia/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
45
tests/unit/providers/nvidia/conftest.py
Normal file
45
tests/unit/providers/nvidia/conftest.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.closed = False
|
||||||
|
mock_session.close = AsyncMock()
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def patch_aiohttp_session():
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def event_loop():
|
||||||
|
"""Create and provide a new event loop for each test."""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def run_async():
|
||||||
|
"""Fixture to run async functions in tests."""
|
||||||
|
|
||||||
|
def _run_async(coro):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
return loop.run_until_complete(coro)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
return _run_async
|
271
tests/unit/providers/nvidia/test_parameters.py
Normal file
271
tests/unit/providers/nvidia/test_parameters.py
Normal file
|
@ -0,0 +1,271 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import warnings
|
||||||
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
||||||
|
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||||
|
TrainingConfig,
|
||||||
|
TrainingConfigDataConfig,
|
||||||
|
TrainingConfigOptimizerConfig,
|
||||||
|
TrainingConfigEfficiencyConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
|
NvidiaPostTrainingAdapter,
|
||||||
|
NvidiaPostTrainingConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNvidiaParameters(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
|
||||||
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||||
|
|
||||||
|
config = NvidiaPostTrainingConfig(
|
||||||
|
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
||||||
|
)
|
||||||
|
self.adapter = NvidiaPostTrainingAdapter(config)
|
||||||
|
|
||||||
|
self.make_request_patcher = patch(
|
||||||
|
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
||||||
|
)
|
||||||
|
self.mock_make_request = self.make_request_patcher.start()
|
||||||
|
self.mock_make_request.return_value = {
|
||||||
|
"id": "job-123",
|
||||||
|
"status": "created",
|
||||||
|
"created_at": "2025-03-04T13:07:47.543605",
|
||||||
|
"updated_at": "2025-03-04T13:07:47.543605",
|
||||||
|
}
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.make_request_patcher.stop()
|
||||||
|
|
||||||
|
def _assert_request_params(self, expected_json):
|
||||||
|
"""Helper method to verify parameters in the request JSON."""
|
||||||
|
call_args = self.mock_make_request.call_args
|
||||||
|
actual_json = call_args[1]["json"]
|
||||||
|
|
||||||
|
for key, value in expected_json.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for nested_key, nested_value in value.items():
|
||||||
|
assert actual_json[key][nested_key] == nested_value
|
||||||
|
else:
|
||||||
|
assert actual_json[key] == value
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def inject_fixtures(self, run_async):
|
||||||
|
self.run_async = run_async
|
||||||
|
|
||||||
|
def test_customizer_parameters_passed(self):
|
||||||
|
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
||||||
|
custom_adapter_dim = 32 # Different from default of 8
|
||||||
|
algorithm_config = LoraFinetuningConfig(
|
||||||
|
type="LoRA",
|
||||||
|
adapter_dim=custom_adapter_dim,
|
||||||
|
adapter_dropout=0.2,
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=True,
|
||||||
|
alpha=16,
|
||||||
|
rank=16,
|
||||||
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
|
)
|
||||||
|
|
||||||
|
data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16)
|
||||||
|
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002)
|
||||||
|
training_config = TrainingConfig(
|
||||||
|
n_epochs=3,
|
||||||
|
data_config=data_config,
|
||||||
|
optimizer_config=optimizer_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
|
||||||
|
self.run_async(
|
||||||
|
self.adapter.supervised_fine_tune(
|
||||||
|
job_uuid="test-job",
|
||||||
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
checkpoint_dir="",
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
training_config=training_config,
|
||||||
|
logger_config={},
|
||||||
|
hyperparam_search_config={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
warning_texts = [str(warning.message) for warning in w]
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
"apply_lora_to_output",
|
||||||
|
"lora_attn_modules",
|
||||||
|
"apply_lora_to_mlp",
|
||||||
|
]
|
||||||
|
for field in fields:
|
||||||
|
assert any(field in text for text in warning_texts)
|
||||||
|
|
||||||
|
self._assert_request_params(
|
||||||
|
{
|
||||||
|
"hyperparameters": {
|
||||||
|
"lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2, "alpha": 16},
|
||||||
|
"epochs": 3,
|
||||||
|
"learning_rate": 0.0002,
|
||||||
|
"batch_size": 16,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_required_parameters_passed(self):
|
||||||
|
"""Test scenario 2: When required parameters are passed."""
|
||||||
|
required_model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
required_dataset_id = "required-dataset"
|
||||||
|
required_job_uuid = "required-job"
|
||||||
|
|
||||||
|
algorithm_config = LoraFinetuningConfig(
|
||||||
|
type="LoRA",
|
||||||
|
adapter_dim=16,
|
||||||
|
adapter_dropout=0.1,
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=True,
|
||||||
|
alpha=16,
|
||||||
|
rank=16,
|
||||||
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
|
)
|
||||||
|
|
||||||
|
data_config = TrainingConfigDataConfig(
|
||||||
|
dataset_id=required_dataset_id, # Required parameter
|
||||||
|
batch_size=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001)
|
||||||
|
|
||||||
|
training_config = TrainingConfig(
|
||||||
|
n_epochs=1,
|
||||||
|
data_config=data_config,
|
||||||
|
optimizer_config=optimizer_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
|
||||||
|
self.run_async(
|
||||||
|
self.adapter.supervised_fine_tune(
|
||||||
|
job_uuid=required_job_uuid, # Required parameter
|
||||||
|
model=required_model, # Required parameter
|
||||||
|
checkpoint_dir="",
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
training_config=training_config,
|
||||||
|
logger_config={},
|
||||||
|
hyperparam_search_config={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
warning_texts = [str(warning.message) for warning in w]
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
"rank",
|
||||||
|
"apply_lora_to_output",
|
||||||
|
"lora_attn_modules",
|
||||||
|
"apply_lora_to_mlp",
|
||||||
|
]
|
||||||
|
for field in fields:
|
||||||
|
assert any(field in text for text in warning_texts)
|
||||||
|
|
||||||
|
self.mock_make_request.assert_called_once()
|
||||||
|
call_args = self.mock_make_request.call_args
|
||||||
|
|
||||||
|
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
|
||||||
|
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
|
||||||
|
|
||||||
|
def test_unsupported_parameters_warning(self):
|
||||||
|
"""Test that warnings are raised for unsupported parameters."""
|
||||||
|
data_config = TrainingConfigDataConfig(
|
||||||
|
dataset_id="test-dataset",
|
||||||
|
batch_size=8,
|
||||||
|
# Unsupported parameters
|
||||||
|
shuffle=True,
|
||||||
|
data_format="instruct",
|
||||||
|
validation_dataset_id="val-dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_config = TrainingConfigOptimizerConfig(
|
||||||
|
lr=0.0001,
|
||||||
|
weight_decay=0.01,
|
||||||
|
# Unsupported parameters
|
||||||
|
optimizer_type="adam",
|
||||||
|
num_warmup_steps=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
efficiency_config = TrainingConfigEfficiencyConfig(
|
||||||
|
enable_activation_checkpointing=True # Unsupported parameter
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = TrainingConfig(
|
||||||
|
n_epochs=1,
|
||||||
|
data_config=data_config,
|
||||||
|
optimizer_config=optimizer_config,
|
||||||
|
# Unsupported parameters
|
||||||
|
efficiency_config=efficiency_config,
|
||||||
|
max_steps_per_epoch=1000,
|
||||||
|
gradient_accumulation_steps=4,
|
||||||
|
max_validation_steps=100,
|
||||||
|
dtype="bf16",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture warnings
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
|
||||||
|
self.run_async(
|
||||||
|
self.adapter.supervised_fine_tune(
|
||||||
|
job_uuid="test-job",
|
||||||
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
checkpoint_dir="test-dir", # Unsupported parameter
|
||||||
|
algorithm_config=LoraFinetuningConfig(
|
||||||
|
type="LoRA",
|
||||||
|
adapter_dim=16,
|
||||||
|
adapter_dropout=0.1,
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=True,
|
||||||
|
alpha=16,
|
||||||
|
rank=16,
|
||||||
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
|
),
|
||||||
|
training_config=training_config,
|
||||||
|
logger_config={"test": "value"}, # Unsupported parameter
|
||||||
|
hyperparam_search_config={"test": "value"}, # Unsupported parameter
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(w) >= 4
|
||||||
|
warning_texts = [str(warning.message) for warning in w]
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
"checkpoint_dir",
|
||||||
|
"hyperparam_search_config",
|
||||||
|
"logger_config",
|
||||||
|
"TrainingConfig",
|
||||||
|
"DataConfig",
|
||||||
|
"OptimizerConfig",
|
||||||
|
"max_steps_per_epoch",
|
||||||
|
"gradient_accumulation_steps",
|
||||||
|
"max_validation_steps",
|
||||||
|
"dtype",
|
||||||
|
# required unsupported parameters
|
||||||
|
"rank",
|
||||||
|
"apply_lora_to_output",
|
||||||
|
"lora_attn_modules",
|
||||||
|
"apply_lora_to_mlp",
|
||||||
|
]
|
||||||
|
for field in fields:
|
||||||
|
assert any(field in text for text in warning_texts)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
295
tests/unit/providers/nvidia/test_supervised_fine_tuning.py
Normal file
295
tests/unit/providers/nvidia/test_supervised_fine_tuning.py
Normal file
|
@ -0,0 +1,295 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
import warnings
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
|
||||||
|
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||||
|
TrainingConfig,
|
||||||
|
TrainingConfigDataConfig,
|
||||||
|
TrainingConfigOptimizerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
|
NvidiaPostTrainingAdapter,
|
||||||
|
NvidiaPostTrainingConfig,
|
||||||
|
NvidiaPostTrainingJobStatusResponse,
|
||||||
|
ListNvidiaPostTrainingJobs,
|
||||||
|
NvidiaPostTrainingJob,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
|
||||||
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||||
|
|
||||||
|
config = NvidiaPostTrainingConfig(
|
||||||
|
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
||||||
|
)
|
||||||
|
self.adapter = NvidiaPostTrainingAdapter(config)
|
||||||
|
self.make_request_patcher = patch(
|
||||||
|
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
||||||
|
)
|
||||||
|
self.mock_make_request = self.make_request_patcher.start()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.make_request_patcher.stop()
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def inject_fixtures(self, run_async):
|
||||||
|
self.run_async = run_async
|
||||||
|
|
||||||
|
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||||
|
"""Helper method to verify request details in mock calls."""
|
||||||
|
call_args = mock_call.call_args
|
||||||
|
|
||||||
|
if expected_method and expected_path:
|
||||||
|
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
||||||
|
assert call_args[0] == (expected_method, expected_path)
|
||||||
|
else:
|
||||||
|
assert call_args[1]["method"] == expected_method
|
||||||
|
assert call_args[1]["path"] == expected_path
|
||||||
|
|
||||||
|
if expected_params:
|
||||||
|
assert call_args[1]["params"] == expected_params
|
||||||
|
|
||||||
|
if expected_json:
|
||||||
|
for key, value in expected_json.items():
|
||||||
|
assert call_args[1]["json"][key] == value
|
||||||
|
|
||||||
|
def test_supervised_fine_tune(self):
|
||||||
|
"""Test the supervised fine-tuning API call."""
|
||||||
|
self.mock_make_request.return_value = {
|
||||||
|
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"config": {
|
||||||
|
"schema_version": "1.0",
|
||||||
|
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542657",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.569837",
|
||||||
|
"custom_fields": {},
|
||||||
|
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"model_path": "llama-3_1-8b-instruct",
|
||||||
|
"training_types": [],
|
||||||
|
"finetuning_types": ["lora"],
|
||||||
|
"precision": "bf16",
|
||||||
|
"num_gpus": 4,
|
||||||
|
"num_nodes": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"tensor_parallel_size": 1,
|
||||||
|
"max_seq_length": 4096,
|
||||||
|
},
|
||||||
|
"dataset": {
|
||||||
|
"schema_version": "1.0",
|
||||||
|
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542657",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.542660",
|
||||||
|
"custom_fields": {},
|
||||||
|
"name": "sample-basic-test",
|
||||||
|
"version_id": "main",
|
||||||
|
"version_tags": [],
|
||||||
|
},
|
||||||
|
"hyperparameters": {
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"training_type": "sft",
|
||||||
|
"batch_size": 16,
|
||||||
|
"epochs": 2,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||||
|
},
|
||||||
|
"output_model": "default/job-1234",
|
||||||
|
"status": "created",
|
||||||
|
"project": "default",
|
||||||
|
"custom_fields": {},
|
||||||
|
"ownership": {"created_by": "me", "access_policies": {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
algorithm_config = LoraFinetuningConfig(
|
||||||
|
type="LoRA",
|
||||||
|
adapter_dim=16,
|
||||||
|
adapter_dropout=0.1,
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=True,
|
||||||
|
alpha=16,
|
||||||
|
rank=16,
|
||||||
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
|
)
|
||||||
|
|
||||||
|
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
||||||
|
|
||||||
|
optimizer_config = TrainingConfigOptimizerConfig(
|
||||||
|
lr=0.0001,
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = TrainingConfig(
|
||||||
|
n_epochs=2,
|
||||||
|
data_config=data_config,
|
||||||
|
optimizer_config=optimizer_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True):
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
training_job = self.run_async(
|
||||||
|
self.adapter.supervised_fine_tune(
|
||||||
|
job_uuid="1234",
|
||||||
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
checkpoint_dir="",
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
training_config=training_config,
|
||||||
|
logger_config={},
|
||||||
|
hyperparam_search_config={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# check the output is a PostTrainingJob
|
||||||
|
assert isinstance(training_job, NvidiaPostTrainingJob)
|
||||||
|
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
|
|
||||||
|
self.mock_make_request.assert_called_once()
|
||||||
|
self._assert_request(
|
||||||
|
self.mock_make_request,
|
||||||
|
"POST",
|
||||||
|
"/v1/customization/jobs",
|
||||||
|
expected_json={
|
||||||
|
"config": "meta/llama-3.1-8b-instruct",
|
||||||
|
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||||
|
"hyperparameters": {
|
||||||
|
"training_type": "sft",
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"epochs": 2,
|
||||||
|
"batch_size": 16,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_supervised_fine_tune_with_qat(self):
|
||||||
|
algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||||
|
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
||||||
|
optimizer_config = TrainingConfigOptimizerConfig(
|
||||||
|
lr=0.0001,
|
||||||
|
)
|
||||||
|
training_config = TrainingConfig(
|
||||||
|
n_epochs=2,
|
||||||
|
data_config=data_config,
|
||||||
|
optimizer_config=optimizer_config,
|
||||||
|
)
|
||||||
|
# This will raise NotImplementedError since QAT is not supported
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
self.run_async(
|
||||||
|
self.adapter.supervised_fine_tune(
|
||||||
|
job_uuid="1234",
|
||||||
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
checkpoint_dir="",
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
training_config=training_config,
|
||||||
|
logger_config={},
|
||||||
|
hyperparam_search_config={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_training_job_status(self):
|
||||||
|
self.mock_make_request.return_value = {
|
||||||
|
"created_at": "2024-12-09T04:06:28.580220",
|
||||||
|
"updated_at": "2024-12-09T04:21:19.852832",
|
||||||
|
"status": "completed",
|
||||||
|
"steps_completed": 1210,
|
||||||
|
"epochs_completed": 2,
|
||||||
|
"percentage_done": 100.0,
|
||||||
|
"best_epoch": 2,
|
||||||
|
"train_loss": 1.718016266822815,
|
||||||
|
"val_loss": 1.8661999702453613,
|
||||||
|
}
|
||||||
|
|
||||||
|
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
|
|
||||||
|
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
||||||
|
|
||||||
|
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||||
|
assert status.status.value == "completed"
|
||||||
|
assert status.steps_completed == 1210
|
||||||
|
assert status.epochs_completed == 2
|
||||||
|
assert status.percentage_done == 100.0
|
||||||
|
assert status.best_epoch == 2
|
||||||
|
assert status.train_loss == 1.718016266822815
|
||||||
|
assert status.val_loss == 1.8661999702453613
|
||||||
|
|
||||||
|
self.mock_make_request.assert_called_once()
|
||||||
|
self._assert_request(
|
||||||
|
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_training_jobs(self):
|
||||||
|
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
|
self.mock_make_request.return_value = {
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": job_id,
|
||||||
|
"created_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"updated_at": "2024-12-09T04:21:19.852832",
|
||||||
|
"config": {
|
||||||
|
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
},
|
||||||
|
"dataset": {"name": "default/sample-basic-test"},
|
||||||
|
"hyperparameters": {
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"training_type": "sft",
|
||||||
|
"batch_size": 16,
|
||||||
|
"epochs": 2,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||||
|
},
|
||||||
|
"output_model": "default/job-1234",
|
||||||
|
"status": "completed",
|
||||||
|
"project": "default",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
jobs = self.run_async(self.adapter.get_training_jobs())
|
||||||
|
|
||||||
|
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
|
||||||
|
assert len(jobs.data) == 1
|
||||||
|
job = jobs.data[0]
|
||||||
|
assert job.job_uuid == job_id
|
||||||
|
assert job.status.value == "completed"
|
||||||
|
|
||||||
|
self.mock_make_request.assert_called_once()
|
||||||
|
self._assert_request(
|
||||||
|
self.mock_make_request,
|
||||||
|
"GET",
|
||||||
|
"/v1/customization/jobs",
|
||||||
|
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cancel_training_job(self):
|
||||||
|
self.mock_make_request.return_value = {} # Empty response for successful cancellation
|
||||||
|
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
|
|
||||||
|
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
self.mock_make_request.assert_called_once()
|
||||||
|
self._assert_request(
|
||||||
|
self.mock_make_request,
|
||||||
|
"POST",
|
||||||
|
f"/v1/customization/jobs/{job_id}/cancel",
|
||||||
|
expected_params={"job_id": job_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
6
uv.lock
generated
6
uv.lock
generated
|
@ -1,4 +1,5 @@
|
||||||
version = 1
|
version = 1
|
||||||
|
revision = 1
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
|
@ -1370,6 +1371,7 @@ docs = [
|
||||||
{ name = "tomli" },
|
{ name = "tomli" },
|
||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
|
{ name = "aiohttp" },
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
{ name = "autoevals" },
|
{ name = "autoevals" },
|
||||||
{ name = "chardet" },
|
{ name = "chardet" },
|
||||||
|
@ -1385,6 +1387,7 @@ test = [
|
||||||
{ name = "torchvision", version = "0.21.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
{ name = "torchvision", version = "0.21.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
]
|
]
|
||||||
unit = [
|
unit = [
|
||||||
|
{ name = "aiohttp" },
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
{ name = "chardet" },
|
{ name = "chardet" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
|
@ -1395,6 +1398,8 @@ unit = [
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
|
{ name = "aiohttp", marker = "extra == 'test'" },
|
||||||
|
{ name = "aiohttp", marker = "extra == 'unit'" },
|
||||||
{ name = "aiosqlite", marker = "extra == 'test'" },
|
{ name = "aiosqlite", marker = "extra == 'test'" },
|
||||||
{ name = "aiosqlite", marker = "extra == 'unit'" },
|
{ name = "aiosqlite", marker = "extra == 'unit'" },
|
||||||
{ name = "autoevals", marker = "extra == 'test'" },
|
{ name = "autoevals", marker = "extra == 'test'" },
|
||||||
|
@ -1455,6 +1460,7 @@ requires-dist = [
|
||||||
{ name = "types-setuptools", marker = "extra == 'dev'" },
|
{ name = "types-setuptools", marker = "extra == 'dev'" },
|
||||||
{ name = "uvicorn", marker = "extra == 'dev'" },
|
{ name = "uvicorn", marker = "extra == 'dev'" },
|
||||||
]
|
]
|
||||||
|
provides-extras = ["dev", "unit", "test", "docs", "codegen"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-stack-client"
|
name = "llama-stack-client"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue