fix changes post merge

This commit is contained in:
raspawar 2025-03-21 18:09:17 +05:30
parent e95b1e9739
commit e4b39aacb8
5 changed files with 65 additions and 20 deletions

View file

@ -14,7 +14,7 @@ from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
PostTraining,
LoraFinetuningConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
@ -53,7 +53,11 @@ class ListNvidiaPostTrainingJobs(BaseModel):
data: List[NvidiaPostTrainingJob]
class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
model_config = ConfigDict(extra="allow")
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
def __init__(self, config: NvidiaPostTrainingConfig):
self.config = config
self.headers = {}
@ -146,7 +150,7 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
return ListNvidiaPostTrainingJobs(data=jobs)
async def get_training_job_status(self, job_uuid: str) -> Optional[NvidiaPostTrainingJob]:
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
"""Get the status of a customization job.
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
@ -175,10 +179,10 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
api_status = response.pop("status").lower()
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
return NvidiaPostTrainingJob(
return NvidiaPostTrainingJobStatusResponse(
status=JobStatus(mapped_status),
job_uuid=job_uuid,
created_at=datetime.fromisoformat(response.pop("created_at")),
started_at=datetime.fromisoformat(response.pop("created_at")),
updated_at=datetime.fromisoformat(response.pop("updated_at")),
**response,
)
@ -188,10 +192,10 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
)
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def get_post_training_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def supervised_fine_tune(
@ -389,14 +393,14 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
# Handle LoRA-specific configuration
if algorithm_config:
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
if isinstance(algorithm_config, LoraFinetuningConfig) and algorithm_config.type == "LoRA":
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
job_config["hyperparameters"]["lora"] = {
k: v
for k, v in {
"adapter_dim": algorithm_config.get("adapter_dim"),
"alpha": algorithm_config.get("alpha"),
"adapter_dropout": algorithm_config.get("adapter_dropout"),
"adapter_dim": getattr(algorithm_config, "adapter_dim", None),
"alpha": getattr(algorithm_config, "alpha", None),
"adapter_dropout": getattr(algorithm_config, "adapter_dropout", None),
}.items()
if v is not None
}
@ -432,5 +436,5 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
"""Optimize a model based on preference data."""
raise NotImplementedError("Preference optimization is not implemented yet")
async def get_training_job_container_logs(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
raise NotImplementedError("Job logs are not implemented yet")