feat: Updating files/content response to return additional fields

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-08-06 16:55:14 -04:00
parent e12524af85
commit a19c16428f
143 changed files with 6907 additions and 15104 deletions

View file

@ -5,8 +5,6 @@
# the root directory of this source tree.
from typing import Any
import pandas
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
@ -44,6 +42,8 @@ class PandasDataframeDataset:
if self.dataset_def.source.type == "uri":
self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows":
import pandas
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
@ -103,6 +103,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
return paginate_records(records, start_index, limit)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
import pandas
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
await dataset_impl.load()

View file

@ -71,8 +71,13 @@ class HuggingFacePostTrainingConfig(BaseModel):
dpo_beta: float = 0.1
use_reference_model: bool = True
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
dpo_output_dir: str = "./checkpoints/dpo"
dpo_output_dir: str
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}
return {
"checkpoint_format": "huggingface",
"distributed_backend": None,
"device": "cpu",
"dpo_output_dir": __distro_dir__ + "/dpo_output",
}

View file

@ -22,15 +22,8 @@ from llama_stack.apis.post_training import (
from llama_stack.providers.inline.post_training.huggingface.config import (
HuggingFacePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
HFDPOAlignmentSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
@ -85,6 +78,10 @@ class HuggingFacePostTrainingImpl:
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
on_log_message_cb("Starting HF finetuning")
recipe = HFFinetuningSingleDevice(
@ -124,6 +121,10 @@ class HuggingFacePostTrainingImpl:
logger_config: dict[str, Any],
) -> PostTrainingJob:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
HFDPOAlignmentSingleDevice,
)
on_log_message_cb("Starting HF DPO alignment")
recipe = HFDPOAlignmentSingleDevice(
@ -168,7 +169,6 @@ class HuggingFacePostTrainingImpl:
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
job = self._scheduler.get_job(job_uuid)
@ -195,16 +195,13 @@ class HuggingFacePostTrainingImpl:
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]

View file

@ -23,12 +23,8 @@ from llama_stack.apis.post_training import (
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
@ -84,6 +80,10 @@ class TorchtunePostTrainingImpl:
if isinstance(algorithm_config, LoraFinetuningConfig):
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
on_log_message_cb("Starting Lora finetuning")
recipe = LoraFinetuningSingleDevice(
@ -144,7 +144,6 @@ class TorchtunePostTrainingImpl:
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
job = self._scheduler.get_job(job_uuid)
@ -171,11 +170,9 @@ class TorchtunePostTrainingImpl:
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

View file

@ -4,7 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import re
import uuid
from string import Template
from typing import Any
@ -20,6 +22,7 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api
from llama_stack.models.llama.datatypes import Role
@ -67,6 +70,31 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
CAT_ELECTIONS: "S13",
CAT_CODE_INTERPRETER_ABUSE: "S14",
}
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES],
OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES],
OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION],
OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
OpenAICategories.HATE: [CAT_HATE],
OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES],
OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES],
OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION],
OpenAICategories.SELF_HARM: [CAT_SELF_HARM],
OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM],
OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
# These are custom categories that are not in the OpenAI moderation categories
"custom/defamation": [CAT_DEFAMATION],
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
"custom/privacy_violation": [CAT_PRIVACY],
"custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY],
"custom/weapons": [CAT_INDISCRIMINATE_WEAPONS],
"custom/elections": [CAT_ELECTIONS],
"custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE],
}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
@ -194,6 +222,34 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
return await impl.run(messages)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
if isinstance(input, list):
messages = input.copy()
else:
messages = [input]
# convert to user messages format with role
messages = [UserMessage(content=m) for m in messages]
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
if model in LLAMA_GUARD_MODEL_IDS:
# Use the mapped model for categories but the original model_id for inference
mapped_model = LLAMA_GUARD_MODEL_IDS[model]
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
else:
# For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
impl = LlamaGuardShield(
model=model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
safety_categories=safety_categories,
)
return await impl.run_moderation(messages)
class LlamaGuardShield:
def __init__(
@ -340,3 +396,117 @@ class LlamaGuardShield:
)
raise ValueError(f"Unexpected response: {response}")
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
if not messages:
return self.create_moderation_object(self.model)
# TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion(
model=self.model,
messages=[shield_input_message],
stream=False,
)
content = response.choices[0].message.content
content = content.strip()
return self.get_moderation_object(content)
def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject:
"""Create a ModerationObject for either safe or unsafe content.
Args:
model: The model name
unsafe_code: Optional comma-separated list of safety codes. If None, creates safe object.
Returns:
ModerationObject with appropriate configuration
"""
# Set default values for safe case
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0)
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
flagged = False
user_message = None
metadata = {}
# Handle unsafe case
if unsafe_code:
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
if invalid_codes:
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
# just returning safe object, as we don't know what the invalid codes can map to
return ModerationObject(
id=f"modr-{uuid.uuid4()}",
model=model,
results=[
ModerationObjectResults(
flagged=flagged,
categories=categories,
category_applied_input_types=category_applied_input_types,
category_scores=category_scores,
user_message=user_message,
metadata=metadata,
)
],
)
# Get OpenAI categories for the unsafe codes
openai_categories = []
for code in unsafe_code_list:
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code]
openai_categories.extend(
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
)
# Update categories for unsafe content
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_applied_input_types = {
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP
}
flagged = True
user_message = CANNED_RESPONSE_TEXT
metadata = {"violation_type": unsafe_code_list}
return ModerationObject(
id=f"modr-{uuid.uuid4()}",
model=model,
results=[
ModerationObjectResults(
flagged=flagged,
categories=categories,
category_applied_input_types=category_applied_input_types,
category_scores=category_scores,
user_message=user_message,
metadata=metadata,
)
],
)
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
"""Check if content is safe based on response and unsafe code."""
if response.strip() == SAFE_RESPONSE:
return True
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return True
return False
def get_moderation_object(self, response: str) -> ModerationObject:
response = response.strip()
if self.is_content_safe(response):
return self.create_moderation_object(self.model)
unsafe_code = self.check_unsafe_response(response)
if not unsafe_code:
raise ValueError(f"Unexpected response: {response}")
if self.is_content_safe(response, unsafe_code):
return self.create_moderation_object(self.model)
else:
return self.create_moderation_object(self.model, unsafe_code)

View file

@ -28,9 +28,6 @@ class ConsoleSpanProcessor(SpanProcessor):
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
def on_end(self, span: ReadableSpan) -> None:
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
if span.status.status_code == StatusCode.ERROR:
@ -67,7 +64,7 @@ class ConsoleSpanProcessor(SpanProcessor):
for key, value in event.attributes.items():
if key.startswith("__") or key in ["message", "severity"]:
continue
logger.info(f"/r[dim]{key}[/dim]: {value}")
logger.info(f"[dim]{key}[/dim]: {value}")
def shutdown(self) -> None:
"""Shutdown the processor."""

View file

@ -4,10 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import threading
from typing import Any
from opentelemetry import metrics, trace
logger = logging.getLogger(__name__)
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
@ -110,7 +113,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(SQLiteSpanProcessor(self.config.sqlite_db_path))
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor(print_attributes=True))
if TelemetrySink.OTEL_METRIC in self.config.sinks:
self.meter = metrics.get_meter(__name__)
@ -126,9 +129,11 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
trace.get_tracer_provider().force_flush()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
logger.debug(f"DEBUG: log_event called with event type: {type(event).__name__}")
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event, ttl_seconds)
elif isinstance(event, MetricEvent):
logger.debug("DEBUG: Routing MetricEvent to _log_metric")
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event, ttl_seconds)
@ -188,6 +193,38 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
# Always log to console if console sink is enabled (debug)
if TelemetrySink.CONSOLE in self.config.sinks:
logger.debug(f"METRIC: {event.metric}={event.value} {event.unit} {event.attributes}")
# Add metric as an event to the current span
try:
with self._lock:
# Only try to add to span if we have a valid span_id
if event.span_id:
try:
span_id = int(event.span_id, 16)
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event(
name=f"metric.{event.metric}",
attributes={
"value": event.value,
"unit": event.unit,
**(event.attributes or {}),
},
timestamp=timestamp_ns,
)
except (ValueError, KeyError):
# Invalid span_id or span not found, but we already logged to console above
pass
except Exception:
# Lock acquisition failed
logger.debug("Failed to acquire lock to add metric to span")
# Log to OpenTelemetry meter if available
if self.meter is None:
return
if isinstance(event.value, int):