mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
add test cases
This commit is contained in:
parent
6a0c38f123
commit
409383ae5f
5 changed files with 425 additions and 29 deletions
|
@ -55,7 +55,7 @@ from .openai_utils import (
|
||||||
convert_openai_completion_choice,
|
convert_openai_completion_choice,
|
||||||
convert_openai_completion_stream,
|
convert_openai_completion_stream,
|
||||||
)
|
)
|
||||||
from .utils import _is_nvidia_hosted, check_health
|
from .utils import _is_nvidia_hosted
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -134,7 +134,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
if content_has_media(content):
|
if content_has_media(content):
|
||||||
raise NotImplementedError("Media is not supported")
|
raise NotImplementedError("Media is not supported")
|
||||||
|
|
||||||
await check_health(self._config) # this raises errors
|
# ToDo: check health of NeMo endpoints and enable this
|
||||||
|
# removing this health check as NeMo customizer endpoint health check is returning 404
|
||||||
|
# await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
provider_model_id = self.get_provider_model_id(model_id)
|
provider_model_id = self.get_provider_model_id(model_id)
|
||||||
request = convert_completion_request(
|
request = convert_completion_request(
|
||||||
|
@ -236,7 +238,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
if tool_prompt_format:
|
if tool_prompt_format:
|
||||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
|
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
|
||||||
|
|
||||||
await check_health(self._config) # this raises errors
|
# await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
provider_model_id = self.get_provider_model_id(model_id)
|
provider_model_id = self.get_provider_model_id(model_id)
|
||||||
request = await convert_chat_completion_request(
|
request = await convert_chat_completion_request(
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
## ToDo: add supported models list, model validation logic
|
|
|
@ -82,6 +82,9 @@ class NvidiaPostTrainingImpl:
|
||||||
for _ in range(self.config.max_retries):
|
for _ in range(self.config.max_retries):
|
||||||
async with aiohttp.ClientSession(headers=request_headers, timeout=self.timeout) as session:
|
async with aiohttp.ClientSession(headers=request_headers, timeout=self.timeout) as session:
|
||||||
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||||
|
if response.status >= 400:
|
||||||
|
error_data = await response.json()
|
||||||
|
raise Exception(f"API request failed: {error_data}")
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
|
@ -175,9 +178,9 @@ class NvidiaPostTrainingImpl:
|
||||||
Fine-tunes a model on a dataset.
|
Fine-tunes a model on a dataset.
|
||||||
Currently only supports Lora finetuning for standlone docker container.
|
Currently only supports Lora finetuning for standlone docker container.
|
||||||
Assumptions:
|
Assumptions:
|
||||||
- model is a valid Nvidia model
|
- nemo microservice is running and endpoint is set in config.customizer_url
|
||||||
- dataset is registered separately in nemo datastore
|
- dataset is registered separately in nemo datastore
|
||||||
- model checkpoint is downloaded from ngc and exists in the local directory
|
- model checkpoint is downloaded as per nemo customizer requirements
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
training_config: TrainingConfig - Configuration for training
|
training_config: TrainingConfig - Configuration for training
|
||||||
|
|
59
llama_stack/providers/remote/post_training/nvidia/utils.py
Normal file
59
llama_stack/providers/remote/post_training/nvidia/utils.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .config import NvidiaPostTrainingConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_health(url: str) -> Tuple[bool, bool]:
|
||||||
|
"""
|
||||||
|
Query {url}/v1/health/{live,ready} to check if the server is running and ready
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): URL of the server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, bool]: (is_live, is_ready)
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
live = await client.get(f"{url}/v1/health/live")
|
||||||
|
ready = await client.get(f"{url}/v1/health/ready")
|
||||||
|
return live.status_code == 200, ready.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
async def check_health(config: NvidiaPostTrainingConfig) -> None:
|
||||||
|
"""
|
||||||
|
Check if the server is running and ready
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): URL of the server
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the server is not running or ready
|
||||||
|
"""
|
||||||
|
if not _is_nvidia_hosted(config):
|
||||||
|
logger.info("Checking NVIDIA NIM health...")
|
||||||
|
try:
|
||||||
|
is_live, is_ready = await _get_health(config.url)
|
||||||
|
if not is_live:
|
||||||
|
raise ConnectionError("NVIDIA NIM is not running")
|
||||||
|
if not is_ready:
|
||||||
|
raise ConnectionError("NVIDIA NIM is not ready")
|
||||||
|
# TODO(mf): should we wait for the server to be ready?
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e
|
|
@ -4,11 +4,24 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
||||||
|
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||||
|
TrainingConfig,
|
||||||
|
TrainingConfigDataConfig,
|
||||||
|
TrainingConfigOptimizerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
POST_TRAINING_PROVIDER_TYPES = ["remote::nvidia"]
|
POST_TRAINING_PROVIDER_TYPES = ["remote::nvidia"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def post_training_provider_available(llama_stack_client):
|
def post_training_provider_available(llama_stack_client):
|
||||||
providers = llama_stack_client.providers.list()
|
providers = llama_stack_client.providers.list()
|
||||||
|
@ -16,6 +29,7 @@ def post_training_provider_available(llama_stack_client):
|
||||||
return len(post_training_providers) > 0
|
return len(post_training_providers) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
def test_post_training_provider_registration(llama_stack_client, post_training_provider_available):
|
def test_post_training_provider_registration(llama_stack_client, post_training_provider_available):
|
||||||
"""Check if post_training is in the api list.
|
"""Check if post_training is in the api list.
|
||||||
This is a sanity check to ensure the provider is registered."""
|
This is a sanity check to ensure the provider is registered."""
|
||||||
|
@ -24,18 +38,349 @@ def test_post_training_provider_registration(llama_stack_client, post_training_p
|
||||||
|
|
||||||
providers = llama_stack_client.providers.list()
|
providers = llama_stack_client.providers.list()
|
||||||
post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES]
|
post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES]
|
||||||
|
|
||||||
assert len(post_training_providers) > 0
|
assert len(post_training_providers) > 0
|
||||||
|
|
||||||
assert any("post_training" in provider.api for provider in post_training_providers)
|
|
||||||
|
class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||||
|
os.environ["NVIDIA_BASE_URL"] = "http://nim.test"
|
||||||
|
|
||||||
|
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
|
||||||
|
self.llama_stack_client.initialize = MagicMock(return_value=None)
|
||||||
|
_ = self.llama_stack_client.initialize()
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_supervised_fine_tune(self, mock_post):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"config": {
|
||||||
|
"schema_version": "1.0",
|
||||||
|
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542657",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.569837",
|
||||||
|
"custom_fields": {},
|
||||||
|
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"model_path": "llama-3_1-8b-instruct",
|
||||||
|
"training_types": [],
|
||||||
|
"finetuning_types": ["lora"],
|
||||||
|
"precision": "bf16",
|
||||||
|
"num_gpus": 4,
|
||||||
|
"num_nodes": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"tensor_parallel_size": 1,
|
||||||
|
"max_seq_length": 4096,
|
||||||
|
},
|
||||||
|
"dataset": {
|
||||||
|
"schema_version": "1.0",
|
||||||
|
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542657",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.542660",
|
||||||
|
"custom_fields": {},
|
||||||
|
"name": "default/sample-basic-test",
|
||||||
|
"version_id": "main",
|
||||||
|
"version_tags": [],
|
||||||
|
},
|
||||||
|
"hyperparameters": {
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"training_type": "sft",
|
||||||
|
"batch_size": 16,
|
||||||
|
"epochs": 2,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"lora": {"adapter_dim": 16},
|
||||||
|
},
|
||||||
|
"output_model": "default/job-1234",
|
||||||
|
"status": "created",
|
||||||
|
"project": "default",
|
||||||
|
"custom_fields": {},
|
||||||
|
"ownership": {"created_by": "me", "access_policies": {}},
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
|
||||||
|
|
||||||
|
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
||||||
|
|
||||||
|
optimizer_config = TrainingConfigOptimizerConfig(
|
||||||
|
lr=0.0001,
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = TrainingConfig(
|
||||||
|
n_epochs=2,
|
||||||
|
data_config=data_config,
|
||||||
|
optimizer_config=optimizer_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
self.llama_stack_client.post_training,
|
||||||
|
"supervised_fine_tune",
|
||||||
|
return_value={
|
||||||
|
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||||
|
"status": "created",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"dataset_id": "sample-basic-test",
|
||||||
|
"output_model": "default/job-1234",
|
||||||
|
},
|
||||||
|
):
|
||||||
|
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
|
||||||
|
job_uuid="1234",
|
||||||
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
checkpoint_dir="",
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
training_config=training_config,
|
||||||
|
logger_config={},
|
||||||
|
hyperparam_search_config={},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(training_job["id"], "cust-JGTaMbJMdqjJU8WbQdN9Q2")
|
||||||
|
self.assertEqual(training_job["status"], "created")
|
||||||
|
self.assertEqual(training_job["model"], "meta-llama/Llama-3.1-8B-Instruct")
|
||||||
|
self.assertEqual(training_job["dataset_id"], "sample-basic-test")
|
||||||
|
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_get_job_status(self, mock_get):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"created_at": "2024-12-09T04:06:28.580220",
|
||||||
|
"updated_at": "2024-12-09T04:21:19.852832",
|
||||||
|
"status": "completed",
|
||||||
|
"steps_completed": 1210,
|
||||||
|
"epochs_completed": 2,
|
||||||
|
"percentage_done": 100.0,
|
||||||
|
"best_epoch": 2,
|
||||||
|
"train_loss": 1.718016266822815,
|
||||||
|
"val_loss": 1.8661999702453613,
|
||||||
|
}
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
self.llama_stack_client.post_training.job,
|
||||||
|
"status",
|
||||||
|
return_value={
|
||||||
|
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||||
|
"status": "completed",
|
||||||
|
"created_at": "2024-12-09T04:06:28.580220",
|
||||||
|
"updated_at": "2024-12-09T04:21:19.852832",
|
||||||
|
"steps_completed": 1210,
|
||||||
|
"epochs_completed": 2,
|
||||||
|
"percentage_done": 100.0,
|
||||||
|
"best_epoch": 2,
|
||||||
|
"train_loss": 1.718016266822815,
|
||||||
|
"val_loss": 1.8661999702453613,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
status = self.llama_stack_client.post_training.job.status("cust-JGTaMbJMdqjJU8WbQdN9Q2")
|
||||||
|
|
||||||
|
self.assertEqual(status["status"], "completed")
|
||||||
|
self.assertEqual(status["steps_completed"], 1210)
|
||||||
|
self.assertEqual(status["epochs_completed"], 2)
|
||||||
|
self.assertEqual(status["percentage_done"], 100.0)
|
||||||
|
self.assertEqual(status["best_epoch"], 2)
|
||||||
|
self.assertEqual(status["train_loss"], 1.718016266822815)
|
||||||
|
self.assertEqual(status["val_loss"], 1.8661999702453613)
|
||||||
|
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_get_job(self, mock_get):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"updated_at": "2024-12-09T04:21:19.852832",
|
||||||
|
"config": {"name": "meta-llama/Llama-3.1-8B-Instruct", "base_model": "meta-llama/Llama-3.1-8B-Instruct"},
|
||||||
|
"dataset": {"name": "default/sample-basic-test"},
|
||||||
|
"hyperparameters": {
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"training_type": "sft",
|
||||||
|
"batch_size": 16,
|
||||||
|
"epochs": 2,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"lora": {"adapter_dim": 16},
|
||||||
|
},
|
||||||
|
"output_model": "default/job-1234",
|
||||||
|
"status": "completed",
|
||||||
|
"project": "default",
|
||||||
|
}
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client.post_training,
|
||||||
|
"get_job",
|
||||||
|
return_value={
|
||||||
|
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||||
|
"status": "completed",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"updated_at": "2024-12-09T04:21:19.852832",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"dataset_id": "sample-basic-test",
|
||||||
|
"batch_size": 16,
|
||||||
|
"epochs": 2,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"adapter_dim": 16,
|
||||||
|
"output_model": "default/job-1234",
|
||||||
|
},
|
||||||
|
):
|
||||||
|
job = client.post_training.get_job("cust-JGTaMbJMdqjJU8WbQdN9Q2")
|
||||||
|
|
||||||
|
self.assertEqual(job["id"], "cust-JGTaMbJMdqjJU8WbQdN9Q2")
|
||||||
|
self.assertEqual(job["status"], "completed")
|
||||||
|
self.assertEqual(job["model"], "meta-llama/Llama-3.1-8B-Instruct")
|
||||||
|
self.assertEqual(job["dataset_id"], "sample-basic-test")
|
||||||
|
self.assertEqual(job["batch_size"], 16)
|
||||||
|
self.assertEqual(job["epochs"], 2)
|
||||||
|
self.assertEqual(job["learning_rate"], 0.0001)
|
||||||
|
self.assertEqual(job["adapter_dim"], 16)
|
||||||
|
self.assertEqual(job["output_model"], "default/job-1234")
|
||||||
|
|
||||||
|
@patch("requests.delete")
|
||||||
|
def test_cancel_job(self, mock_delete):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_delete.return_value = mock_response
|
||||||
|
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
with patch.object(client.post_training, "cancel_job", return_value=True):
|
||||||
|
result = client.post_training.cancel_job("cust-JGTaMbJMdqjJU8WbQdN9Q2")
|
||||||
|
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("aiohttp.ClientSession.post")
|
||||||
|
async def test_async_supervised_fine_tune(self, mock_post):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||||
|
"status": "created",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"dataset_id": "sample-basic-test",
|
||||||
|
"output_model": "default/job-1234",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_post.return_value.__aenter__.return_value = mock_response
|
||||||
|
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
|
||||||
|
|
||||||
|
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
||||||
|
|
||||||
|
optimizer_config = TrainingConfigOptimizerConfig(
|
||||||
|
lr=0.0001,
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = TrainingConfig(
|
||||||
|
n_epochs=2,
|
||||||
|
data_config=data_config,
|
||||||
|
optimizer_config=optimizer_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client.post_training,
|
||||||
|
"supervised_fine_tune_async",
|
||||||
|
AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||||
|
"status": "created",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"dataset_id": "sample-basic-test",
|
||||||
|
"output_model": "default/job-1234",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
):
|
||||||
|
training_job = await client.post_training.supervised_fine_tune_async(
|
||||||
|
job_uuid="1234",
|
||||||
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
checkpoint_dir="",
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
training_config=training_config,
|
||||||
|
logger_config={},
|
||||||
|
hyperparam_search_config={},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(training_job["id"], "cust-JGTaMbJMdqjJU8WbQdN9Q2")
|
||||||
|
self.assertEqual(training_job["status"], "created")
|
||||||
|
self.assertEqual(training_job["model"], "meta-llama/Llama-3.1-8B-Instruct")
|
||||||
|
self.assertEqual(training_job["dataset_id"], "sample-basic-test")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("aiohttp.ClientSession.post")
|
||||||
|
async def test_inference_with_fine_tuned_model(self, mock_post):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": "cmpl-123456",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1677858242,
|
||||||
|
"model": "job-1234",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "The next GTC will take place in the middle of March, 2023.",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 100, "completion_tokens": 12, "total_tokens": 112},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_post.return_value.__aenter__.return_value = mock_response
|
||||||
|
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client.inference,
|
||||||
|
"completion",
|
||||||
|
AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": "cmpl-123456",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1677858242,
|
||||||
|
"model": "job-1234",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "The next GTC will take place in the middle of March, 2023.",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 100, "completion_tokens": 12, "total_tokens": 112},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await client.inference.completion(
|
||||||
|
content="When is the upcoming GTC event? GTC 2018 attracted over 8,400 attendees. Due to the COVID pandemic of 2020, GTC 2020 was converted to a digital event and drew roughly 59,000 registrants. The 2021 GTC keynote, which was streamed on YouTube on April 12, included a portion that was made with CGI using the Nvidia Omniverse real-time rendering platform. This next GTC will take place in the middle of March, 2023. Answer: ",
|
||||||
|
stream=False,
|
||||||
|
model_id="job-1234",
|
||||||
|
sampling_params={
|
||||||
|
"max_tokens": 128,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response["model"], "job-1234")
|
||||||
|
self.assertEqual(
|
||||||
|
response["choices"][0]["text"], "The next GTC will take place in the middle of March, 2023."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_list_training_jobs(llama_stack_client, post_training_provider_available):
|
if __name__ == "__main__":
|
||||||
"""Check if the list_jobs method returns a list of jobs."""
|
unittest.main()
|
||||||
if not post_training_provider_available:
|
|
||||||
pytest.skip("post training provider not available")
|
|
||||||
|
|
||||||
jobs = llama_stack_client.post_training.job.list()
|
|
||||||
|
|
||||||
assert jobs is not None
|
|
||||||
assert isinstance(jobs, list)
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue