Merge remote-tracking branch 'origin/main' into if_eval

This commit is contained in:
Botao Chen 2025-03-18 23:23:13 -07:00
commit a690c7b230
123 changed files with 4482 additions and 3161 deletions

View file

@ -614,118 +614,133 @@ class ChatAgent(ShieldRunnerMixin):
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
input_messages = input_messages + [message]
else:
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
# 1. Start the tool execution step and progress
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
tool_call = message.tool_calls[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call=tool_call,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
tool_call=tool_call,
),
)
)
)
input_messages = input_messages + [message]
# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
# NOTE: mark end_of_message to indicate to client that it may
# call the tool and continue the conversation with the tool's response.
message.stop_reason = StopReason.end_of_message
# Process tool calls in the message
client_tool_calls = []
non_client_tool_calls = []
# Separate client and non-client tool calls
for tool_call in message.tool_calls:
if tool_call.tool_name in client_tools:
client_tool_calls.append(tool_call)
else:
non_client_tool_calls.append(tool_call)
# Process non-client tool calls first
for tool_call in non_client_tool_calls:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
tool_call=tool_call,
),
)
)
)
# Execute the tool call
async with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
)
if tool_result.content is None:
raise ValueError(
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
)
result_message = ToolResponseMessage(
call_id=tool_call.call_id,
content=tool_result.content,
)
span.set_attribute("output", result_message.model_dump_json())
# Store tool execution step
tool_execution_step = ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[
ToolResponse(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
metadata=tool_result.metadata,
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(timezone.utc).isoformat(),
)
# Yield the step completion event
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=tool_execution_step,
)
)
)
# Add the result message to input_messages for the next iteration
input_messages.append(result_message)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
output_attachments.append(out_attachment)
# If there are client tool calls, yield a message with only those tool calls
if client_tool_calls:
await self.storage.set_in_progress_tool_call_step(
session_id,
turn_id,
ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_calls=client_tool_calls,
tool_responses=[],
started_at=datetime.now(timezone.utc).isoformat(),
),
)
yield message
# Create a copy of the message with only client tool calls
client_message = message.model_copy(deep=True)
client_message.tool_calls = client_tool_calls
# NOTE: mark end_of_message to indicate to client that it may
# call the tool and continue the conversation with the tool's response.
client_message.stop_reason = StopReason.end_of_message
# Yield the message with client tool calls
yield client_message
return
# If tool is a builtin server tool, execute it
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
async with tracing.span(
"tool_execution",
{
"tool_name": tool_name,
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
tool_call = message.tool_calls[0]
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
)
if tool_result.content is None:
raise ValueError(
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
)
result_messages = [
ToolResponseMessage(
call_id=tool_call.call_id,
content=tool_result.content,
)
]
assert len(result_messages) == 1, "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[
ToolResponse(
call_id=result_message.call_id,
tool_name=tool_call.tool_name,
content=result_message.content,
metadata=tool_result.metadata,
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(timezone.utc).isoformat(),
),
)
)
)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
output_attachments.append(out_attachment)
input_messages = input_messages + [message, result_message]
async def _initialize_tools(
self,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
@ -891,16 +906,14 @@ class ChatAgent(ShieldRunnerMixin):
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
await attachment_message(self.tempdir, url_items, input_messages[-1])
# Since memory is present, add all the data to the memory bank
await self.add_to_session_vector_db(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
await attachment_message(self.tempdir, url_items, input_messages[-1])
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_vector_db(session_id, documents)
@ -967,8 +980,8 @@ async def load_data_from_urls(urls: List[URL]) -> List[str]:
return data
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
content = []
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
contents = []
for url in urls:
uri = url.uri
@ -988,16 +1001,19 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
else:
raise ValueError(f"Unsupported URL {url}")
content.append(
contents.append(
TextContentItem(
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
)
)
return ToolResponseMessage(
call_id="",
content=content,
)
if isinstance(message.content, list):
message.content.extend(contents)
else:
if isinstance(message.content, str):
message.content = [TextContentItem(text=message.content)] + contents
else:
message.content = [message.content] + contents
def _interpret_content_as_attachment(

View file

@ -3,20 +3,14 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig
@ -24,30 +18,7 @@ from .config import LocalFSDatasetIOConfig
DATASETS_PREFIX = "localfs_datasets:"
class BaseDataset(ABC):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError()
@abstractmethod
def __getitem__(self, idx):
raise NotImplementedError()
@abstractmethod
def load(self):
raise NotImplementedError()
@dataclass
class DatasetInfo:
dataset_def: Dataset
dataset_impl: BaseDataset
class PandasDataframeDataset(BaseDataset):
class PandasDataframeDataset:
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.dataset_def = dataset_def
@ -64,23 +35,19 @@ class PandasDataframeDataset(BaseDataset):
else:
return self.df.iloc[idx].to_dict()
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
# note that we will drop any columns in dataset that are not in the schema
df = df[self.dataset_def.dataset_schema.keys()]
# check all columns in dataset schema are present
assert len(df.columns) == len(self.dataset_def.dataset_schema)
# TODO: type checking against column types in dataset schema
return df
def load(self) -> None:
if self.df is not None:
return
df = get_dataframe_from_url(self.dataset_def.url)
if df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
if self.dataset_def.source.type == "uri":
self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
self.df = self._validate_dataset_schema(df)
if self.df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -99,95 +66,55 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
)
self.dataset_infos[dataset.identifier] = dataset
async def shutdown(self) -> None: ...
async def register_dataset(
self,
dataset: Dataset,
dataset_def: Dataset,
) -> None:
# Store in kvstore
key = f"{DATASETS_PREFIX}{dataset.identifier}"
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set(
key=key,
value=dataset.json(),
)
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
value=dataset_def.model_dump_json(),
)
self.dataset_infos[dataset_def.identifier] = dataset_def
async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
dataset_info = self.dataset_infos.get(dataset_id)
dataset_info.dataset_impl.load()
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
start_index = start_index or 0
if page_token is None or len(page_token) == 0:
next_page_token = 0
if limit is None or limit == -1:
end = len(dataset_impl)
else:
next_page_token = int(page_token)
end = min(start_index + limit, len(dataset_impl))
start = next_page_token
if rows_in_page == -1:
end = len(dataset_info.dataset_impl)
else:
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
rows = dataset_impl[start_index:end]
rows = dataset_info.dataset_impl[start:end]
return PaginatedRowsResult(
rows=rows,
total_count=len(rows),
next_page_token=str(end),
return IterrowsResponse(
data=rows,
next_start_index=end if end < len(dataset_impl) else None,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_info = self.dataset_infos.get(dataset_id)
if dataset_info is None:
raise ValueError(f"Dataset with id {dataset_id} not found")
dataset_impl = dataset_info.dataset_impl
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()
new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
url = str(dataset_info.dataset_def.url.uri)
parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
dataset_impl.df.to_csv(file_path, index=False)
elif parsed_url.scheme == "data":
# For data URLs, we need to update the base64-encoded content
if not parsed_url.path.startswith("text/csv;base64,"):
raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
else:
raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
)

View file

@ -14,16 +14,11 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
get_valid_schemas,
validate_dataset_schema,
)
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
@ -88,15 +83,17 @@ class MetaReferenceEvalImpl(
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
all_rows = await self.datasetio_api.get_rows_paginated(
# TODO (xiyan): validate dataset schema
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
)
res = await self.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=all_rows.rows,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)

View file

@ -10,6 +10,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import copy
import json
import logging
import multiprocessing
@ -213,7 +214,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
def parse_message(json_str: str) -> ProcessingMessage:
data = json.loads(json_str)
return ProcessingMessageWrapper(**data).payload
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
def worker_process_entrypoint(

View file

@ -9,6 +9,9 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
DialogType,
@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
validate_dataset_schema,
)
EXPECTED_DATASET_SCHEMA = {
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
"instruct": [
{
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
@ -41,6 +44,9 @@ async def validate_input_dataset_schema(
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def:
raise ValueError(f"Dataset {dataset_id} does not exist.")
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")

View file

@ -37,7 +37,7 @@ class TorchtuneCheckpointer:
checkpoint_files: List[str],
output_dir: str,
model_type: str,
) -> None:
):
# Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file
if len(checkpoint_files) != 1:
@ -58,7 +58,7 @@ class TorchtuneCheckpointer:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str:Any] = {}
state_dict: Dict[str, Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
@ -85,10 +85,10 @@ class TorchtuneCheckpointer:
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
checkpoint_format: str = "meta",
checkpoint_format: str | None = None,
) -> str:
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
if checkpoint_format == "meta":
if checkpoint_format == "meta" or checkpoint_format is None:
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
elif checkpoint_format == "huggingface":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now

View file

@ -10,7 +10,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Callable, Dict
from typing import Callable, Dict
import torch
from pydantic import BaseModel
@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import resolve_model
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
class ModelConfig(BaseModel):
model_definition: Any
tokenizer_type: Any
model_definition: BuildLoraModelCallable
tokenizer_type: BuildTokenizerCallable
checkpoint_type: str
@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = {
}
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
def _validate_model_id(model_id: str) -> Model:
model = resolve_model(model_id)
if model is None or model.core_model_id.value not in MODEL_CONFIGS:

View file

@ -55,7 +55,7 @@ class SFTDataset(Dataset):
if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"])
tokenized_dict = self._model_transform(transformed_sample)
tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample)
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys())

View file

@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
LoraFinetuningConfig,
OptimizerConfig,
QATFinetuningConfig,
TrainingConfig,
)
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice:
# Currently logging only logs limited training metrics to local disk
# will figure out more loggings and how it works with telemetry in future PRs
_checkpointer: TorchtuneCheckpointer
def __init__(
self,
config: TorchtunePostTrainingConfig,
@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice:
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice:
return str(checkpoint_dir)
if checkpoint_dir and checkpoint_dir != "null":
self.checkpoint_dir = config.checkpoint_dir
self.checkpoint_dir = checkpoint_dir
else:
model = resolve_model(self.model_id)
if model is None:
model_obj = resolve_model(self.model_id)
if model_obj is None:
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
self.checkpoint_dir = model_checkpoint_dir(model)
self.checkpoint_dir = model_checkpoint_dir(model_obj)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self._checkpoint_format = config.checkpoint_format
@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice:
self.max_validation_steps = training_config.max_validation_steps
self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = (
(training_config.efficiency_config.enable_activation_checkpointing)
if training_config.efficiency_config
else False
)
self._enable_activation_offloading = (
(training_config.efficiency_config.enable_activation_offloading)
if training_config.efficiency_config
else False
)
self._enable_activation_checkpointing = False
self._enable_activation_offloading = False
if training_config.efficiency_config:
if training_config.efficiency_config.enable_activation_checkpointing:
self._enable_activation_checkpointing = (
training_config.efficiency_config.enable_activation_checkpointing
)
if training_config.efficiency_config.enable_activation_offloading:
self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
@ -328,13 +331,13 @@ class LoraFinetuningSingleDevice:
batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows(dataset_id: str):
return await self.datasetio_api.get_rows_paginated(
return await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
all_rows = await fetch_rows(dataset_id)
rows = all_rows.rows
rows = all_rows.data
await validate_input_dataset_schema(
datasets_api=self.datasets_api,
@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice:
"""
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
running_loss = 0
running_loss: float = 0.0
num_tokens = 0
# training artifacts
checkpoints = []
memory_stats = {}
memory_stats: Dict[str, Any] = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice:
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = await self._loss_step(batch) * current_num_tokens
running_loss += current_loss
running_loss += current_loss.detach().item()
current_loss.backward()
# Step with optimizer
@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice:
# Update the number of steps when the weights are updated
self.global_step += 1
loss_to_log = running_loss.item() / num_tokens
loss_to_log = running_loss / num_tokens
pbar.update(1)
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice:
)
# Reset running stats for the next step
running_loss = 0
running_loss = 0.0
num_tokens = 0
t0 = time.perf_counter()

View file

@ -227,13 +227,6 @@ class LlamaGuardShield:
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
messages = messages[1:]
for i in range(1, len(messages)):
if messages[i].role == messages[i - 1].role:
for i, m in enumerate(messages):
print(f"{i}: {m.role}: {m.content}")
raise ValueError(
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
)
return messages
async def run(self, messages: List[Message]) -> RunShieldResponse:

View file

@ -24,7 +24,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn,
)
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
@ -82,12 +84,12 @@ class BasicScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
res = await self.score(
input_rows=all_rows.rows,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
)
if save_results_dataset:

View file

@ -167,11 +167,11 @@ class BraintrustScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
res = await self.score(input_rows=all_rows.data, scoring_functions=scoring_functions)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()

View file

@ -72,12 +72,12 @@ class LlmAsJudgeScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
res = await self.score(
input_rows=all_rows.rows,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
)
if save_results_dataset:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import logging
import os
import tempfile
@ -37,7 +38,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
async def register_tool(self, tool: Tool) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
@ -65,7 +66,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
# Use environment variable to control bwrap usage
force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes")
req = CodeExecutionRequest(scripts=[script], use_bwrap=not force_disable_bwrap)
res = self.code_executor.execute(req)
res = await asyncio.to_thread(self.code_executor.execute, req)
pieces = [res["process_status"]]
for out_type in ["stdout", "stderr"]:
res_out = res[out_type]

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class QdrantVectorIOConfig(BaseModel):
path: str
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"path": "${env.QDRANT_PATH:~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
}