add latest code

This commit is contained in:
Ubuntu 2025-03-06 16:33:51 +00:00 committed by raspawar
parent d702296e61
commit 7e2b4489e1
3 changed files with 180 additions and 102 deletions

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
## ToDo: add supported models list, model validation logic

View file

@ -19,11 +19,42 @@ class NvidiaPostTrainingConfig(BaseModel):
) )
user_id: Optional[str] = Field( user_id: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_USER_ID"), default_factory=lambda: os.getenv("NVIDIA_USER_ID", "llama-stack-user"),
description="The NVIDIA user ID, only needed of using the hosted service", description="The NVIDIA user ID, only needed of using the hosted service",
) )
dataset_namespace: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
description="The NVIDIA dataset namespace, only needed of using the hosted service",
)
access_policies: Optional[dict] = Field(
default_factory=lambda: os.getenv("NVIDIA_ACCESS_POLICIES", {}),
description="The NVIDIA access policies, only needed of using the hosted service",
)
project_id: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-project"),
description="The NVIDIA project ID, only needed of using the hosted service",
)
# ToDO: validate this, add default value
customizer_url: str = Field( customizer_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_CUSTOMIZER_URL"), default_factory=lambda: os.getenv("NVIDIA_CUSTOMIZER_URL"),
description="Base URL for the NeMo Customizer API", description="Base URL for the NeMo Customizer API",
) )
timeout: int = Field(
default=300,
description="Timeout for the NVIDIA Post Training API",
)
max_retries: int = Field(
default=3,
description="Maximum number of retries for the NVIDIA Post Training API",
)
output_model_dir: str = Field(
default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"),
description="Directory to save the output model",
)

View file

@ -4,16 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Literal, Optional from typing import Any, Dict, List, Literal, Optional
import aiohttp import aiohttp
from aiohttp import ClientTimeout from pydantic import BaseModel, ConfigDict
from llama_stack.apis.post_training import ( from llama_stack.apis.post_training import (
AlgorithmConfig, AlgorithmConfig,
DPOAlignmentConfig, DPOAlignmentConfig,
JobStatus, JobStatus,
ListPostTrainingJobsResponse,
PostTrainingJob, PostTrainingJob,
PostTrainingJobArtifactsResponse, PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse, PostTrainingJobStatusResponse,
@ -24,6 +23,32 @@ from llama_stack.providers.remote.post_training.nvidia.config import (
) )
from llama_stack.schema_utils import webmethod from llama_stack.schema_utils import webmethod
# Map API status to JobStatus enum
STATUS_MAPPING = {
"running": "in_progress",
"completed": "completed",
"failed": "failed",
"cancelled": "cancelled",
"pending": "scheduled",
}
class NvidiaPostTrainingJob(PostTrainingJob):
"""Parse the response from the Customizer API.
Inherits job_uuid from PostTrainingJob.
Adds status, created_at, updated_at parameters.
Passes through all other parameters from data field in the response.
"""
model_config = ConfigDict(extra="allow")
status: JobStatus
created_at: datetime
updated_at: datetime
class ListNvidiaPostTrainingJobs(BaseModel):
data: List[NvidiaPostTrainingJob]
class NvidiaPostTrainingImpl: class NvidiaPostTrainingImpl:
def __init__(self, config: NvidiaPostTrainingConfig): def __init__(self, config: NvidiaPostTrainingConfig):
@ -32,96 +57,110 @@ class NvidiaPostTrainingImpl:
if config.api_key: if config.api_key:
self.headers["Authorization"] = f"Bearer {config.api_key}" self.headers["Authorization"] = f"Bearer {config.api_key}"
self.timeout = ClientTimeout(total=config.timeout) self.timeout = aiohttp.ClientTimeout(total=config.timeout)
async def _make_request(self, method: str, path: str, **kwargs) -> Dict[str, Any]: async def _make_request(
self,
method: str,
path: str,
headers: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Dict[str, Any]:
"""Helper method to make HTTP requests to the Customizer API.""" """Helper method to make HTTP requests to the Customizer API."""
url = f"{self.config.customizer_url}{path}" url = f"{self.config.customizer_url}{path}"
request_headers = self.headers.copy() # Create a copy to avoid modifying the original
for attempt in range(self.config.max_retries): if headers:
async with aiohttp.ClientSession(headers=self.headers, timeout=self.timeout) as session: request_headers.update(headers)
async with session.request(method, url, **kwargs) as response:
if response.status >= 400: # Add content-type header for JSON requests
error_data = response if json and "Content-Type" not in request_headers:
raise Exception(f"API request failed: {error_data}") request_headers["Content-Type"] = "application/json"
for _ in range(self.config.max_retries):
async with aiohttp.ClientSession(headers=request_headers, timeout=self.timeout) as session:
async with session.request(method, url, params=params, json=json, **kwargs) as response:
return await response.json() return await response.json()
@webmethod(route="/post-training/jobs", method="GET") @webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs( async def get_training_jobs(
self, self,
page: int = 1, page: Optional[int] = 1,
page_size: int = 10, page_size: Optional[int] = 10,
sort: Literal[ sort: Optional[Literal["created_at", "-created_at"]] = "created_at",
"created_at", ) -> ListNvidiaPostTrainingJobs:
"-created_at", """Get all customization jobs.
] = "created_at", Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
) -> ListPostTrainingJobsResponse:
"""
Get all customization jobs.
""" """
params = {"page": page, "page_size": page_size, "sort": sort} params = {"page": page, "page_size": page_size, "sort": sort}
response = await self._make_request( response = await self._make_request("GET", "/v1/customization/jobs", params=params)
"GET",
"/v1/customization/jobs",
# params=params
)
# Convert customization jobs to PostTrainingJob objects jobs = []
jobs = [PostTrainingJob(job_uuid=job["id"]) for job in response["data"]] for job in response.get("data", []):
job_id = job.pop("id")
job_status = job.pop("status", "unknown").lower()
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
# Remove the data and pass through other fields # Convert string timestamps to datetime objects
response.pop("data") created_at = datetime.fromisoformat(job.pop("created_at")) if "created_at" in job else datetime.now()
return ListPostTrainingJobsResponse(data=jobs, **response) updated_at = datetime.fromisoformat(job.pop("updated_at")) if "updated_at" in job else datetime.now()
# Create NvidiaPostTrainingJob instance
jobs.append(
NvidiaPostTrainingJob(
job_uuid=job_id,
status=JobStatus(mapped_status),
created_at=created_at,
updated_at=updated_at,
**job,
)
)
return ListNvidiaPostTrainingJobs(data=jobs)
@webmethod(route="/post-training/job/status", method="GET") @webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: async def get_training_job_status(self, job_uuid: str) -> Optional[NvidiaPostTrainingJob]:
""" """Get the status of a customization job.
Get the status of a customization job. Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
""" """
response = await self._make_request( response = await self._make_request(
"GET", "GET",
f"/v1/customization/jobs/{job_uuid}/status", f"/v1/customization/jobs/{job_uuid}/status",
params=job_uuid, params={"job_id": job_uuid},
) )
# Map API status to JobStatus enum api_status = response.pop("status").lower()
status_mapping = { mapped_status = STATUS_MAPPING.get(api_status, "unknown")
"running": "in_progress",
"completed": "completed",
"failed": "failed",
# "cancelled": "cancelled",
"pending": "scheduled",
}
api_status = response["status"].lower() return NvidiaPostTrainingJob(
mapped_status = status_mapping.get(api_status, "unknown") status=JobStatus(mapped_status),
job_uuid=job_uuid,
# todo: add callback for rest of the parameters created_at=datetime.fromisoformat(response.pop("created_at")),
response["status"] = JobStatus(mapped_status) updated_at=datetime.fromisoformat(response.pop("updated_at")),
response["job_uuid"] = job_uuid **response,
response["started_at"] = datetime.fromisoformat(response["created_at"]) )
return PostTrainingJobStatusResponse(**response)
@webmethod(route="/post-training/job/cancel", method="POST") @webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: async def cancel_training_job(self, job_uuid: str) -> None:
""" """Cancels a customization job."""
Cancels a customization job. await self._make_request(
""" method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
response = await self._make_request(
"POST", f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
) )
@webmethod(route="/post-training/job/artifacts") @webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
"""Get artifacts for a specific training job."""
raise NotImplementedError("Job artifacts are not implemented yet") raise NotImplementedError("Job artifacts are not implemented yet")
## post-training artifacts
@webmethod(route="/post-training/artifacts", method="GET") @webmethod(route="/post-training/artifacts", method="GET")
async def get_post_training_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: async def get_post_training_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
"""Get all post-training artifacts."""
raise NotImplementedError("Job artifacts are not implemented yet") raise NotImplementedError("Job artifacts are not implemented yet")
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
@ -140,39 +179,31 @@ class NvidiaPostTrainingImpl:
- dataset is registered separately in nemo datastore - dataset is registered separately in nemo datastore
- model checkpoint is downloaded from ngc and exists in the local directory - model checkpoint is downloaded from ngc and exists in the local directory
Parameters: Parameters:
training_config: TrainingConfig training_config: TrainingConfig - Configuration for training
model: str model: str - Model identifier
algorithm_config: Optional[AlgorithmConfig] algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
dataset: Optional[str] checkpoint_dir: Optional[str] - Directory containing model checkpoints
run_local_jobs: bool = False True for standalone mode. job_uuid: str - Unique identifier for the job
nemo_data_store_url: str URL of NeMo Data Store for Customizer to connect to for dataset and model files. hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search
logger_config: Dict[str, Any] - Configuration for logging
LoRA config:
training_type = sft
finetuning_type = lora
adapter_dim = Size of adapter layers added throughout the model.
adapter_dropout = Dropout probability in the adapter layer.
ToDo:
support for model config of helm chart ??
/status endpoint for model customization
Get Metrics for customization
Weights and Biases integration ??
OpenTelemetry integration ??
""" """
# map model to nvidia model name # map model to nvidia model name
model_mapping = { model_mapping = {
"Llama3.1-8B-Instruct": "meta/llama-3.1-8b-instruct", "Llama3.1-8B-Instruct": "meta/llama-3.1-8b-instruct",
"meta-llama/Llama-3.1-8B-Instruct": "meta/llama-3.1-8b-instruct",
} }
nvidia_model = model_mapping.get(model, model) nvidia_model = model_mapping.get(model, model)
# Prepare the customization job request # Get output model directory from config
output_model = self.config.output_model_dir
# Prepare base job configuration
job_config = { job_config = {
"config": nvidia_model, "config": nvidia_model,
"dataset": { "dataset": {
"name": training_config["data_config"]["dataset_id"], "name": training_config["data_config"]["dataset_id"],
"namespace": "default", # todo: could be configurable in the future "namespace": self.config.dataset_namespace,
}, },
"hyperparameters": { "hyperparameters": {
"training_type": "sft", "training_type": "sft",
@ -180,33 +211,34 @@ class NvidiaPostTrainingImpl:
"epochs": training_config["n_epochs"], "epochs": training_config["n_epochs"],
"batch_size": training_config["data_config"]["batch_size"], "batch_size": training_config["data_config"]["batch_size"],
"learning_rate": training_config["optimizer_config"]["lr"], "learning_rate": training_config["optimizer_config"]["lr"],
"lora": {"adapter_dim": 16},
}, },
"project": "llama-stack-project", # todo: could be configurable "project": self.config.project_id,
"ownership": { "ownership": {"created_by": self.config.user_id, "access_policies": self.config.access_policies},
"created_by": self.config.user_id or "llama-stack-user", "output_model": output_model,
},
"output_model": f"llama-stack-{training_config['data_config']['dataset_id']}-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
} }
# Add LoRA specific configuration if provided # Handle LoRA-specific configuration
if isinstance(algorithm_config, Dict) and algorithm_config["type"] == "LoRA": if algorithm_config:
if algorithm_config["adapter_dim"]: if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
job_config["hyperparameters"]["lora"]["adapter_dim"] = algorithm_config["adapter_dim"] # Extract LoRA-specific parameters
else: lora_config = {k: v for k, v in algorithm_config.items() if k != "type"}
raise NotImplementedError(f"Algorithm config {type(algorithm_config)} not implemented.") job_config["hyperparameters"]["lora"] = lora_config
# Make the API request to create the customization job # Add adapter_dim if available in training_config
response = await self._make_request("POST", "/v1/customization/jobs", json=job_config) if training_config.get("algorithm_config", {}).get("adapter_dim"):
job_config["hyperparameters"]["lora"]["adapter_dim"] = training_config["algorithm_config"][
"adapter_dim"
]
# Create the customization job
response = await self._make_request(
method="POST",
path="/v1/customization/jobs",
headers={"Accept": "application/json"},
json=job_config,
)
# Parse the response to extract relevant fields
job_uuid = response["id"] job_uuid = response["id"]
created_at = response["created_at"]
status = response["status"]
output_model = response["output_model"]
project = response["project"]
created_by = response["ownership"]["created_by"]
return PostTrainingJob(job_uuid=job_uuid) return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize( async def preference_optimize(
@ -217,7 +249,9 @@ class NvidiaPostTrainingImpl:
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob:
"""Optimize a model based on preference data."""
raise NotImplementedError("Preference optimization is not implemented yet")
@webmethod(route="/post-training/job/logs", method="GET") @webmethod(route="/post-training/job/logs", method="GET")
async def get_training_job_container_logs(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: async def get_training_job_container_logs(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: