mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
ignore ownership param
This commit is contained in:
parent
c885015e6f
commit
d7340da7a6
2 changed files with 11 additions and 24 deletions
|
@ -18,21 +18,11 @@ class NvidiaPostTrainingConfig(BaseModel):
|
|||
description="The NVIDIA API key.",
|
||||
)
|
||||
|
||||
user_id: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_USER_ID", "llama-stack-user"),
|
||||
description="The NVIDIA user ID.",
|
||||
)
|
||||
|
||||
dataset_namespace: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
|
||||
description="The NVIDIA dataset namespace.",
|
||||
)
|
||||
|
||||
access_policies: Optional[dict] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_ACCESS_POLICIES", {"arbitrary": "json"}),
|
||||
description="The NVIDIA access policies.",
|
||||
)
|
||||
|
||||
project_id: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-example-model@v1"),
|
||||
description="The NVIDIA project ID.",
|
||||
|
@ -64,7 +54,6 @@ class NvidiaPostTrainingConfig(BaseModel):
|
|||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.NVIDIA_API_KEY:}",
|
||||
"user_id": "${env.NVIDIA_USER_ID:llama-stack-user}",
|
||||
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
|
||||
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
|
||||
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}",
|
||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.apis.post_training import (
|
|||
AlgorithmConfig,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
PostTraining,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
|
@ -23,7 +24,6 @@ from llama_stack.providers.remote.post_training.nvidia.config import (
|
|||
NvidiaPostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
from .models import _MODEL_ENTRIES
|
||||
|
||||
|
@ -54,7 +54,7 @@ class ListNvidiaPostTrainingJobs(BaseModel):
|
|||
data: List[NvidiaPostTrainingJob]
|
||||
|
||||
|
||||
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||
class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
|
||||
def __init__(self, config: NvidiaPostTrainingConfig):
|
||||
self.config = config
|
||||
self.headers = {}
|
||||
|
@ -98,7 +98,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
raise Exception(f"API request failed: {error_data}")
|
||||
return await response.json()
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def get_training_jobs(
|
||||
self,
|
||||
page: Optional[int] = 1,
|
||||
|
@ -135,7 +134,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
|
||||
return ListNvidiaPostTrainingJobs(data=jobs)
|
||||
|
||||
@webmethod(route="/post-training/job/status", method="GET")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[NvidiaPostTrainingJob]:
|
||||
"""Get the status of a customization job.
|
||||
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
|
||||
|
@ -157,24 +155,20 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
**response,
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
"""Cancels a customization job."""
|
||||
await self._make_request(
|
||||
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
"""Get artifacts for a specific training job."""
|
||||
raise NotImplementedError("Job artifacts are not implemented yet")
|
||||
|
||||
@webmethod(route="/post-training/artifacts", method="GET")
|
||||
async def get_post_training_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
"""Get all post-training artifacts."""
|
||||
raise NotImplementedError("Job artifacts are not implemented yet")
|
||||
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
@ -184,6 +178,10 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
extra_json: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> NvidiaPostTrainingJob:
|
||||
"""
|
||||
Fine-tunes a model on a dataset.
|
||||
|
@ -204,8 +202,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
|
||||
Environment Variables:
|
||||
- NVIDIA_PROJECT_ID: ID of the project
|
||||
- NVIDIA_USER_ID: ID of the user
|
||||
- NVIDIA_ACCESS_POLICIES: Access policies for the project
|
||||
- NVIDIA_DATASET_NAMESPACE: Namespace of the dataset
|
||||
- NVIDIA_OUTPUT_MODEL_DIR: Directory to save the output model
|
||||
|
||||
|
@ -241,6 +237,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
# map model to nvidia model name
|
||||
nvidia_model = self.get_provider_model_id(model)
|
||||
|
||||
# Check the extra parameters
|
||||
print(hyperparam_search_config, extra_json, params, headers, kwargs)
|
||||
|
||||
# Check for unsupported parameters
|
||||
if checkpoint_dir or hyperparam_search_config or logger_config:
|
||||
warnings.warn(
|
||||
|
@ -280,7 +279,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
"learning_rate": training_config["optimizer_config"].get("lr", 0.0001),
|
||||
},
|
||||
"project": self.config.project_id,
|
||||
"ownership": {"created_by": self.config.user_id, "access_policies": self.config.access_policies},
|
||||
# TODO: ignored ownership, add it later
|
||||
# "ownership": {"created_by": self.config.user_id, "access_policies": self.config.access_policies},
|
||||
"output_model": output_model,
|
||||
}
|
||||
|
||||
|
@ -296,7 +296,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
warn_unsupported_params(lora_config, ["adapter_dim", "adapter_dropout"], "LoRA config")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
||||
|
||||
# Create the customization job
|
||||
response = await self._make_request(
|
||||
method="POST",
|
||||
|
@ -326,7 +325,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
"""Optimize a model based on preference data."""
|
||||
raise NotImplementedError("Preference optimization is not implemented yet")
|
||||
|
||||
@webmethod(route="/post-training/job/logs", method="GET")
|
||||
async def get_training_job_container_logs(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
"""Get the container logs of a customization job."""
|
||||
raise NotImplementedError("Job logs are not implemented yet")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue