ignore ownership param

This commit is contained in:
Ubuntu 2025-03-18 08:50:13 +00:00
parent c885015e6f
commit d7340da7a6
2 changed files with 11 additions and 24 deletions

View file

@ -18,21 +18,11 @@ class NvidiaPostTrainingConfig(BaseModel):
description="The NVIDIA API key.",
)
user_id: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_USER_ID", "llama-stack-user"),
description="The NVIDIA user ID.",
)
dataset_namespace: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
description="The NVIDIA dataset namespace.",
)
access_policies: Optional[dict] = Field(
default_factory=lambda: os.getenv("NVIDIA_ACCESS_POLICIES", {"arbitrary": "json"}),
description="The NVIDIA access policies.",
)
project_id: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-example-model@v1"),
description="The NVIDIA project ID.",
@ -64,7 +54,6 @@ class NvidiaPostTrainingConfig(BaseModel):
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"api_key": "${env.NVIDIA_API_KEY:}",
"user_id": "${env.NVIDIA_USER_ID:llama-stack-user}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}",

View file

@ -14,6 +14,7 @@ from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
PostTraining,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
@ -23,7 +24,6 @@ from llama_stack.providers.remote.post_training.nvidia.config import (
NvidiaPostTrainingConfig,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.schema_utils import webmethod
from .models import _MODEL_ENTRIES
@ -54,7 +54,7 @@ class ListNvidiaPostTrainingJobs(BaseModel):
data: List[NvidiaPostTrainingJob]
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
def __init__(self, config: NvidiaPostTrainingConfig):
self.config = config
self.headers = {}
@ -98,7 +98,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
raise Exception(f"API request failed: {error_data}")
return await response.json()
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(
self,
page: Optional[int] = 1,
@ -135,7 +134,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
return ListNvidiaPostTrainingJobs(data=jobs)
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> Optional[NvidiaPostTrainingJob]:
"""Get the status of a customization job.
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
@ -157,24 +155,20 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
**response,
)
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancels a customization job."""
await self._make_request(
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
"""Get artifacts for a specific training job."""
raise NotImplementedError("Job artifacts are not implemented yet")
@webmethod(route="/post-training/artifacts", method="GET")
async def get_post_training_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
"""Get all post-training artifacts."""
raise NotImplementedError("Job artifacts are not implemented yet")
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
async def supervised_fine_tune(
self,
job_uuid: str,
@ -184,6 +178,10 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig] = None,
extra_json: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
**kwargs,
) -> NvidiaPostTrainingJob:
"""
Fine-tunes a model on a dataset.
@ -204,8 +202,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
Environment Variables:
- NVIDIA_PROJECT_ID: ID of the project
- NVIDIA_USER_ID: ID of the user
- NVIDIA_ACCESS_POLICIES: Access policies for the project
- NVIDIA_DATASET_NAMESPACE: Namespace of the dataset
- NVIDIA_OUTPUT_MODEL_DIR: Directory to save the output model
@ -241,6 +237,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
# map model to nvidia model name
nvidia_model = self.get_provider_model_id(model)
# Check the extra parameters
print(hyperparam_search_config, extra_json, params, headers, kwargs)
# Check for unsupported parameters
if checkpoint_dir or hyperparam_search_config or logger_config:
warnings.warn(
@ -280,7 +279,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
"learning_rate": training_config["optimizer_config"].get("lr", 0.0001),
},
"project": self.config.project_id,
"ownership": {"created_by": self.config.user_id, "access_policies": self.config.access_policies},
# TODO: ignored ownership, add it later
# "ownership": {"created_by": self.config.user_id, "access_policies": self.config.access_policies},
"output_model": output_model,
}
@ -296,7 +296,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
warn_unsupported_params(lora_config, ["adapter_dim", "adapter_dropout"], "LoRA config")
else:
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
# Create the customization job
response = await self._make_request(
method="POST",
@ -326,7 +325,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
"""Optimize a model based on preference data."""
raise NotImplementedError("Preference optimization is not implemented yet")
@webmethod(route="/post-training/job/logs", method="GET")
async def get_training_job_container_logs(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
"""Get the container logs of a customization job."""
raise NotImplementedError("Job logs are not implemented yet")