Merge branch 'main' into feat/litellm_sambanova_usage

This commit is contained in:
jhpiedrahitao 2025-04-01 07:57:21 -05:00
commit 9c9f9577e2
173 changed files with 3073 additions and 3118 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
POST_TRAINING_PROVIDER_TYPES = ["remote::nvidia"]
@pytest.mark.integration
@pytest.fixture(scope="session")
def post_training_provider_available(llama_stack_client):
providers = llama_stack_client.providers.list()
post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES]
return len(post_training_providers) > 0
@pytest.mark.integration
def test_post_training_provider_registration(llama_stack_client, post_training_provider_available):
"""Check if post_training is in the api list.
This is a sanity check to ensure the provider is registered."""
if not post_training_provider_available:
pytest.skip("post training provider not available")
providers = llama_stack_client.providers.list()
post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES]
assert len(post_training_providers) > 0
@pytest.mark.integration
def test_get_training_jobs(llama_stack_client, post_training_provider_available):
"""Test listing all training jobs."""
if not post_training_provider_available:
pytest.skip("post training provider not available")
jobs = llama_stack_client.post_training.get_training_jobs()
assert isinstance(jobs, dict)
assert "data" in jobs
assert isinstance(jobs["data"], list)
@pytest.mark.integration
def test_get_training_job_status(llama_stack_client, post_training_provider_available):
"""Test getting status of a specific training job."""
if not post_training_provider_available:
pytest.skip("post training provider not available")
jobs = llama_stack_client.post_training.get_training_jobs()
if not jobs["data"]:
pytest.skip("No training jobs available to check status")
job_uuid = jobs["data"][0]["job_uuid"]
job_status = llama_stack_client.post_training.get_training_job_status(job_uuid=job_uuid)
assert job_status is not None
assert "job_uuid" in job_status
assert "status" in job_status
assert job_status["job_uuid"] == job_uuid

View file

@ -23,8 +23,8 @@ Model parameters can be influenced by the following options:
- `--judge-model`: comma-separated list of judge models.
- `--embedding-dimension`: output dimensionality of the embedding model to use for testing. Default: 384
Each of these are comma-separated lists and can be used to generate multiple parameter combinations.
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
if no model is specified.
Experimental, under development, options:
- `--record-responses`: record new API responses instead of using cached ones
@ -36,7 +36,7 @@ Experimental, under development, options:
Run all text inference tests with the `together` distribution:
```bash
pytest -s -v tests/api/inference/test_text_inference.py \
pytest -s -v tests/integration/inference/test_text_inference.py \
--stack-config=together \
--text-model=meta-llama/Llama-3.1-8B-Instruct
```
@ -44,7 +44,7 @@ pytest -s -v tests/api/inference/test_text_inference.py \
Run all text inference tests with the `together` distribution and `meta-llama/Llama-3.1-8B-Instruct`:
```bash
pytest -s -v tests/api/inference/test_text_inference.py \
pytest -s -v tests/integration/inference/test_text_inference.py \
--stack-config=together \
--text-model=meta-llama/Llama-3.1-8B-Instruct
```
@ -57,7 +57,7 @@ VISION_MODELS=meta-llama/Llama-3.2-11B-Vision-Instruct
EMBEDDING_MODELS=all-MiniLM-L6-v2
export TOGETHER_API_KEY=<together_api_key>
pytest -s -v tests/api/inference/ \
pytest -s -v tests/integration/inference/ \
--stack-config=together \
--text-model=$TEXT_MODELS \
--vision-model=$VISION_MODELS \
@ -69,7 +69,7 @@ Same thing but instead of using the distribution, use an adhoc stack with just o
```bash
export FIREWORKS_API_KEY=<fireworks_api_key>
pytest -s -v tests/api/inference/ \
pytest -s -v tests/integration/inference/ \
--stack-config=inference=fireworks \
--text-model=$TEXT_MODELS \
--vision-model=$VISION_MODELS \
@ -81,7 +81,7 @@ Running Vector IO tests for a number of embedding models:
```bash
EMBEDDING_MODELS=all-MiniLM-L6-v2
pytest -s -v tests/api/vector_io/ \
pytest -s -v tests/integration/vector_io/ \
--stack-config=inference=sentence-transformers,vector_io=sqlite-vec \
--embedding-model=$EMBEDDING_MODELS
```

View file

@ -173,6 +173,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
agent_config = {
**agent_config,
"instructions": "You are a helpful assistant that can use web search to answer questions.",
"tools": [
"builtin::websearch",
],
@ -184,20 +185,20 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
messages=[
{
"role": "user",
"content": "Search the web and tell me who the founder of Meta is.",
"content": "Search the web and tell me what is the local time in Tokyo currently.",
}
],
session_id=session_id,
stream=False,
)
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "tool_execution>" in logs_str
assert "Tool:brave_search Response:" in logs_str
assert "mark zuckerberg" in logs_str.lower()
if len(agent_config["output_shields"]) > 0:
assert "No Violation" in logs_str
found_tool_execution = False
for step in response.steps:
if step.step_type == "tool_execution":
assert step.tool_calls[0].tool_name == "brave_search"
found_tool_execution = True
break
assert found_tool_execution
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
@ -427,19 +428,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
assert expected_kw in response.output_message.content.lower()
@pytest.mark.parametrize(
"tool",
[
dict(
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": [],
},
),
"builtin::rag/knowledge_search",
],
)
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, tool):
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
@ -452,7 +441,6 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
]
agent_config = {
**agent_config,
"tools": [tool],
}
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
@ -486,10 +474,6 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
stream=False,
)
# rag is called
tool_execution_step = [step for step in response.steps if step.step_type == "tool_execution"]
assert len(tool_execution_step) >= 1
assert tool_execution_step[0].tool_calls[0].tool_name == "knowledge_search"
assert "lora" in response.output_message.content.lower()
@ -536,19 +520,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
],
}
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
inflation_doc = Document(
document_id="test_csv",
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
mime_type="text/csv",
metadata={},
)
user_prompts = [
(
"Here is a csv file, can you describe it?",
[inflation_doc],
"code_interpreter",
"",
),
(
"when was Perplexity the company founded?",
[],

View file

@ -117,6 +117,33 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
assert len(content_str) > 10
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:stop_sequence",
],
)
def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
# This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
if inference_provider_type != "remote::vllm":
pytest.xfail(f"{inference_provider_type} doesn't support 'stop' parameter yet")
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content=tc["content"],
stream=True,
model_id=text_model_id,
sampling_params={
"max_tokens": 50,
"stop": ["1963"],
},
)
streamed_content = [chunk.delta for chunk in response]
content_str = "".join(streamed_content).lower().strip()
assert "1963" not in content_str
@pytest.mark.parametrize(
"test_case",
[
@ -266,6 +293,7 @@ def test_text_chat_completion_first_token_profiling(client_with_models, text_mod
model_id=text_model_id,
messages=messages,
stream=False,
timeout=120, # Increase timeout to 2 minutes for large conversation history
)
message_content = response.completion_message.content.lower().strip()
assert len(message_content) > 0
@ -292,6 +320,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
model_id=text_model_id,
messages=[{"role": "user", "content": question}],
stream=True,
timeout=120, # Increase timeout to 2 minutes for large conversation history
)
streamed_content = [str(chunk.event.delta.text.lower().strip()) for chunk in response]
assert len(streamed_content) > 0

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from uuid import uuid4
from llama_stack_client import Agent
def test_agent_query_spans(llama_stack_client, text_model_id):
agent = Agent(llama_stack_client, model=text_model_id, instructions="You are a helpful assistant")
session_id = agent.create_session(f"test-session-{uuid4()}")
agent.create_turn(
messages=[
{
"role": "user",
"content": "Give me a sentence that contains the word: hello",
}
],
session_id=session_id,
stream=False,
)
# Wait for the span to be logged
time.sleep(2)
agent_logs = []
for span in llama_stack_client.telemetry.query_spans(
attribute_filters=[
{"key": "session_id", "op": "eq", "value": session_id},
],
attributes_to_return=["input", "output"],
):
if span.attributes["output"] != "no shields":
agent_logs.append(span.attributes)
assert len(agent_logs) == 1
assert "Give me a sentence that contains the word: hello" in agent_logs[0]["input"]
assert "hello" in agent_logs[0]["output"].lower()

View file

@ -10,6 +10,11 @@
"expected": "1963"
}
},
"stop_sequence": {
"data": {
"content": "Return the exact same sentence and don't add additional words): Michael Jordan was born in the year of 1963"
}
},
"streaming": {
"data": {
"content": "Roses are red,"

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
mock_session = MagicMock()
mock_session.closed = False
mock_session.close = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock()
@pytest.fixture(scope="session", autouse=True)
def patch_aiohttp_session():
with patch("aiohttp.ClientSession", return_value=mock_session):
yield
@pytest.fixture
def event_loop():
"""Create and provide a new event loop for each test."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()
@pytest.fixture
def run_async():
"""Fixture to run async functions in tests."""
def _run_async(coro):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coro)
finally:
loop.close()
return _run_async

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
import warnings
from unittest.mock import patch
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigEfficiencyConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack.providers.remote.post_training.nvidia.post_training import (
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
)
class TestNvidiaParameters(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
)
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
self.mock_make_request.return_value = {
"id": "job-123",
"status": "created",
"created_at": "2025-03-04T13:07:47.543605",
"updated_at": "2025-03-04T13:07:47.543605",
}
def tearDown(self):
self.make_request_patcher.stop()
def _assert_request_params(self, expected_json):
"""Helper method to verify parameters in the request JSON."""
call_args = self.mock_make_request.call_args
actual_json = call_args[1]["json"]
for key, value in expected_json.items():
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert actual_json[key][nested_key] == nested_value
else:
assert actual_json[key] == value
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def test_customizer_parameters_passed(self):
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
custom_adapter_dim = 32 # Different from default of 8
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=custom_adapter_dim,
adapter_dropout=0.2,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002)
training_config = TrainingConfig(
n_epochs=3,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
warning_texts = [str(warning.message) for warning in w]
fields = [
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self._assert_request_params(
{
"hyperparameters": {
"lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2, "alpha": 16},
"epochs": 3,
"learning_rate": 0.0002,
"batch_size": 16,
}
}
)
def test_required_parameters_passed(self):
"""Test scenario 2: When required parameters are passed."""
required_model = "meta-llama/Llama-3.1-8B-Instruct"
required_dataset_id = "required-dataset"
required_job_uuid = "required-job"
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(
dataset_id=required_dataset_id, # Required parameter
batch_size=8,
)
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
warning_texts = [str(warning.message) for warning in w]
fields = [
"rank",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self.mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
def test_unsupported_parameters_warning(self):
"""Test that warnings are raised for unsupported parameters."""
data_config = TrainingConfigDataConfig(
dataset_id="test-dataset",
batch_size=8,
# Unsupported parameters
shuffle=True,
data_format="instruct",
validation_dataset_id="val-dataset",
)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
weight_decay=0.01,
# Unsupported parameters
optimizer_type="adam",
num_warmup_steps=100,
)
efficiency_config = TrainingConfigEfficiencyConfig(
enable_activation_checkpointing=True # Unsupported parameter
)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
# Unsupported parameters
efficiency_config=efficiency_config,
max_steps_per_epoch=1000,
gradient_accumulation_steps=4,
max_validation_steps=100,
dtype="bf16",
)
# Capture warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="test-dir", # Unsupported parameter
algorithm_config=LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
),
training_config=training_config,
logger_config={"test": "value"}, # Unsupported parameter
hyperparam_search_config={"test": "value"}, # Unsupported parameter
)
)
assert len(w) >= 4
warning_texts = [str(warning.message) for warning in w]
fields = [
"checkpoint_dir",
"hyperparam_search_config",
"logger_config",
"TrainingConfig",
"DataConfig",
"OptimizerConfig",
"max_steps_per_epoch",
"gradient_accumulation_steps",
"max_validation_steps",
"dtype",
# required unsupported parameters
"rank",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,295 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
import warnings
from unittest.mock import patch
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack.providers.remote.post_training.nvidia.post_training import (
ListNvidiaPostTrainingJobs,
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
NvidiaPostTrainingJob,
NvidiaPostTrainingJobStatusResponse,
)
class TestNvidiaPostTraining(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
)
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
def tearDown(self):
self.make_request_patcher.stop()
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
def test_supervised_fine_tune(self):
"""Test the supervised fine-tuning API call."""
self.mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"config": {
"schema_version": "1.0",
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.569837",
"custom_fields": {},
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
"model_path": "llama-3_1-8b-instruct",
"training_types": [],
"finetuning_types": ["lora"],
"precision": "bf16",
"num_gpus": 4,
"num_nodes": 1,
"micro_batch_size": 1,
"tensor_parallel_size": 1,
"max_seq_length": 4096,
},
"dataset": {
"schema_version": "1.0",
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.542660",
"custom_fields": {},
"name": "sample-basic-test",
"version_id": "main",
"version_tags": [],
},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "created",
"project": "default",
"custom_fields": {},
"ownership": {"created_by": "me", "access_policies": {}},
}
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
training_job = self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
# check the output is a PostTrainingJob
assert isinstance(training_job, NvidiaPostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
"/v1/customization/jobs",
expected_json={
"config": "meta/llama-3.1-8b-instruct",
"dataset": {"name": "sample-basic-test", "namespace": "default"},
"hyperparameters": {
"training_type": "sft",
"finetuning_type": "lora",
"epochs": 2,
"batch_size": 16,
"learning_rate": 0.0001,
"lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1},
},
},
)
def test_supervised_fine_tune_with_qat(self):
algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
# This will raise NotImplementedError since QAT is not supported
with self.assertRaises(NotImplementedError):
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
def test_get_training_job_status(self):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": "completed",
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == "completed"
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
)
def test_get_training_jobs(self):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
jobs = self.run_async(self.adapter.get_training_jobs())
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
assert len(jobs.data) == 1
job = jobs.data[0]
assert job.job_uuid == job_id
assert job.status.value == "completed"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"GET",
"/v1/customization/jobs",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
)
def test_cancel_training_job(self):
self.mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
assert result is None
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
f"/v1/customization/jobs/{job_id}/cancel",
expected_params={"job_id": job_id},
)
if __name__ == "__main__":
unittest.main()

View file

@ -14,17 +14,10 @@ from llama_stack.distribution.utils.dynamic import instantiate_class_type
class TestProviderConfigurations:
"""Test suite for testing provider configurations across all API types."""
def test_all_api_providers_exist(self):
provider_registry = get_provider_registry()
for api in providable_apis():
providers = provider_registry.get(api, {})
assert providers, f"No providers found for API type: {api}"
@pytest.mark.parametrize("api", providable_apis())
def test_api_providers(self, api):
provider_registry = get_provider_registry()
providers = provider_registry.get(api, {})
assert providers, f"No providers found for API type: {api}"
failures = []
for provider_type, provider_spec in providers.items():