mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
Merge branch 'main' into chore/add-upstream-to-sl-config
This commit is contained in:
commit
36c8196d17
22 changed files with 583 additions and 521 deletions
|
|
@ -4,3 +4,9 @@ omit =
|
||||||
*/llama_stack/providers/*
|
*/llama_stack/providers/*
|
||||||
*/llama_stack/templates/*
|
*/llama_stack/templates/*
|
||||||
.venv/*
|
.venv/*
|
||||||
|
*/llama_stack/cli/scripts/*
|
||||||
|
*/llama_stack/ui/*
|
||||||
|
*/llama_stack/distribution/ui/*
|
||||||
|
*/llama_stack/strong_typing/*
|
||||||
|
*/llama_stack/env.py
|
||||||
|
*/__init__.py
|
||||||
|
|
|
||||||
57
.github/workflows/coverage-badge.yml
vendored
Normal file
57
.github/workflows/coverage-badge.yml
vendored
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
name: Coverage Badge
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
paths:
|
||||||
|
- 'llama_stack/**'
|
||||||
|
- 'tests/unit/**'
|
||||||
|
- 'uv.lock'
|
||||||
|
- 'pyproject.toml'
|
||||||
|
- 'requirements.txt'
|
||||||
|
- '.github/workflows/unit-tests.yml'
|
||||||
|
- '.github/workflows/coverage-badge.yml' # This workflow
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
unit-tests:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
||||||
|
- name: Run unit tests
|
||||||
|
run: |
|
||||||
|
./scripts/unit-tests.sh
|
||||||
|
|
||||||
|
- name: Coverage Badge
|
||||||
|
uses: tj-actions/coverage-badge-py@1788babcb24544eb5bbb6e0d374df5d1e54e670f # v2.0.4
|
||||||
|
|
||||||
|
- name: Verify Changed files
|
||||||
|
uses: tj-actions/verify-changed-files@a1c6acee9df209257a246f2cc6ae8cb6581c1edf # v20.0.4
|
||||||
|
id: verify-changed-files
|
||||||
|
with:
|
||||||
|
files: coverage.svg
|
||||||
|
|
||||||
|
- name: Commit files
|
||||||
|
if: steps.verify-changed-files.outputs.files_changed == 'true'
|
||||||
|
run: |
|
||||||
|
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
||||||
|
git config --local user.name "github-actions[bot]"
|
||||||
|
git add coverage.svg
|
||||||
|
git commit -m "Updated coverage.svg"
|
||||||
|
|
||||||
|
- name: Create Pull Request
|
||||||
|
if: steps.verify-changed-files.outputs.files_changed == 'true'
|
||||||
|
uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
title: "ci: [Automatic] Coverage Badge Update"
|
||||||
|
body: |
|
||||||
|
This PR updates the coverage badge based on the latest coverage report.
|
||||||
|
|
||||||
|
Automatically generated by the [workflow coverage-badge.yaml](.github/workflows/coverage-badge.yaml)
|
||||||
|
delete-branch: true
|
||||||
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
|
|
@ -7,7 +7,7 @@ on:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- 'llama_stack/**'
|
- 'llama_stack/**'
|
||||||
- 'tests/integration/**'
|
- 'tests/**'
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
|
|
|
||||||
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
|
@ -36,7 +36,7 @@ jobs:
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: |
|
run: |
|
||||||
PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --cov=llama_stack --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }}
|
PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --junitxml=pytest-report-${{ matrix.python }}.xml
|
||||||
|
|
||||||
- name: Upload test results
|
- name: Upload test results
|
||||||
if: always()
|
if: always()
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
[](https://discord.gg/llama-stack)
|
[](https://discord.gg/llama-stack)
|
||||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
||||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
||||||
|

|
||||||
|
|
||||||
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
|
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
|
||||||
|
|
||||||
|
|
|
||||||
21
coverage.svg
Normal file
21
coverage.svg
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="99" height="20">
|
||||||
|
<linearGradient id="b" x2="0" y2="100%">
|
||||||
|
<stop offset="0" stop-color="#bbb" stop-opacity=".1"/>
|
||||||
|
<stop offset="1" stop-opacity=".1"/>
|
||||||
|
</linearGradient>
|
||||||
|
<mask id="a">
|
||||||
|
<rect width="99" height="20" rx="3" fill="#fff"/>
|
||||||
|
</mask>
|
||||||
|
<g mask="url(#a)">
|
||||||
|
<path fill="#555" d="M0 0h63v20H0z"/>
|
||||||
|
<path fill="#fe7d37" d="M63 0h36v20H63z"/>
|
||||||
|
<path fill="url(#b)" d="M0 0h99v20H0z"/>
|
||||||
|
</g>
|
||||||
|
<g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="11">
|
||||||
|
<text x="31.5" y="15" fill="#010101" fill-opacity=".3">coverage</text>
|
||||||
|
<text x="31.5" y="14">coverage</text>
|
||||||
|
<text x="80" y="15" fill="#010101" fill-opacity=".3">44%</text>
|
||||||
|
<text x="80" y="14">44%</text>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 904 B |
|
|
@ -167,7 +167,7 @@ When using the `:` pattern (like `${env.OLLAMA_INFERENCE_MODEL:__disabled__}`),
|
||||||
|
|
||||||
## Running the Distribution
|
## Running the Distribution
|
||||||
|
|
||||||
You can run the starter distribution via Docker or Conda.
|
You can run the starter distribution via Docker, Conda, or venv.
|
||||||
|
|
||||||
### Via Docker
|
### Via Docker
|
||||||
|
|
||||||
|
|
@ -186,17 +186,12 @@ docker run \
|
||||||
--port $LLAMA_STACK_PORT
|
--port $LLAMA_STACK_PORT
|
||||||
```
|
```
|
||||||
|
|
||||||
### Via Conda
|
### Via Conda or venv
|
||||||
|
|
||||||
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
|
Ensure you have configured the starter distribution using the environment variables explained above.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama stack build --template starter --image-type conda
|
uv run --with llama-stack llama stack build --template starter --image-type <conda|venv> --run
|
||||||
llama stack run distributions/starter/run.yaml \
|
|
||||||
--port 8321 \
|
|
||||||
--env OPENAI_API_KEY=your_openai_key \
|
|
||||||
--env FIREWORKS_API_KEY=your_fireworks_key \
|
|
||||||
--env TOGETHER_API_KEY=your_together_key
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Example Usage
|
## Example Usage
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ ollama run llama3.2:3b --keepalive 60m
|
||||||
#### Step 2: Run the Llama Stack server
|
#### Step 2: Run the Llama Stack server
|
||||||
We will use `uv` to run the Llama Stack server.
|
We will use `uv` to run the Llama Stack server.
|
||||||
```bash
|
```bash
|
||||||
INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
|
ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
|
||||||
```
|
```
|
||||||
#### Step 3: Run the demo
|
#### Step 3: Run the demo
|
||||||
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
|
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
|
||||||
|
|
@ -111,6 +111,12 @@ Ultimately, great work is about making a meaningful contribution and leaving a l
|
||||||
```
|
```
|
||||||
Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳
|
Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳
|
||||||
|
|
||||||
|
```{admonition} HuggingFace access
|
||||||
|
:class: tip
|
||||||
|
|
||||||
|
If you are getting a **401 Client Error** from HuggingFace for the **all-MiniLM-L6-v2** model, try setting **HF_TOKEN** to a valid HuggingFace token in your environment
|
||||||
|
```
|
||||||
|
|
||||||
### Next Steps
|
### Next Steps
|
||||||
|
|
||||||
Now you're ready to dive deeper into Llama Stack!
|
Now you're ready to dive deeper into Llama Stack!
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import io
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
|
|
@ -184,16 +185,26 @@ class ChatFormat:
|
||||||
content = content[: -len("<|eom_id|>")]
|
content = content[: -len("<|eom_id|>")]
|
||||||
stop_reason = StopReason.end_of_message
|
stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
tool_name = None
|
tool_name: str | BuiltinTool | None = None
|
||||||
tool_arguments = {}
|
tool_arguments: dict[str, Any] = {}
|
||||||
|
|
||||||
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
||||||
if custom_tool_info is not None:
|
if custom_tool_info is not None:
|
||||||
tool_name, tool_arguments = custom_tool_info
|
# Type guard: ensure custom_tool_info is a tuple of correct types
|
||||||
|
if isinstance(custom_tool_info, tuple) and len(custom_tool_info) == 2:
|
||||||
|
extracted_tool_name, extracted_tool_arguments = custom_tool_info
|
||||||
|
# Handle both dict and str return types from the function
|
||||||
|
if isinstance(extracted_tool_arguments, dict):
|
||||||
|
tool_name, tool_arguments = extracted_tool_name, extracted_tool_arguments
|
||||||
|
else:
|
||||||
|
# If it's a string, treat it as a query parameter
|
||||||
|
tool_name, tool_arguments = extracted_tool_name, {"query": extracted_tool_arguments}
|
||||||
|
else:
|
||||||
|
tool_name, tool_arguments = None, {}
|
||||||
# Sometimes when agent has custom tools alongside builin tools
|
# Sometimes when agent has custom tools alongside builin tools
|
||||||
# Agent responds for builtin tool calls in the format of the custom tools
|
# Agent responds for builtin tool calls in the format of the custom tools
|
||||||
# This code tries to handle that case
|
# This code tries to handle that case
|
||||||
if tool_name in BuiltinTool.__members__:
|
if tool_name is not None and tool_name in BuiltinTool.__members__:
|
||||||
tool_name = BuiltinTool[tool_name]
|
tool_name = BuiltinTool[tool_name]
|
||||||
if isinstance(tool_arguments, dict):
|
if isinstance(tool_arguments, dict):
|
||||||
tool_arguments = {
|
tool_arguments = {
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ class ProcessingMessageWrapper(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
def mp_rank_0() -> bool:
|
def mp_rank_0() -> bool:
|
||||||
return get_model_parallel_rank() == 0
|
return bool(get_model_parallel_rank() == 0)
|
||||||
|
|
||||||
|
|
||||||
def encode_msg(msg: ProcessingMessage) -> bytes:
|
def encode_msg(msg: ProcessingMessage) -> bytes:
|
||||||
|
|
@ -125,7 +125,7 @@ def retrieve_requests(reply_socket_url: str):
|
||||||
reply_socket.send_multipart([client_id, encode_msg(obj)])
|
reply_socket.send_multipart([client_id, encode_msg(obj)])
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
tasks = [None]
|
tasks: list[ProcessingMessage | None] = [None]
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
client_id, maybe_task_json = maybe_get_work(reply_socket)
|
client_id, maybe_task_json = maybe_get_work(reply_socket)
|
||||||
if maybe_task_json is not None:
|
if maybe_task_json is not None:
|
||||||
|
|
@ -152,7 +152,7 @@ def retrieve_requests(reply_socket_url: str):
|
||||||
break
|
break
|
||||||
|
|
||||||
for obj in out:
|
for obj in out:
|
||||||
updates = [None]
|
updates: list[ProcessingMessage | None] = [None]
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
_, update_json = maybe_get_work(reply_socket)
|
_, update_json = maybe_get_work(reply_socket)
|
||||||
update = maybe_parse_message(update_json)
|
update = maybe_parse_message(update_json)
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,7 @@ unit = [
|
||||||
"pymilvus>=2.5.12",
|
"pymilvus>=2.5.12",
|
||||||
"litellm",
|
"litellm",
|
||||||
"together",
|
"together",
|
||||||
|
"coverage",
|
||||||
]
|
]
|
||||||
# These are the core dependencies required for running integration tests. They are shared across all
|
# These are the core dependencies required for running integration tests. They are shared across all
|
||||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||||
|
|
@ -242,7 +243,6 @@ exclude = [
|
||||||
"^llama_stack/distribution/store/registry\\.py$",
|
"^llama_stack/distribution/store/registry\\.py$",
|
||||||
"^llama_stack/distribution/utils/exec\\.py$",
|
"^llama_stack/distribution/utils/exec\\.py$",
|
||||||
"^llama_stack/distribution/utils/prompt_for_config\\.py$",
|
"^llama_stack/distribution/utils/prompt_for_config\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/chat_format\\.py$",
|
|
||||||
"^llama_stack/models/llama/llama3/interface\\.py$",
|
"^llama_stack/models/llama/llama3/interface\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||||
|
|
@ -255,7 +255,6 @@ exclude = [
|
||||||
"^llama_stack/models/llama/llama3/generation\\.py$",
|
"^llama_stack/models/llama/llama3/generation\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
||||||
"^llama_stack/models/llama/llama4/",
|
"^llama_stack/models/llama/llama4/",
|
||||||
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
|
|
||||||
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
|
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
|
||||||
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
||||||
"^llama_stack/providers/inline/inference/vllm/",
|
"^llama_stack/providers/inline/inference/vllm/",
|
||||||
|
|
|
||||||
|
|
@ -16,4 +16,9 @@ if [ $FOUND_PYTHON -ne 0 ]; then
|
||||||
uv python install "$PYTHON_VERSION"
|
uv python install "$PYTHON_VERSION"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest -s -v tests/unit/ $@
|
# Run unit tests with coverage
|
||||||
|
uv run --python "$PYTHON_VERSION" --with-editable . --group unit \
|
||||||
|
coverage run --source=llama_stack -m pytest -s -v tests/unit/ "$@"
|
||||||
|
|
||||||
|
# Generate HTML coverage report
|
||||||
|
uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION
|
||||||
|
|
|
||||||
|
|
@ -123,14 +123,14 @@ class TestPostTraining:
|
||||||
logger.info(f"Job artifacts: {artifacts}")
|
logger.info(f"Job artifacts: {artifacts}")
|
||||||
|
|
||||||
# TODO: Fix these tests to properly represent the Jobs API in training
|
# TODO: Fix these tests to properly represent the Jobs API in training
|
||||||
# @pytest.mark.asyncio
|
#
|
||||||
# async def test_get_training_jobs(self, post_training_stack):
|
# async def test_get_training_jobs(self, post_training_stack):
|
||||||
# post_training_impl = post_training_stack
|
# post_training_impl = post_training_stack
|
||||||
# jobs_list = await post_training_impl.get_training_jobs()
|
# jobs_list = await post_training_impl.get_training_jobs()
|
||||||
# assert isinstance(jobs_list, list)
|
# assert isinstance(jobs_list, list)
|
||||||
# assert jobs_list[0].job_uuid == "1234"
|
# assert jobs_list[0].job_uuid == "1234"
|
||||||
|
|
||||||
# @pytest.mark.asyncio
|
#
|
||||||
# async def test_get_training_job_status(self, post_training_stack):
|
# async def test_get_training_job_status(self, post_training_stack):
|
||||||
# post_training_impl = post_training_stack
|
# post_training_impl = post_training_stack
|
||||||
# job_status = await post_training_impl.get_training_job_status("1234")
|
# job_status = await post_training_impl.get_training_job_status("1234")
|
||||||
|
|
@ -139,7 +139,7 @@ class TestPostTraining:
|
||||||
# assert job_status.status == JobStatus.completed
|
# assert job_status.status == JobStatus.completed
|
||||||
# assert isinstance(job_status.checkpoints[0], Checkpoint)
|
# assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||||
|
|
||||||
# @pytest.mark.asyncio
|
#
|
||||||
# async def test_get_training_job_artifacts(self, post_training_stack):
|
# async def test_get_training_job_artifacts(self, post_training_stack):
|
||||||
# post_training_impl = post_training_stack
|
# post_training_impl = post_training_stack
|
||||||
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,17 @@
|
||||||
# Llama Stack Unit Tests
|
# Llama Stack Unit Tests
|
||||||
|
|
||||||
|
## Unit Tests
|
||||||
|
|
||||||
|
Unit tests verify individual components and functions in isolation. They are fast, reliable, and don't require external services.
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
1. **Python Environment**: Ensure you have Python 3.12+ installed
|
||||||
|
2. **uv Package Manager**: Install `uv` if not already installed
|
||||||
|
|
||||||
You can run the unit tests by running:
|
You can run the unit tests by running:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source .venv/bin/activate
|
|
||||||
./scripts/unit-tests.sh [PYTEST_ARGS]
|
./scripts/unit-tests.sh [PYTEST_ARGS]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -19,3 +27,21 @@ If you'd like to run for a non-default version of Python (currently 3.12), pass
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
|
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Test Configuration
|
||||||
|
|
||||||
|
- **Test Discovery**: Tests are automatically discovered in the `tests/unit/` directory
|
||||||
|
- **Async Support**: Tests use `--asyncio-mode=auto` for automatic async test handling
|
||||||
|
- **Coverage**: Tests generate coverage reports in `htmlcov/` directory
|
||||||
|
- **Python Version**: Defaults to Python 3.12, but can be overridden with `PYTHON_VERSION` environment variable
|
||||||
|
|
||||||
|
### Coverage Reports
|
||||||
|
|
||||||
|
After running tests, you can view coverage reports:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Open HTML coverage report in browser
|
||||||
|
open htmlcov/index.html # macOS
|
||||||
|
xdg-open htmlcov/index.html # Linux
|
||||||
|
start htmlcov/index.html # Windows
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
|
@ -32,7 +31,6 @@ MODEL = "Llama3.1-8B-Instruct"
|
||||||
MODEL3_2 = "Llama3.2-3B-Instruct"
|
MODEL3_2 = "Llama3.2-3B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_system_default():
|
async def test_system_default():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
|
@ -47,7 +45,6 @@ async def test_system_default():
|
||||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_system_builtin_only():
|
async def test_system_builtin_only():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
|
@ -67,7 +64,6 @@ async def test_system_builtin_only():
|
||||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_system_custom_only():
|
async def test_system_custom_only():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
|
@ -98,7 +94,6 @@ async def test_system_custom_only():
|
||||||
assert messages[-1].content == content
|
assert messages[-1].content == content
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_system_custom_and_builtin():
|
async def test_system_custom_and_builtin():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
|
@ -132,7 +127,6 @@ async def test_system_custom_and_builtin():
|
||||||
assert messages[-1].content == content
|
assert messages[-1].content == content
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_completion_message_encoding():
|
async def test_completion_message_encoding():
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=MODEL3_2,
|
model=MODEL3_2,
|
||||||
|
|
@ -174,7 +168,6 @@ async def test_completion_message_encoding():
|
||||||
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
|
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_user_provided_system_message():
|
async def test_user_provided_system_message():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
system_prompt = "You are a pirate"
|
system_prompt = "You are a pirate"
|
||||||
|
|
@ -195,7 +188,6 @@ async def test_user_provided_system_message():
|
||||||
assert messages[-1].content == content
|
assert messages[-1].content == content
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_replace_system_message_behavior_builtin_tools():
|
async def test_replace_system_message_behavior_builtin_tools():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
system_prompt = "You are a pirate"
|
system_prompt = "You are a pirate"
|
||||||
|
|
@ -221,7 +213,6 @@ async def test_replace_system_message_behavior_builtin_tools():
|
||||||
assert messages[-1].content == content
|
assert messages[-1].content == content
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_replace_system_message_behavior_custom_tools():
|
async def test_replace_system_message_behavior_custom_tools():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
system_prompt = "You are a pirate"
|
system_prompt = "You are a pirate"
|
||||||
|
|
@ -259,7 +250,6 @@ async def test_replace_system_message_behavior_custom_tools():
|
||||||
assert messages[-1].content == content
|
assert messages[-1].content == content
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_replace_system_message_behavior_custom_tools_with_template():
|
async def test_replace_system_message_behavior_custom_tools_with_template():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
system_prompt = "You are a pirate {{ function_description }}"
|
system_prompt = "You are a pirate {{ function_description }}"
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
import unittest
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||||
|
|
@ -24,59 +23,61 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplateTests(unittest.TestCase):
|
def check_generator_output(generator):
|
||||||
def check_generator_output(self, generator):
|
for example in generator.data_examples():
|
||||||
for example in generator.data_examples():
|
pt = generator.gen(example)
|
||||||
pt = generator.gen(example)
|
|
||||||
text = pt.render()
|
|
||||||
# print(text) # debugging
|
|
||||||
if not example:
|
|
||||||
continue
|
|
||||||
for tool in example:
|
|
||||||
assert tool.tool_name in text
|
|
||||||
|
|
||||||
def test_system_default(self):
|
|
||||||
generator = SystemDefaultGenerator()
|
|
||||||
today = datetime.now().strftime("%d %B %Y")
|
|
||||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
|
||||||
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
|
||||||
|
|
||||||
def test_system_builtin_only(self):
|
|
||||||
generator = BuiltinToolGenerator()
|
|
||||||
expected_text = textwrap.dedent(
|
|
||||||
"""
|
|
||||||
Environment: ipython
|
|
||||||
Tools: brave_search, wolfram_alpha
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
|
||||||
|
|
||||||
def test_system_custom_only(self):
|
|
||||||
self.maxDiff = None
|
|
||||||
generator = JsonCustomToolGenerator()
|
|
||||||
self.check_generator_output(generator)
|
|
||||||
|
|
||||||
def test_system_custom_function_tag(self):
|
|
||||||
self.maxDiff = None
|
|
||||||
generator = FunctionTagCustomToolGenerator()
|
|
||||||
self.check_generator_output(generator)
|
|
||||||
|
|
||||||
def test_llama_3_2_system_zero_shot(self):
|
|
||||||
generator = PythonListCustomToolGenerator()
|
|
||||||
self.check_generator_output(generator)
|
|
||||||
|
|
||||||
def test_llama_3_2_provided_system_prompt(self):
|
|
||||||
generator = PythonListCustomToolGenerator()
|
|
||||||
user_system_prompt = textwrap.dedent(
|
|
||||||
"""
|
|
||||||
Overriding message.
|
|
||||||
|
|
||||||
{{ function_description }}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
example = generator.data_examples()[0]
|
|
||||||
|
|
||||||
pt = generator.gen(example, user_system_prompt)
|
|
||||||
text = pt.render()
|
text = pt.render()
|
||||||
assert "Overriding message." in text
|
if not example:
|
||||||
assert '"name": "get_weather"' in text
|
continue
|
||||||
|
for tool in example:
|
||||||
|
assert tool.tool_name in text
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_default():
|
||||||
|
generator = SystemDefaultGenerator()
|
||||||
|
today = datetime.now().strftime("%d %B %Y")
|
||||||
|
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||||
|
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_builtin_only():
|
||||||
|
generator = BuiltinToolGenerator()
|
||||||
|
expected_text = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Environment: ipython
|
||||||
|
Tools: brave_search, wolfram_alpha
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_custom_only():
|
||||||
|
generator = JsonCustomToolGenerator()
|
||||||
|
check_generator_output(generator)
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_custom_function_tag():
|
||||||
|
generator = FunctionTagCustomToolGenerator()
|
||||||
|
check_generator_output(generator)
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama_3_2_system_zero_shot():
|
||||||
|
generator = PythonListCustomToolGenerator()
|
||||||
|
check_generator_output(generator)
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama_3_2_provided_system_prompt():
|
||||||
|
generator = PythonListCustomToolGenerator()
|
||||||
|
user_system_prompt = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Overriding message.
|
||||||
|
|
||||||
|
{{ function_description }}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
example = generator.data_examples()[0]
|
||||||
|
|
||||||
|
pt = generator.gen(example, user_system_prompt)
|
||||||
|
text = pt.render()
|
||||||
|
assert "Overriding message." in text
|
||||||
|
assert '"name": "get_weather"' in text
|
||||||
|
|
|
||||||
|
|
@ -5,103 +5,110 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource
|
from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig
|
from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig
|
||||||
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
|
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
|
||||||
|
|
||||||
|
|
||||||
class TestNvidiaDatastore(unittest.TestCase):
|
@pytest.fixture
|
||||||
def setUp(self):
|
def nvidia_adapter():
|
||||||
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
|
"""Fixture to set up NvidiaDatasetIOAdapter with mocked requests."""
|
||||||
|
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
|
||||||
|
|
||||||
config = NvidiaDatasetIOConfig(
|
config = NvidiaDatasetIOConfig(
|
||||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
|
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
|
||||||
)
|
)
|
||||||
self.adapter = NvidiaDatasetIOAdapter(config)
|
adapter = NvidiaDatasetIOAdapter(config)
|
||||||
self.make_request_patcher = patch(
|
|
||||||
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
|
|
||||||
)
|
|
||||||
self.mock_make_request = self.make_request_patcher.start()
|
|
||||||
|
|
||||||
def tearDown(self):
|
with patch(
|
||||||
self.make_request_patcher.stop()
|
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
|
||||||
|
) as mock_make_request:
|
||||||
|
yield adapter, mock_make_request
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def inject_fixtures(self, run_async):
|
|
||||||
self.run_async = run_async
|
|
||||||
|
|
||||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None):
|
def _assert_request(mock_call, expected_method, expected_path, expected_json=None):
|
||||||
"""Helper method to verify request details in mock calls."""
|
"""Helper function to verify request details in mock calls."""
|
||||||
call_args = mock_call.call_args
|
call_args = mock_call.call_args
|
||||||
|
|
||||||
assert call_args[0][0] == expected_method
|
assert call_args[0][0] == expected_method
|
||||||
assert call_args[0][1] == expected_path
|
assert call_args[0][1] == expected_path
|
||||||
|
|
||||||
if expected_json:
|
if expected_json:
|
||||||
for key, value in expected_json.items():
|
for key, value in expected_json.items():
|
||||||
assert call_args[1]["json"][key] == value
|
assert call_args[1]["json"][key] == value
|
||||||
|
|
||||||
def test_register_dataset(self):
|
|
||||||
self.mock_make_request.return_value = {
|
def test_register_dataset(nvidia_adapter, run_async):
|
||||||
"id": "dataset-123456",
|
adapter, mock_make_request = nvidia_adapter
|
||||||
|
mock_make_request.return_value = {
|
||||||
|
"id": "dataset-123456",
|
||||||
|
"name": "test-dataset",
|
||||||
|
"namespace": "default",
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset_def = Dataset(
|
||||||
|
identifier="test-dataset",
|
||||||
|
type=ResourceType.dataset,
|
||||||
|
provider_resource_id="",
|
||||||
|
provider_id="",
|
||||||
|
purpose=DatasetPurpose.post_training_messages,
|
||||||
|
source=URIDataSource(uri="https://example.com/data.jsonl"),
|
||||||
|
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
|
||||||
|
)
|
||||||
|
|
||||||
|
run_async(adapter.register_dataset(dataset_def))
|
||||||
|
|
||||||
|
mock_make_request.assert_called_once()
|
||||||
|
_assert_request(
|
||||||
|
mock_make_request,
|
||||||
|
"POST",
|
||||||
|
"/v1/datasets",
|
||||||
|
expected_json={
|
||||||
"name": "test-dataset",
|
"name": "test-dataset",
|
||||||
"namespace": "default",
|
"namespace": "default",
|
||||||
}
|
"files_url": "https://example.com/data.jsonl",
|
||||||
|
"project": "default",
|
||||||
|
"format": "jsonl",
|
||||||
|
"description": "Test dataset description",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
dataset_def = Dataset(
|
|
||||||
identifier="test-dataset",
|
|
||||||
type="dataset",
|
|
||||||
provider_resource_id="",
|
|
||||||
provider_id="",
|
|
||||||
purpose=DatasetPurpose.post_training_messages,
|
|
||||||
source=URIDataSource(uri="https://example.com/data.jsonl"),
|
|
||||||
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
|
|
||||||
)
|
|
||||||
|
|
||||||
self.run_async(self.adapter.register_dataset(dataset_def))
|
def test_unregister_dataset(nvidia_adapter, run_async):
|
||||||
|
adapter, mock_make_request = nvidia_adapter
|
||||||
|
mock_make_request.return_value = {
|
||||||
|
"message": "Resource deleted successfully.",
|
||||||
|
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
|
||||||
|
"deleted_at": None,
|
||||||
|
}
|
||||||
|
dataset_id = "test-dataset"
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
run_async(adapter.unregister_dataset(dataset_id))
|
||||||
self._assert_request(
|
|
||||||
self.mock_make_request,
|
|
||||||
"POST",
|
|
||||||
"/v1/datasets",
|
|
||||||
expected_json={
|
|
||||||
"name": "test-dataset",
|
|
||||||
"namespace": "default",
|
|
||||||
"files_url": "https://example.com/data.jsonl",
|
|
||||||
"project": "default",
|
|
||||||
"format": "jsonl",
|
|
||||||
"description": "Test dataset description",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_unregister_dataset(self):
|
mock_make_request.assert_called_once()
|
||||||
self.mock_make_request.return_value = {
|
_assert_request(mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
|
||||||
"message": "Resource deleted successfully.",
|
|
||||||
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
|
|
||||||
"deleted_at": None,
|
|
||||||
}
|
|
||||||
dataset_id = "test-dataset"
|
|
||||||
|
|
||||||
self.run_async(self.adapter.unregister_dataset(dataset_id))
|
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
def test_register_dataset_with_custom_namespace_project(run_async):
|
||||||
self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
|
"""Test with custom namespace and project configuration."""
|
||||||
|
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
|
||||||
|
|
||||||
def test_register_dataset_with_custom_namespace_project(self):
|
custom_config = NvidiaDatasetIOConfig(
|
||||||
custom_config = NvidiaDatasetIOConfig(
|
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
|
||||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
|
dataset_namespace="custom-namespace",
|
||||||
dataset_namespace="custom-namespace",
|
project_id="custom-project",
|
||||||
project_id="custom-project",
|
)
|
||||||
)
|
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
|
||||||
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
|
|
||||||
|
|
||||||
self.mock_make_request.return_value = {
|
with patch(
|
||||||
|
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
|
||||||
|
) as mock_make_request:
|
||||||
|
mock_make_request.return_value = {
|
||||||
"id": "dataset-123456",
|
"id": "dataset-123456",
|
||||||
"name": "test-dataset",
|
"name": "test-dataset",
|
||||||
"namespace": "custom-namespace",
|
"namespace": "custom-namespace",
|
||||||
|
|
@ -109,7 +116,7 @@ class TestNvidiaDatastore(unittest.TestCase):
|
||||||
|
|
||||||
dataset_def = Dataset(
|
dataset_def = Dataset(
|
||||||
identifier="test-dataset",
|
identifier="test-dataset",
|
||||||
type="dataset",
|
type=ResourceType.dataset,
|
||||||
provider_resource_id="",
|
provider_resource_id="",
|
||||||
provider_id="",
|
provider_id="",
|
||||||
purpose=DatasetPurpose.post_training_messages,
|
purpose=DatasetPurpose.post_training_messages,
|
||||||
|
|
@ -117,11 +124,11 @@ class TestNvidiaDatastore(unittest.TestCase):
|
||||||
metadata={"format": "jsonl"},
|
metadata={"format": "jsonl"},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.run_async(custom_adapter.register_dataset(dataset_def))
|
run_async(custom_adapter.register_dataset(dataset_def))
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
mock_make_request.assert_called_once()
|
||||||
self._assert_request(
|
_assert_request(
|
||||||
self.mock_make_request,
|
mock_make_request,
|
||||||
"POST",
|
"POST",
|
||||||
"/v1/datasets",
|
"/v1/datasets",
|
||||||
expected_json={
|
expected_json={
|
||||||
|
|
@ -132,7 +139,3 @@ class TestNvidiaDatastore(unittest.TestCase):
|
||||||
"format": "jsonl",
|
"format": "jsonl",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
import warnings
|
import warnings
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
@ -27,14 +26,13 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestNvidiaParameters(unittest.TestCase):
|
class TestNvidiaParameters:
|
||||||
def setUp(self):
|
@pytest.fixture(autouse=True)
|
||||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
|
def setup_and_teardown(self):
|
||||||
|
"""Setup and teardown for each test method."""
|
||||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||||
|
|
||||||
config = NvidiaPostTrainingConfig(
|
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
|
||||||
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
|
||||||
)
|
|
||||||
self.adapter = NvidiaPostTrainingAdapter(config)
|
self.adapter = NvidiaPostTrainingAdapter(config)
|
||||||
|
|
||||||
self.make_request_patcher = patch(
|
self.make_request_patcher = patch(
|
||||||
|
|
@ -48,7 +46,8 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
"updated_at": "2025-03-04T13:07:47.543605",
|
"updated_at": "2025-03-04T13:07:47.543605",
|
||||||
}
|
}
|
||||||
|
|
||||||
def tearDown(self):
|
yield
|
||||||
|
|
||||||
self.make_request_patcher.stop()
|
self.make_request_patcher.stop()
|
||||||
|
|
||||||
def _assert_request_params(self, expected_json):
|
def _assert_request_params(self, expected_json):
|
||||||
|
|
@ -166,8 +165,8 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
|
|
||||||
self.run_async(
|
self.run_async(
|
||||||
self.adapter.supervised_fine_tune(
|
self.adapter.supervised_fine_tune(
|
||||||
job_uuid=required_job_uuid, # Required parameter
|
job_uuid=required_job_uuid,
|
||||||
model=required_model, # Required parameter
|
model=required_model,
|
||||||
checkpoint_dir="",
|
checkpoint_dir="",
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=convert_pydantic_to_json_value(training_config),
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
|
|
@ -198,7 +197,6 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
data_config = DataConfig(
|
data_config = DataConfig(
|
||||||
dataset_id="test-dataset",
|
dataset_id="test-dataset",
|
||||||
batch_size=8,
|
batch_size=8,
|
||||||
# Unsupported parameters
|
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
data_format=DatasetFormat.instruct,
|
data_format=DatasetFormat.instruct,
|
||||||
validation_dataset_id="val-dataset",
|
validation_dataset_id="val-dataset",
|
||||||
|
|
@ -207,20 +205,16 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
optimizer_config = OptimizerConfig(
|
optimizer_config = OptimizerConfig(
|
||||||
lr=0.0001,
|
lr=0.0001,
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
# Unsupported parameters
|
|
||||||
optimizer_type=OptimizerType.adam,
|
optimizer_type=OptimizerType.adam,
|
||||||
num_warmup_steps=100,
|
num_warmup_steps=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
efficiency_config = EfficiencyConfig(
|
efficiency_config = EfficiencyConfig(enable_activation_checkpointing=True)
|
||||||
enable_activation_checkpointing=True # Unsupported parameter
|
|
||||||
)
|
|
||||||
|
|
||||||
training_config = TrainingConfig(
|
training_config = TrainingConfig(
|
||||||
n_epochs=1,
|
n_epochs=1,
|
||||||
data_config=data_config,
|
data_config=data_config,
|
||||||
optimizer_config=optimizer_config,
|
optimizer_config=optimizer_config,
|
||||||
# Unsupported parameters
|
|
||||||
efficiency_config=efficiency_config,
|
efficiency_config=efficiency_config,
|
||||||
max_steps_per_epoch=1000,
|
max_steps_per_epoch=1000,
|
||||||
gradient_accumulation_steps=4,
|
gradient_accumulation_steps=4,
|
||||||
|
|
@ -228,7 +222,6 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
dtype="bf16",
|
dtype="bf16",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Capture warnings
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
|
|
||||||
|
|
@ -236,7 +229,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
self.adapter.supervised_fine_tune(
|
self.adapter.supervised_fine_tune(
|
||||||
job_uuid="test-job",
|
job_uuid="test-job",
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
checkpoint_dir="test-dir", # Unsupported parameter
|
checkpoint_dir="test-dir",
|
||||||
algorithm_config=LoraFinetuningConfig(
|
algorithm_config=LoraFinetuningConfig(
|
||||||
type="LoRA",
|
type="LoRA",
|
||||||
apply_lora_to_mlp=True,
|
apply_lora_to_mlp=True,
|
||||||
|
|
@ -246,8 +239,8 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
),
|
),
|
||||||
training_config=convert_pydantic_to_json_value(training_config),
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
logger_config={"test": "value"}, # Unsupported parameter
|
logger_config={"test": "value"},
|
||||||
hyperparam_search_config={"test": "value"}, # Unsupported parameter
|
hyperparam_search_config={"test": "value"},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -265,7 +258,6 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
"gradient_accumulation_steps",
|
"gradient_accumulation_steps",
|
||||||
"max_validation_steps",
|
"max_validation_steps",
|
||||||
"dtype",
|
"dtype",
|
||||||
# required unsupported parameters
|
|
||||||
"rank",
|
"rank",
|
||||||
"apply_lora_to_output",
|
"apply_lora_to_output",
|
||||||
"lora_attn_modules",
|
"lora_attn_modules",
|
||||||
|
|
@ -273,7 +265,3 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
]
|
]
|
||||||
for field in fields:
|
for field in fields:
|
||||||
assert any(field in text for text in warning_texts)
|
assert any(field in text for text in warning_texts)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|
|
||||||
|
|
@ -5,13 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
import warnings
|
import warnings
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.models import Model, ModelType
|
|
||||||
from llama_stack.apis.post_training.post_training import (
|
from llama_stack.apis.post_training.post_training import (
|
||||||
DataConfig,
|
DataConfig,
|
||||||
DatasetFormat,
|
DatasetFormat,
|
||||||
|
|
@ -22,7 +20,6 @@ from llama_stack.apis.post_training.post_training import (
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
|
|
||||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
ListNvidiaPostTrainingJobs,
|
ListNvidiaPostTrainingJobs,
|
||||||
NvidiaPostTrainingAdapter,
|
NvidiaPostTrainingAdapter,
|
||||||
|
|
@ -32,336 +29,297 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestNvidiaPostTraining(unittest.TestCase):
|
@pytest.fixture
|
||||||
def setUp(self):
|
def nvidia_post_training_adapter():
|
||||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
|
"""Fixture to create and configure the NVIDIA post training adapter."""
|
||||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||||
|
|
||||||
config = NvidiaPostTrainingConfig(
|
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
|
||||||
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
adapter = NvidiaPostTrainingAdapter(config)
|
||||||
|
|
||||||
|
with patch.object(adapter, "_make_request") as mock_make_request:
|
||||||
|
yield adapter, mock_make_request
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_request(mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||||
|
"""Helper method to verify request details in mock calls."""
|
||||||
|
call_args = mock_call.call_args
|
||||||
|
|
||||||
|
if expected_method and expected_path:
|
||||||
|
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
||||||
|
assert call_args[0] == (expected_method, expected_path)
|
||||||
|
else:
|
||||||
|
assert call_args[1]["method"] == expected_method
|
||||||
|
assert call_args[1]["path"] == expected_path
|
||||||
|
|
||||||
|
if expected_params:
|
||||||
|
assert call_args[1]["params"] == expected_params
|
||||||
|
|
||||||
|
if expected_json:
|
||||||
|
for key, value in expected_json.items():
|
||||||
|
assert call_args[1]["json"][key] == value
|
||||||
|
|
||||||
|
|
||||||
|
async def test_supervised_fine_tune(nvidia_post_training_adapter):
|
||||||
|
"""Test the supervised fine-tuning API call."""
|
||||||
|
adapter, mock_make_request = nvidia_post_training_adapter
|
||||||
|
mock_make_request.return_value = {
|
||||||
|
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"config": {
|
||||||
|
"schema_version": "1.0",
|
||||||
|
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542657",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.569837",
|
||||||
|
"custom_fields": {},
|
||||||
|
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"model_path": "llama-3_1-8b-instruct",
|
||||||
|
"training_types": [],
|
||||||
|
"finetuning_types": ["lora"],
|
||||||
|
"precision": "bf16",
|
||||||
|
"num_gpus": 4,
|
||||||
|
"num_nodes": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"tensor_parallel_size": 1,
|
||||||
|
"max_seq_length": 4096,
|
||||||
|
},
|
||||||
|
"dataset": {
|
||||||
|
"schema_version": "1.0",
|
||||||
|
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542657",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.542660",
|
||||||
|
"custom_fields": {},
|
||||||
|
"name": "sample-basic-test",
|
||||||
|
"version_id": "main",
|
||||||
|
"version_tags": [],
|
||||||
|
},
|
||||||
|
"hyperparameters": {
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"training_type": "sft",
|
||||||
|
"batch_size": 16,
|
||||||
|
"epochs": 2,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"lora": {"alpha": 16},
|
||||||
|
},
|
||||||
|
"output_model": "default/job-1234",
|
||||||
|
"status": "created",
|
||||||
|
"project": "default",
|
||||||
|
"custom_fields": {},
|
||||||
|
"ownership": {"created_by": "me", "access_policies": {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
algorithm_config = LoraFinetuningConfig(
|
||||||
|
type="LoRA",
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=True,
|
||||||
|
alpha=16,
|
||||||
|
rank=16,
|
||||||
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
|
)
|
||||||
|
|
||||||
|
data_config = DataConfig(
|
||||||
|
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
optimizer_type=OptimizerType.adam,
|
||||||
|
lr=0.0001,
|
||||||
|
weight_decay=0.01,
|
||||||
|
num_warmup_steps=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = TrainingConfig(
|
||||||
|
n_epochs=2,
|
||||||
|
data_config=data_config,
|
||||||
|
optimizer_config=optimizer_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True):
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
training_job = await adapter.supervised_fine_tune(
|
||||||
|
job_uuid="1234",
|
||||||
|
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||||
|
checkpoint_dir="",
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
|
logger_config={},
|
||||||
|
hyperparam_search_config={},
|
||||||
)
|
)
|
||||||
self.adapter = NvidiaPostTrainingAdapter(config)
|
|
||||||
self.make_request_patcher = patch(
|
|
||||||
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
|
||||||
)
|
|
||||||
self.mock_make_request = self.make_request_patcher.start()
|
|
||||||
|
|
||||||
# Mock the inference client
|
# check the output is a PostTrainingJob
|
||||||
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
|
assert isinstance(training_job, NvidiaPostTrainingJob)
|
||||||
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
|
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
|
|
||||||
self.mock_client = unittest.mock.MagicMock()
|
mock_make_request.assert_called_once()
|
||||||
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
_assert_request(
|
||||||
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
mock_make_request,
|
||||||
self.inference_make_request_patcher = patch(
|
"POST",
|
||||||
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._client",
|
"/v1/customization/jobs",
|
||||||
new_callable=unittest.mock.PropertyMock,
|
expected_json={
|
||||||
return_value=self.mock_client,
|
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||||
)
|
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||||
self.inference_make_request_patcher.start()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.make_request_patcher.stop()
|
|
||||||
self.inference_make_request_patcher.stop()
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def inject_fixtures(self, run_async):
|
|
||||||
self.run_async = run_async
|
|
||||||
|
|
||||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
|
||||||
"""Helper method to verify request details in mock calls."""
|
|
||||||
call_args = mock_call.call_args
|
|
||||||
|
|
||||||
if expected_method and expected_path:
|
|
||||||
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
|
||||||
assert call_args[0] == (expected_method, expected_path)
|
|
||||||
else:
|
|
||||||
assert call_args[1]["method"] == expected_method
|
|
||||||
assert call_args[1]["path"] == expected_path
|
|
||||||
|
|
||||||
if expected_params:
|
|
||||||
assert call_args[1]["params"] == expected_params
|
|
||||||
|
|
||||||
if expected_json:
|
|
||||||
for key, value in expected_json.items():
|
|
||||||
assert call_args[1]["json"][key] == value
|
|
||||||
|
|
||||||
def test_supervised_fine_tune(self):
|
|
||||||
"""Test the supervised fine-tuning API call."""
|
|
||||||
self.mock_make_request.return_value = {
|
|
||||||
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
|
||||||
"created_at": "2024-12-09T04:06:28.542884",
|
|
||||||
"updated_at": "2024-12-09T04:06:28.542884",
|
|
||||||
"config": {
|
|
||||||
"schema_version": "1.0",
|
|
||||||
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
|
|
||||||
"created_at": "2024-12-09T04:06:28.542657",
|
|
||||||
"updated_at": "2024-12-09T04:06:28.569837",
|
|
||||||
"custom_fields": {},
|
|
||||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
"model_path": "llama-3_1-8b-instruct",
|
|
||||||
"training_types": [],
|
|
||||||
"finetuning_types": ["lora"],
|
|
||||||
"precision": "bf16",
|
|
||||||
"num_gpus": 4,
|
|
||||||
"num_nodes": 1,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"tensor_parallel_size": 1,
|
|
||||||
"max_seq_length": 4096,
|
|
||||||
},
|
|
||||||
"dataset": {
|
|
||||||
"schema_version": "1.0",
|
|
||||||
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
|
|
||||||
"created_at": "2024-12-09T04:06:28.542657",
|
|
||||||
"updated_at": "2024-12-09T04:06:28.542660",
|
|
||||||
"custom_fields": {},
|
|
||||||
"name": "sample-basic-test",
|
|
||||||
"version_id": "main",
|
|
||||||
"version_tags": [],
|
|
||||||
},
|
|
||||||
"hyperparameters": {
|
"hyperparameters": {
|
||||||
"finetuning_type": "lora",
|
|
||||||
"training_type": "sft",
|
"training_type": "sft",
|
||||||
"batch_size": 16,
|
"finetuning_type": "lora",
|
||||||
"epochs": 2,
|
"epochs": 2,
|
||||||
|
"batch_size": 16,
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 0.0001,
|
||||||
|
"weight_decay": 0.01,
|
||||||
"lora": {"alpha": 16},
|
"lora": {"alpha": 16},
|
||||||
},
|
},
|
||||||
"output_model": "default/job-1234",
|
},
|
||||||
"status": "created",
|
)
|
||||||
"project": "default",
|
|
||||||
"custom_fields": {},
|
|
||||||
"ownership": {"created_by": "me", "access_policies": {}},
|
async def test_supervised_fine_tune_with_qat(nvidia_post_training_adapter):
|
||||||
|
"""Test that QAT configuration raises NotImplementedError."""
|
||||||
|
adapter, mock_make_request = nvidia_post_training_adapter
|
||||||
|
|
||||||
|
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||||
|
data_config = DataConfig(
|
||||||
|
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||||
|
)
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
optimizer_type=OptimizerType.adam,
|
||||||
|
lr=0.0001,
|
||||||
|
weight_decay=0.01,
|
||||||
|
num_warmup_steps=100,
|
||||||
|
)
|
||||||
|
training_config = TrainingConfig(
|
||||||
|
n_epochs=2,
|
||||||
|
data_config=data_config,
|
||||||
|
optimizer_config=optimizer_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This will raise NotImplementedError since QAT is not supported
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
await adapter.supervised_fine_tune(
|
||||||
|
job_uuid="1234",
|
||||||
|
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||||
|
checkpoint_dir="",
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
|
logger_config={},
|
||||||
|
hyperparam_search_config={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_training_job_status(nvidia_post_training_adapter):
|
||||||
|
"""Test getting training job status with different statuses."""
|
||||||
|
adapter, mock_make_request = nvidia_post_training_adapter
|
||||||
|
|
||||||
|
customizer_status_to_job_status = [
|
||||||
|
("running", "in_progress"),
|
||||||
|
("completed", "completed"),
|
||||||
|
("failed", "failed"),
|
||||||
|
("cancelled", "cancelled"),
|
||||||
|
("pending", "scheduled"),
|
||||||
|
("unknown", "scheduled"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for customizer_status, expected_status in customizer_status_to_job_status:
|
||||||
|
mock_make_request.return_value = {
|
||||||
|
"created_at": "2024-12-09T04:06:28.580220",
|
||||||
|
"updated_at": "2024-12-09T04:21:19.852832",
|
||||||
|
"status": customizer_status,
|
||||||
|
"steps_completed": 1210,
|
||||||
|
"epochs_completed": 2,
|
||||||
|
"percentage_done": 100.0,
|
||||||
|
"best_epoch": 2,
|
||||||
|
"train_loss": 1.718016266822815,
|
||||||
|
"val_loss": 1.8661999702453613,
|
||||||
}
|
}
|
||||||
|
|
||||||
algorithm_config = LoraFinetuningConfig(
|
|
||||||
type="LoRA",
|
|
||||||
apply_lora_to_mlp=True,
|
|
||||||
apply_lora_to_output=True,
|
|
||||||
alpha=16,
|
|
||||||
rank=16,
|
|
||||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
|
||||||
)
|
|
||||||
|
|
||||||
data_config = DataConfig(
|
|
||||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_config = OptimizerConfig(
|
|
||||||
optimizer_type=OptimizerType.adam,
|
|
||||||
lr=0.0001,
|
|
||||||
weight_decay=0.01,
|
|
||||||
num_warmup_steps=100,
|
|
||||||
)
|
|
||||||
|
|
||||||
training_config = TrainingConfig(
|
|
||||||
n_epochs=2,
|
|
||||||
data_config=data_config,
|
|
||||||
optimizer_config=optimizer_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True):
|
|
||||||
warnings.simplefilter("always")
|
|
||||||
training_job = self.run_async(
|
|
||||||
self.adapter.supervised_fine_tune(
|
|
||||||
job_uuid="1234",
|
|
||||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
|
||||||
checkpoint_dir="",
|
|
||||||
algorithm_config=algorithm_config,
|
|
||||||
training_config=convert_pydantic_to_json_value(training_config),
|
|
||||||
logger_config={},
|
|
||||||
hyperparam_search_config={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# check the output is a PostTrainingJob
|
|
||||||
assert isinstance(training_job, NvidiaPostTrainingJob)
|
|
||||||
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
|
||||||
self._assert_request(
|
|
||||||
self.mock_make_request,
|
|
||||||
"POST",
|
|
||||||
"/v1/customization/jobs",
|
|
||||||
expected_json={
|
|
||||||
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
|
||||||
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
|
||||||
"hyperparameters": {
|
|
||||||
"training_type": "sft",
|
|
||||||
"finetuning_type": "lora",
|
|
||||||
"epochs": 2,
|
|
||||||
"batch_size": 16,
|
|
||||||
"learning_rate": 0.0001,
|
|
||||||
"weight_decay": 0.01,
|
|
||||||
"lora": {"alpha": 16},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_supervised_fine_tune_with_qat(self):
|
|
||||||
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
|
||||||
data_config = DataConfig(
|
|
||||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
|
||||||
)
|
|
||||||
optimizer_config = OptimizerConfig(
|
|
||||||
optimizer_type=OptimizerType.adam,
|
|
||||||
lr=0.0001,
|
|
||||||
weight_decay=0.01,
|
|
||||||
num_warmup_steps=100,
|
|
||||||
)
|
|
||||||
training_config = TrainingConfig(
|
|
||||||
n_epochs=2,
|
|
||||||
data_config=data_config,
|
|
||||||
optimizer_config=optimizer_config,
|
|
||||||
)
|
|
||||||
# This will raise NotImplementedError since QAT is not supported
|
|
||||||
with self.assertRaises(NotImplementedError):
|
|
||||||
self.run_async(
|
|
||||||
self.adapter.supervised_fine_tune(
|
|
||||||
job_uuid="1234",
|
|
||||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
|
||||||
checkpoint_dir="",
|
|
||||||
algorithm_config=algorithm_config,
|
|
||||||
training_config=convert_pydantic_to_json_value(training_config),
|
|
||||||
logger_config={},
|
|
||||||
hyperparam_search_config={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_get_training_job_status(self):
|
|
||||||
customizer_status_to_job_status = [
|
|
||||||
("running", "in_progress"),
|
|
||||||
("completed", "completed"),
|
|
||||||
("failed", "failed"),
|
|
||||||
("cancelled", "cancelled"),
|
|
||||||
("pending", "scheduled"),
|
|
||||||
("unknown", "scheduled"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for customizer_status, expected_status in customizer_status_to_job_status:
|
|
||||||
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
|
|
||||||
self.mock_make_request.return_value = {
|
|
||||||
"created_at": "2024-12-09T04:06:28.580220",
|
|
||||||
"updated_at": "2024-12-09T04:21:19.852832",
|
|
||||||
"status": customizer_status,
|
|
||||||
"steps_completed": 1210,
|
|
||||||
"epochs_completed": 2,
|
|
||||||
"percentage_done": 100.0,
|
|
||||||
"best_epoch": 2,
|
|
||||||
"train_loss": 1.718016266822815,
|
|
||||||
"val_loss": 1.8661999702453613,
|
|
||||||
}
|
|
||||||
|
|
||||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
|
||||||
|
|
||||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
|
||||||
|
|
||||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
|
||||||
assert status.status.value == expected_status
|
|
||||||
assert status.steps_completed == 1210
|
|
||||||
assert status.epochs_completed == 2
|
|
||||||
assert status.percentage_done == 100.0
|
|
||||||
assert status.best_epoch == 2
|
|
||||||
assert status.train_loss == 1.718016266822815
|
|
||||||
assert status.val_loss == 1.8661999702453613
|
|
||||||
|
|
||||||
self._assert_request(
|
|
||||||
self.mock_make_request,
|
|
||||||
"GET",
|
|
||||||
f"/v1/customization/jobs/{job_id}/status",
|
|
||||||
expected_params={"job_id": job_id},
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_get_training_jobs(self):
|
|
||||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
self.mock_make_request.return_value = {
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"id": job_id,
|
|
||||||
"created_at": "2024-12-09T04:06:28.542884",
|
|
||||||
"updated_at": "2024-12-09T04:21:19.852832",
|
|
||||||
"config": {
|
|
||||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
},
|
|
||||||
"dataset": {"name": "default/sample-basic-test"},
|
|
||||||
"hyperparameters": {
|
|
||||||
"finetuning_type": "lora",
|
|
||||||
"training_type": "sft",
|
|
||||||
"batch_size": 16,
|
|
||||||
"epochs": 2,
|
|
||||||
"learning_rate": 0.0001,
|
|
||||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
|
||||||
},
|
|
||||||
"output_model": "default/job-1234",
|
|
||||||
"status": "completed",
|
|
||||||
"project": "default",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
jobs = self.run_async(self.adapter.get_training_jobs())
|
status = await adapter.get_training_job_status(job_uuid=job_id)
|
||||||
|
|
||||||
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
|
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||||
assert len(jobs.data) == 1
|
assert status.status.value == expected_status
|
||||||
job = jobs.data[0]
|
# Note: The response object inherits extra fields via ConfigDict(extra="allow")
|
||||||
assert job.job_uuid == job_id
|
# So these attributes should be accessible using getattr with defaults
|
||||||
assert job.status.value == "completed"
|
assert getattr(status, "steps_completed", None) == 1210
|
||||||
|
assert getattr(status, "epochs_completed", None) == 2
|
||||||
|
assert getattr(status, "percentage_done", None) == 100.0
|
||||||
|
assert getattr(status, "best_epoch", None) == 2
|
||||||
|
assert getattr(status, "train_loss", None) == 1.718016266822815
|
||||||
|
assert getattr(status, "val_loss", None) == 1.8661999702453613
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
_assert_request(
|
||||||
self._assert_request(
|
mock_make_request,
|
||||||
self.mock_make_request,
|
|
||||||
"GET",
|
"GET",
|
||||||
"/v1/customization/jobs",
|
f"/v1/customization/jobs/{job_id}/status",
|
||||||
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_cancel_training_job(self):
|
|
||||||
self.mock_make_request.return_value = {} # Empty response for successful cancellation
|
|
||||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
|
||||||
|
|
||||||
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
|
||||||
self._assert_request(
|
|
||||||
self.mock_make_request,
|
|
||||||
"POST",
|
|
||||||
f"/v1/customization/jobs/{job_id}/cancel",
|
|
||||||
expected_params={"job_id": job_id},
|
expected_params={"job_id": job_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_inference_register_model(self):
|
mock_make_request.reset_mock()
|
||||||
model_id = "default/job-1234"
|
|
||||||
model_type = ModelType.llm
|
|
||||||
model = Model(
|
|
||||||
identifier=model_id,
|
|
||||||
provider_id="nvidia",
|
|
||||||
provider_model_id=model_id,
|
|
||||||
provider_resource_id=model_id,
|
|
||||||
model_type=model_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
# simulate a NIM where default/job-1234 is an available model
|
|
||||||
with patch.object(self.inference_adapter, "check_model_availability", new_callable=AsyncMock) as mock_check:
|
|
||||||
mock_check.return_value = True
|
|
||||||
result = self.run_async(self.inference_adapter.register_model(model))
|
|
||||||
assert result == model
|
|
||||||
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
|
||||||
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
|
|
||||||
|
|
||||||
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
|
|
||||||
self.run_async(
|
|
||||||
self.inference_adapter.chat_completion(
|
|
||||||
model_id=model_id,
|
|
||||||
messages=[{"role": "user", "content": "Hello, model"}],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_chat_completion.assert_called()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
async def test_get_training_jobs(nvidia_post_training_adapter):
|
||||||
unittest.main()
|
"""Test getting list of training jobs."""
|
||||||
|
adapter, mock_make_request = nvidia_post_training_adapter
|
||||||
|
|
||||||
|
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
|
mock_make_request.return_value = {
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": job_id,
|
||||||
|
"created_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"updated_at": "2024-12-09T04:21:19.852832",
|
||||||
|
"config": {
|
||||||
|
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
},
|
||||||
|
"dataset": {"name": "default/sample-basic-test"},
|
||||||
|
"hyperparameters": {
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"training_type": "sft",
|
||||||
|
"batch_size": 16,
|
||||||
|
"epochs": 2,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||||
|
},
|
||||||
|
"output_model": "default/job-1234",
|
||||||
|
"status": "completed",
|
||||||
|
"project": "default",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
jobs = await adapter.get_training_jobs()
|
||||||
|
|
||||||
|
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
|
||||||
|
assert len(jobs.data) == 1
|
||||||
|
job = jobs.data[0]
|
||||||
|
assert job.job_uuid == job_id
|
||||||
|
assert job.status.value == "completed"
|
||||||
|
|
||||||
|
mock_make_request.assert_called_once()
|
||||||
|
_assert_request(
|
||||||
|
mock_make_request,
|
||||||
|
"GET",
|
||||||
|
"/v1/customization/jobs",
|
||||||
|
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cancel_training_job(nvidia_post_training_adapter):
|
||||||
|
"""Test canceling a training job."""
|
||||||
|
adapter, mock_make_request = nvidia_post_training_adapter
|
||||||
|
|
||||||
|
mock_make_request.return_value = {} # Empty response for successful cancellation
|
||||||
|
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
|
|
||||||
|
result = await adapter.cancel_training_job(job_uuid=job_id)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
mock_make_request.assert_called_once()
|
||||||
|
_assert_request(
|
||||||
|
mock_make_request,
|
||||||
|
"POST",
|
||||||
|
f"/v1/customization/jobs/{job_id}/cancel",
|
||||||
|
expected_params={"job_id": job_id},
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
|
|
||||||
|
|
@ -33,7 +32,7 @@ with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||||
MILVUS_PROVIDER = "milvus"
|
MILVUS_PROVIDER = "milvus"
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def mock_milvus_client() -> MagicMock:
|
async def mock_milvus_client() -> MagicMock:
|
||||||
"""Create a mock Milvus client with common method behaviors."""
|
"""Create a mock Milvus client with common method behaviors."""
|
||||||
client = MagicMock()
|
client = MagicMock()
|
||||||
|
|
@ -84,7 +83,7 @@ async def mock_milvus_client() -> MagicMock:
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def milvus_index(mock_milvus_client):
|
async def milvus_index(mock_milvus_client):
|
||||||
"""Create a MilvusIndex with mocked client."""
|
"""Create a MilvusIndex with mocked client."""
|
||||||
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
||||||
|
|
@ -92,7 +91,6 @@ async def milvus_index(mock_milvus_client):
|
||||||
# No real cleanup needed since we're using mocks
|
# No real cleanup needed since we're using mocks
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
# Setup: collection doesn't exist initially, then exists after creation
|
# Setup: collection doesn't exist initially, then exists after creation
|
||||||
mock_milvus_client.has_collection.side_effect = [False, True]
|
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||||
|
|
@ -108,7 +106,6 @@ async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_m
|
||||||
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_vector(
|
async def test_query_chunks_vector(
|
||||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||||
):
|
):
|
||||||
|
|
@ -125,7 +122,6 @@ async def test_query_chunks_vector(
|
||||||
mock_milvus_client.search.assert_called_once()
|
mock_milvus_client.search.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
mock_milvus_client.has_collection.return_value = True
|
mock_milvus_client.has_collection.return_value = True
|
||||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
@ -138,7 +134,6 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e
|
||||||
assert len(response.chunks) == 2
|
assert len(response.chunks) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||||
mock_milvus_client.has_collection.return_value = True
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
|
@ -181,7 +176,6 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl
|
||||||
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_collection(milvus_index, mock_milvus_client):
|
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||||
# Test collection deletion
|
# Test collection deletion
|
||||||
mock_milvus_client.has_collection.return_value = True
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,6 @@ class TestRagQuery:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
RAGQueryConfig(mode="invalid_mode")
|
RAGQueryConfig(mode="invalid_mode")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_accepts_valid_modes(self):
|
async def test_query_accepts_valid_modes(self):
|
||||||
RAGQueryConfig() # Test default (vector)
|
RAGQueryConfig() # Test default (vector)
|
||||||
RAGQueryConfig(mode="vector") # Test vector
|
RAGQueryConfig(mode="vector") # Test vector
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -1390,6 +1390,7 @@ unit = [
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
{ name = "blobfile" },
|
{ name = "blobfile" },
|
||||||
{ name = "chardet" },
|
{ name = "chardet" },
|
||||||
|
{ name = "coverage" },
|
||||||
{ name = "faiss-cpu" },
|
{ name = "faiss-cpu" },
|
||||||
{ name = "litellm" },
|
{ name = "litellm" },
|
||||||
{ name = "mcp" },
|
{ name = "mcp" },
|
||||||
|
|
@ -1499,6 +1500,7 @@ unit = [
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
{ name = "blobfile" },
|
{ name = "blobfile" },
|
||||||
{ name = "chardet" },
|
{ name = "chardet" },
|
||||||
|
{ name = "coverage" },
|
||||||
{ name = "faiss-cpu" },
|
{ name = "faiss-cpu" },
|
||||||
{ name = "litellm" },
|
{ name = "litellm" },
|
||||||
{ name = "mcp" },
|
{ name = "mcp" },
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue