mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-25 04:38:04 +00:00
feat: add auto-generated CI documentation pre-commit hook (#2890)
Our CI is entirely undocumented, this commit adds a README.md file with a table of the current CI and what is does --------- Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
parent
7f834339ba
commit
b381ed6d64
93 changed files with 495 additions and 477 deletions
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
|
@ -19,10 +18,9 @@ from llama_stack.apis.post_training import (
|
|||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
skip_because_resource_intensive = pytest.mark.skip(
|
||||
|
|
@ -71,14 +69,14 @@ class TestPostTraining:
|
|||
)
|
||||
@pytest.mark.timeout(360) # 6 minutes timeout
|
||||
def test_supervised_fine_tune(self, llama_stack_client, purpose, source):
|
||||
logger.info("Starting supervised fine-tuning test")
|
||||
log.info("Starting supervised fine-tuning test")
|
||||
|
||||
# register dataset to train
|
||||
dataset = llama_stack_client.datasets.register(
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
)
|
||||
logger.info(f"Registered dataset with ID: {dataset.identifier}")
|
||||
log.info(f"Registered dataset with ID: {dataset.identifier}")
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
|
|
@ -105,7 +103,7 @@ class TestPostTraining:
|
|||
)
|
||||
|
||||
job_uuid = f"test-job{uuid.uuid4()}"
|
||||
logger.info(f"Starting training job with UUID: {job_uuid}")
|
||||
log.info(f"Starting training job with UUID: {job_uuid}")
|
||||
|
||||
# train with HF trl SFTTrainer as the default
|
||||
_ = llama_stack_client.post_training.supervised_fine_tune(
|
||||
|
|
@ -121,21 +119,21 @@ class TestPostTraining:
|
|||
while True:
|
||||
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
|
||||
if not status:
|
||||
logger.error("Job not found")
|
||||
log.error("Job not found")
|
||||
break
|
||||
|
||||
logger.info(f"Current status: {status}")
|
||||
log.info(f"Current status: {status}")
|
||||
assert status.status in ["scheduled", "in_progress", "completed"]
|
||||
if status.status == "completed":
|
||||
break
|
||||
|
||||
logger.info("Waiting for job to complete...")
|
||||
log.info("Waiting for job to complete...")
|
||||
time.sleep(10) # Increased sleep time to reduce polling frequency
|
||||
|
||||
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
|
||||
logger.info(f"Job artifacts: {artifacts}")
|
||||
log.info(f"Job artifacts: {artifacts}")
|
||||
|
||||
logger.info(f"Registered dataset with ID: {dataset.identifier}")
|
||||
log.info(f"Registered dataset with ID: {dataset.identifier}")
|
||||
|
||||
# TODO: Fix these tests to properly represent the Jobs API in training
|
||||
#
|
||||
|
|
@ -181,17 +179,21 @@ class TestPostTraining:
|
|||
)
|
||||
@pytest.mark.timeout(360)
|
||||
def test_preference_optimize(self, llama_stack_client, purpose, source):
|
||||
logger.info("Starting DPO preference optimization test")
|
||||
log.info("Starting DPO preference optimization test")
|
||||
|
||||
# register preference dataset to train
|
||||
dataset = llama_stack_client.datasets.register(
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
)
|
||||
logger.info(f"Registered preference dataset with ID: {dataset.identifier}")
|
||||
log.info(f"Registered preference dataset with ID: {dataset.identifier}")
|
||||
|
||||
# DPO algorithm configuration
|
||||
algorithm_config = DPOAlignmentConfig(
|
||||
reward_scale=1.0,
|
||||
reward_clip=10.0,
|
||||
epsilon=1e-8,
|
||||
gamma=0.99,
|
||||
beta=0.1,
|
||||
loss_type=DPOLossType.sigmoid, # Default loss type
|
||||
)
|
||||
|
|
@ -211,7 +213,7 @@ class TestPostTraining:
|
|||
)
|
||||
|
||||
job_uuid = f"test-dpo-job-{uuid.uuid4()}"
|
||||
logger.info(f"Starting DPO training job with UUID: {job_uuid}")
|
||||
log.info(f"Starting DPO training job with UUID: {job_uuid}")
|
||||
|
||||
# train with HuggingFace DPO implementation
|
||||
_ = llama_stack_client.post_training.preference_optimize(
|
||||
|
|
@ -226,15 +228,15 @@ class TestPostTraining:
|
|||
while True:
|
||||
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
|
||||
if not status:
|
||||
logger.error("DPO job not found")
|
||||
log.error("DPO job not found")
|
||||
break
|
||||
|
||||
logger.info(f"Current DPO status: {status}")
|
||||
log.info(f"Current DPO status: {status}")
|
||||
if status.status == "completed":
|
||||
break
|
||||
|
||||
logger.info("Waiting for DPO job to complete...")
|
||||
log.info("Waiting for DPO job to complete...")
|
||||
time.sleep(10) # Increased sleep time to reduce polling frequency
|
||||
|
||||
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
|
||||
logger.info(f"DPO job artifacts: {artifacts}")
|
||||
log.info(f"DPO job artifacts: {artifacts}")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
|
|
@ -13,8 +12,10 @@ from llama_stack_client import BadRequestError, LlamaStackClient
|
|||
from openai import BadRequestError as OpenAIBadRequestError
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector-io")
|
||||
|
||||
|
||||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||
|
|
@ -99,7 +100,7 @@ def compat_client_with_empty_stores(compat_client):
|
|||
compat_client.vector_stores.delete(vector_store_id=store.id)
|
||||
except Exception:
|
||||
# If the API is not available or fails, just continue
|
||||
logger.warning("Failed to clear vector stores")
|
||||
log.warning("Failed to clear vector stores")
|
||||
pass
|
||||
|
||||
def clear_files():
|
||||
|
|
@ -109,7 +110,7 @@ def compat_client_with_empty_stores(compat_client):
|
|||
compat_client.files.delete(file_id=file.id)
|
||||
except Exception:
|
||||
# If the API is not available or fails, just continue
|
||||
logger.warning("Failed to clear files")
|
||||
log.warning("Failed to clear files")
|
||||
pass
|
||||
|
||||
clear_vector_stores()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue