mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
add latest code
This commit is contained in:
parent
d702296e61
commit
7e2b4489e1
3 changed files with 180 additions and 102 deletions
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
## ToDo: add supported models list, model validation logic
|
|
@ -19,11 +19,42 @@ class NvidiaPostTrainingConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id: Optional[str] = Field(
|
user_id: Optional[str] = Field(
|
||||||
default_factory=lambda: os.getenv("NVIDIA_USER_ID"),
|
default_factory=lambda: os.getenv("NVIDIA_USER_ID", "llama-stack-user"),
|
||||||
description="The NVIDIA user ID, only needed of using the hosted service",
|
description="The NVIDIA user ID, only needed of using the hosted service",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dataset_namespace: Optional[str] = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
|
||||||
|
description="The NVIDIA dataset namespace, only needed of using the hosted service",
|
||||||
|
)
|
||||||
|
|
||||||
|
access_policies: Optional[dict] = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_ACCESS_POLICIES", {}),
|
||||||
|
description="The NVIDIA access policies, only needed of using the hosted service",
|
||||||
|
)
|
||||||
|
|
||||||
|
project_id: Optional[str] = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-project"),
|
||||||
|
description="The NVIDIA project ID, only needed of using the hosted service",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ToDO: validate this, add default value
|
||||||
customizer_url: str = Field(
|
customizer_url: str = Field(
|
||||||
default_factory=lambda: os.getenv("NVIDIA_CUSTOMIZER_URL"),
|
default_factory=lambda: os.getenv("NVIDIA_CUSTOMIZER_URL"),
|
||||||
description="Base URL for the NeMo Customizer API",
|
description="Base URL for the NeMo Customizer API",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
timeout: int = Field(
|
||||||
|
default=300,
|
||||||
|
description="Timeout for the NVIDIA Post Training API",
|
||||||
|
)
|
||||||
|
|
||||||
|
max_retries: int = Field(
|
||||||
|
default=3,
|
||||||
|
description="Maximum number of retries for the NVIDIA Post Training API",
|
||||||
|
)
|
||||||
|
|
||||||
|
output_model_dir: str = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"),
|
||||||
|
description="Directory to save the output model",
|
||||||
|
)
|
||||||
|
|
|
@ -4,16 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import ClientTimeout
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
AlgorithmConfig,
|
AlgorithmConfig,
|
||||||
DPOAlignmentConfig,
|
DPOAlignmentConfig,
|
||||||
JobStatus,
|
JobStatus,
|
||||||
ListPostTrainingJobsResponse,
|
|
||||||
PostTrainingJob,
|
PostTrainingJob,
|
||||||
PostTrainingJobArtifactsResponse,
|
PostTrainingJobArtifactsResponse,
|
||||||
PostTrainingJobStatusResponse,
|
PostTrainingJobStatusResponse,
|
||||||
|
@ -24,6 +23,32 @@ from llama_stack.providers.remote.post_training.nvidia.config import (
|
||||||
)
|
)
|
||||||
from llama_stack.schema_utils import webmethod
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
# Map API status to JobStatus enum
|
||||||
|
STATUS_MAPPING = {
|
||||||
|
"running": "in_progress",
|
||||||
|
"completed": "completed",
|
||||||
|
"failed": "failed",
|
||||||
|
"cancelled": "cancelled",
|
||||||
|
"pending": "scheduled",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NvidiaPostTrainingJob(PostTrainingJob):
|
||||||
|
"""Parse the response from the Customizer API.
|
||||||
|
Inherits job_uuid from PostTrainingJob.
|
||||||
|
Adds status, created_at, updated_at parameters.
|
||||||
|
Passes through all other parameters from data field in the response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
status: JobStatus
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ListNvidiaPostTrainingJobs(BaseModel):
|
||||||
|
data: List[NvidiaPostTrainingJob]
|
||||||
|
|
||||||
|
|
||||||
class NvidiaPostTrainingImpl:
|
class NvidiaPostTrainingImpl:
|
||||||
def __init__(self, config: NvidiaPostTrainingConfig):
|
def __init__(self, config: NvidiaPostTrainingConfig):
|
||||||
|
@ -32,96 +57,110 @@ class NvidiaPostTrainingImpl:
|
||||||
if config.api_key:
|
if config.api_key:
|
||||||
self.headers["Authorization"] = f"Bearer {config.api_key}"
|
self.headers["Authorization"] = f"Bearer {config.api_key}"
|
||||||
|
|
||||||
self.timeout = ClientTimeout(total=config.timeout)
|
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
||||||
|
|
||||||
async def _make_request(self, method: str, path: str, **kwargs) -> Dict[str, Any]:
|
async def _make_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
headers: Optional[Dict[str, Any]] = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
json: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Helper method to make HTTP requests to the Customizer API."""
|
"""Helper method to make HTTP requests to the Customizer API."""
|
||||||
url = f"{self.config.customizer_url}{path}"
|
url = f"{self.config.customizer_url}{path}"
|
||||||
|
request_headers = self.headers.copy() # Create a copy to avoid modifying the original
|
||||||
|
|
||||||
for attempt in range(self.config.max_retries):
|
if headers:
|
||||||
async with aiohttp.ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
request_headers.update(headers)
|
||||||
async with session.request(method, url, **kwargs) as response:
|
|
||||||
if response.status >= 400:
|
# Add content-type header for JSON requests
|
||||||
error_data = response
|
if json and "Content-Type" not in request_headers:
|
||||||
raise Exception(f"API request failed: {error_data}")
|
request_headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
|
for _ in range(self.config.max_retries):
|
||||||
|
async with aiohttp.ClientSession(headers=request_headers, timeout=self.timeout) as session:
|
||||||
|
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
async def get_training_jobs(
|
async def get_training_jobs(
|
||||||
self,
|
self,
|
||||||
page: int = 1,
|
page: Optional[int] = 1,
|
||||||
page_size: int = 10,
|
page_size: Optional[int] = 10,
|
||||||
sort: Literal[
|
sort: Optional[Literal["created_at", "-created_at"]] = "created_at",
|
||||||
"created_at",
|
) -> ListNvidiaPostTrainingJobs:
|
||||||
"-created_at",
|
"""Get all customization jobs.
|
||||||
] = "created_at",
|
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
|
||||||
) -> ListPostTrainingJobsResponse:
|
|
||||||
"""
|
|
||||||
Get all customization jobs.
|
|
||||||
"""
|
"""
|
||||||
params = {"page": page, "page_size": page_size, "sort": sort}
|
params = {"page": page, "page_size": page_size, "sort": sort}
|
||||||
|
|
||||||
response = await self._make_request(
|
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
|
||||||
"GET",
|
|
||||||
"/v1/customization/jobs",
|
|
||||||
# params=params
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert customization jobs to PostTrainingJob objects
|
jobs = []
|
||||||
jobs = [PostTrainingJob(job_uuid=job["id"]) for job in response["data"]]
|
for job in response.get("data", []):
|
||||||
|
job_id = job.pop("id")
|
||||||
|
job_status = job.pop("status", "unknown").lower()
|
||||||
|
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
|
||||||
|
|
||||||
# Remove the data and pass through other fields
|
# Convert string timestamps to datetime objects
|
||||||
response.pop("data")
|
created_at = datetime.fromisoformat(job.pop("created_at")) if "created_at" in job else datetime.now()
|
||||||
return ListPostTrainingJobsResponse(data=jobs, **response)
|
updated_at = datetime.fromisoformat(job.pop("updated_at")) if "updated_at" in job else datetime.now()
|
||||||
|
|
||||||
|
# Create NvidiaPostTrainingJob instance
|
||||||
|
jobs.append(
|
||||||
|
NvidiaPostTrainingJob(
|
||||||
|
job_uuid=job_id,
|
||||||
|
status=JobStatus(mapped_status),
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=updated_at,
|
||||||
|
**job,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ListNvidiaPostTrainingJobs(data=jobs)
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status", method="GET")
|
@webmethod(route="/post-training/job/status", method="GET")
|
||||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
async def get_training_job_status(self, job_uuid: str) -> Optional[NvidiaPostTrainingJob]:
|
||||||
"""
|
"""Get the status of a customization job.
|
||||||
Get the status of a customization job.
|
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
|
||||||
"""
|
"""
|
||||||
response = await self._make_request(
|
response = await self._make_request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/v1/customization/jobs/{job_uuid}/status",
|
f"/v1/customization/jobs/{job_uuid}/status",
|
||||||
params=job_uuid,
|
params={"job_id": job_uuid},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map API status to JobStatus enum
|
api_status = response.pop("status").lower()
|
||||||
status_mapping = {
|
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
|
||||||
"running": "in_progress",
|
|
||||||
"completed": "completed",
|
|
||||||
"failed": "failed",
|
|
||||||
# "cancelled": "cancelled",
|
|
||||||
"pending": "scheduled",
|
|
||||||
}
|
|
||||||
|
|
||||||
api_status = response["status"].lower()
|
return NvidiaPostTrainingJob(
|
||||||
mapped_status = status_mapping.get(api_status, "unknown")
|
status=JobStatus(mapped_status),
|
||||||
|
job_uuid=job_uuid,
|
||||||
# todo: add callback for rest of the parameters
|
created_at=datetime.fromisoformat(response.pop("created_at")),
|
||||||
response["status"] = JobStatus(mapped_status)
|
updated_at=datetime.fromisoformat(response.pop("updated_at")),
|
||||||
response["job_uuid"] = job_uuid
|
**response,
|
||||||
response["started_at"] = datetime.fromisoformat(response["created_at"])
|
)
|
||||||
|
|
||||||
return PostTrainingJobStatusResponse(**response)
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
"""
|
"""Cancels a customization job."""
|
||||||
Cancels a customization job.
|
await self._make_request(
|
||||||
"""
|
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
|
||||||
response = await self._make_request(
|
|
||||||
"POST", f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts")
|
@webmethod(route="/post-training/job/artifacts")
|
||||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||||
|
"""Get artifacts for a specific training job."""
|
||||||
raise NotImplementedError("Job artifacts are not implemented yet")
|
raise NotImplementedError("Job artifacts are not implemented yet")
|
||||||
|
|
||||||
## post-training artifacts
|
|
||||||
@webmethod(route="/post-training/artifacts", method="GET")
|
@webmethod(route="/post-training/artifacts", method="GET")
|
||||||
async def get_post_training_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
async def get_post_training_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||||
|
"""Get all post-training artifacts."""
|
||||||
raise NotImplementedError("Job artifacts are not implemented yet")
|
raise NotImplementedError("Job artifacts are not implemented yet")
|
||||||
|
|
||||||
|
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
|
@ -140,39 +179,31 @@ class NvidiaPostTrainingImpl:
|
||||||
- dataset is registered separately in nemo datastore
|
- dataset is registered separately in nemo datastore
|
||||||
- model checkpoint is downloaded from ngc and exists in the local directory
|
- model checkpoint is downloaded from ngc and exists in the local directory
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
training_config: TrainingConfig
|
training_config: TrainingConfig - Configuration for training
|
||||||
model: str
|
model: str - Model identifier
|
||||||
algorithm_config: Optional[AlgorithmConfig]
|
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
|
||||||
dataset: Optional[str]
|
checkpoint_dir: Optional[str] - Directory containing model checkpoints
|
||||||
run_local_jobs: bool = False True for standalone mode.
|
job_uuid: str - Unique identifier for the job
|
||||||
nemo_data_store_url: str URL of NeMo Data Store for Customizer to connect to for dataset and model files.
|
hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search
|
||||||
|
logger_config: Dict[str, Any] - Configuration for logging
|
||||||
LoRA config:
|
|
||||||
training_type = sft
|
|
||||||
finetuning_type = lora
|
|
||||||
adapter_dim = Size of adapter layers added throughout the model.
|
|
||||||
adapter_dropout = Dropout probability in the adapter layer.
|
|
||||||
|
|
||||||
ToDo:
|
|
||||||
support for model config of helm chart ??
|
|
||||||
/status endpoint for model customization
|
|
||||||
Get Metrics for customization
|
|
||||||
Weights and Biases integration ??
|
|
||||||
OpenTelemetry integration ??
|
|
||||||
"""
|
"""
|
||||||
# map model to nvidia model name
|
# map model to nvidia model name
|
||||||
model_mapping = {
|
model_mapping = {
|
||||||
"Llama3.1-8B-Instruct": "meta/llama-3.1-8b-instruct",
|
"Llama3.1-8B-Instruct": "meta/llama-3.1-8b-instruct",
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct": "meta/llama-3.1-8b-instruct",
|
||||||
}
|
}
|
||||||
nvidia_model = model_mapping.get(model, model)
|
nvidia_model = model_mapping.get(model, model)
|
||||||
|
|
||||||
# Prepare the customization job request
|
# Get output model directory from config
|
||||||
|
output_model = self.config.output_model_dir
|
||||||
|
|
||||||
|
# Prepare base job configuration
|
||||||
job_config = {
|
job_config = {
|
||||||
"config": nvidia_model,
|
"config": nvidia_model,
|
||||||
"dataset": {
|
"dataset": {
|
||||||
"name": training_config["data_config"]["dataset_id"],
|
"name": training_config["data_config"]["dataset_id"],
|
||||||
"namespace": "default", # todo: could be configurable in the future
|
"namespace": self.config.dataset_namespace,
|
||||||
},
|
},
|
||||||
"hyperparameters": {
|
"hyperparameters": {
|
||||||
"training_type": "sft",
|
"training_type": "sft",
|
||||||
|
@ -180,33 +211,34 @@ class NvidiaPostTrainingImpl:
|
||||||
"epochs": training_config["n_epochs"],
|
"epochs": training_config["n_epochs"],
|
||||||
"batch_size": training_config["data_config"]["batch_size"],
|
"batch_size": training_config["data_config"]["batch_size"],
|
||||||
"learning_rate": training_config["optimizer_config"]["lr"],
|
"learning_rate": training_config["optimizer_config"]["lr"],
|
||||||
"lora": {"adapter_dim": 16},
|
|
||||||
},
|
},
|
||||||
"project": "llama-stack-project", # todo: could be configurable
|
"project": self.config.project_id,
|
||||||
"ownership": {
|
"ownership": {"created_by": self.config.user_id, "access_policies": self.config.access_policies},
|
||||||
"created_by": self.config.user_id or "llama-stack-user",
|
"output_model": output_model,
|
||||||
},
|
|
||||||
"output_model": f"llama-stack-{training_config['data_config']['dataset_id']}-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add LoRA specific configuration if provided
|
# Handle LoRA-specific configuration
|
||||||
if isinstance(algorithm_config, Dict) and algorithm_config["type"] == "LoRA":
|
if algorithm_config:
|
||||||
if algorithm_config["adapter_dim"]:
|
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
|
||||||
job_config["hyperparameters"]["lora"]["adapter_dim"] = algorithm_config["adapter_dim"]
|
# Extract LoRA-specific parameters
|
||||||
else:
|
lora_config = {k: v for k, v in algorithm_config.items() if k != "type"}
|
||||||
raise NotImplementedError(f"Algorithm config {type(algorithm_config)} not implemented.")
|
job_config["hyperparameters"]["lora"] = lora_config
|
||||||
|
|
||||||
# Make the API request to create the customization job
|
# Add adapter_dim if available in training_config
|
||||||
response = await self._make_request("POST", "/v1/customization/jobs", json=job_config)
|
if training_config.get("algorithm_config", {}).get("adapter_dim"):
|
||||||
|
job_config["hyperparameters"]["lora"]["adapter_dim"] = training_config["algorithm_config"][
|
||||||
|
"adapter_dim"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create the customization job
|
||||||
|
response = await self._make_request(
|
||||||
|
method="POST",
|
||||||
|
path="/v1/customization/jobs",
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
json=job_config,
|
||||||
|
)
|
||||||
|
|
||||||
# Parse the response to extract relevant fields
|
|
||||||
job_uuid = response["id"]
|
job_uuid = response["id"]
|
||||||
created_at = response["created_at"]
|
|
||||||
status = response["status"]
|
|
||||||
output_model = response["output_model"]
|
|
||||||
project = response["project"]
|
|
||||||
created_by = response["ownership"]["created_by"]
|
|
||||||
|
|
||||||
return PostTrainingJob(job_uuid=job_uuid)
|
return PostTrainingJob(job_uuid=job_uuid)
|
||||||
|
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
|
@ -217,7 +249,9 @@ class NvidiaPostTrainingImpl:
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob:
|
||||||
|
"""Optimize a model based on preference data."""
|
||||||
|
raise NotImplementedError("Preference optimization is not implemented yet")
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/logs", method="GET")
|
@webmethod(route="/post-training/job/logs", method="GET")
|
||||||
async def get_training_job_container_logs(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
async def get_training_job_container_logs(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue