fix all occurrence

This commit is contained in:
Xi Yan 2025-03-15 14:20:17 -07:00
parent 917679cc2f
commit a9c662d68b
9 changed files with 177 additions and 70 deletions

View file

@ -537,7 +537,7 @@ class DatasetIORouter(DatasetIO):
logger.debug("DatasetIORouter.shutdown") logger.debug("DatasetIORouter.shutdown")
pass pass
async def get_rows_paginated( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
rows_in_page: int, rows_in_page: int,
@ -545,11 +545,9 @@ class DatasetIORouter(DatasetIO):
filter_condition: Optional[str] = None, filter_condition: Optional[str] = None,
) -> IterrowsResponse: ) -> IterrowsResponse:
logger.debug( logger.debug(
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}", f"DatasetIORouter.iterrows: {dataset_id}, rows_in_page={rows_in_page}",
) )
return await self.routing_table.get_provider_impl( return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id
).get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=rows_in_page, rows_in_page=rows_in_page,
page_token=page_token, page_token=page_token,

View file

@ -166,7 +166,7 @@ def run_evaluation_3():
eval_candidate = st.session_state["eval_candidate"] eval_candidate = st.session_state["eval_candidate"]
dataset_id = benchmarks[selected_benchmark].dataset_id dataset_id = benchmarks[selected_benchmark].dataset_id
rows = llama_stack_api.client.datasetio.get_rows_paginated( rows = llama_stack_api.client.datasetio.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )
@ -230,7 +230,9 @@ def run_evaluation_3():
output_res[scoring_fn] = [] output_res[scoring_fn] = []
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})") progress_text_container.write(
f"Expand to see current processed result ({i + 1} / {len(rows)})"
)
results_container.json(eval_res, expanded=2) results_container.json(eval_res, expanded=2)
progress_bar.progress(1.0, text="Evaluation complete!") progress_bar.progress(1.0, text="Evaluation complete!")

View file

@ -128,7 +128,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
await self.kvstore.delete(key=key) await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id] del self.dataset_infos[dataset_id]
async def get_rows_paginated( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
rows_in_page: int, rows_in_page: int,

View file

@ -89,10 +89,16 @@ class MetaReferenceEvalImpl(
dataset_id = task_def.dataset_id dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)) validate_dataset_schema(
all_rows = await self.datasetio_api.get_rows_paginated( dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), rows_in_page=(
-1
if benchmark_config.num_examples is None
else benchmark_config.num_examples
),
) )
res = await self.evaluate_rows( res = await self.evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
@ -118,10 +124,14 @@ class MetaReferenceEvalImpl(
for i, x in tqdm(enumerate(input_rows)): for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row" assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = json.loads(x[ColumnName.chat_completion_input.value]) input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"] input_messages = [
UserMessage(**x) for x in input_messages if x["role"] == "user"
]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row # NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}") session_create_response = await self.agents_api.create_agent_session(
agent_id, f"session-{i}"
)
session_id = session_create_response.session_id session_id = session_create_response.session_id
turn_request = dict( turn_request = dict(
@ -130,7 +140,12 @@ class MetaReferenceEvalImpl(
messages=input_messages, messages=input_messages,
stream=True, stream=True,
) )
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)] turn_response = [
chunk
async for chunk in await self.agents_api.create_agent_turn(
**turn_request
)
]
final_event = turn_response[-1].event.payload final_event = turn_response[-1].event.payload
# check if there's a memory retrieval step and extract the context # check if there's a memory retrieval step and extract the context
@ -139,10 +154,14 @@ class MetaReferenceEvalImpl(
if step.step_type == StepType.tool_execution.value: if step.step_type == StepType.tool_execution.value:
for tool_response in step.tool_responses: for tool_response in step.tool_responses:
if tool_response.tool_name == MEMORY_QUERY_TOOL: if tool_response.tool_name == MEMORY_QUERY_TOOL:
memory_rag_context = " ".join(x.text for x in tool_response.content) memory_rag_context = " ".join(
x.text for x in tool_response.content
)
agent_generation = {} agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content agent_generation[ColumnName.generated_answer.value] = (
final_event.turn.output_message.content
)
if memory_rag_context: if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context agent_generation[ColumnName.context.value] = memory_rag_context
@ -154,7 +173,9 @@ class MetaReferenceEvalImpl(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
candidate = benchmark_config.eval_candidate candidate = benchmark_config.eval_candidate
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided" assert (
candidate.sampling_params.max_tokens is not None
), "SamplingParams.max_tokens must be provided"
generations = [] generations = []
for x in tqdm(input_rows): for x in tqdm(input_rows):
@ -165,21 +186,39 @@ class MetaReferenceEvalImpl(
content=input_content, content=input_content,
sampling_params=candidate.sampling_params, sampling_params=candidate.sampling_params,
) )
generations.append({ColumnName.generated_answer.value: response.completion_message.content}) generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
elif ColumnName.chat_completion_input.value in x: elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) chat_completion_input_json = json.loads(
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"] x[ColumnName.chat_completion_input.value]
)
input_messages = [
UserMessage(**x)
for x in chat_completion_input_json
if x["role"] == "user"
]
messages = [] messages = []
if candidate.system_message: if candidate.system_message:
messages.append(candidate.system_message) messages.append(candidate.system_message)
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"] messages += [
SystemMessage(**x)
for x in chat_completion_input_json
if x["role"] == "system"
]
messages += input_messages messages += input_messages
response = await self.inference_api.chat_completion( response = await self.inference_api.chat_completion(
model_id=candidate.model, model_id=candidate.model,
messages=messages, messages=messages,
sampling_params=candidate.sampling_params, sampling_params=candidate.sampling_params,
) )
generations.append({ColumnName.generated_answer.value: response.completion_message.content}) generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
else: else:
raise ValueError("Invalid input row") raise ValueError("Invalid input row")
@ -202,7 +241,8 @@ class MetaReferenceEvalImpl(
# scoring with generated_answer # scoring with generated_answer
score_input_rows = [ score_input_rows = [
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False) input_r | generated_r
for input_r, generated_r in zip(input_rows, generations, strict=False)
] ]
if benchmark_config.scoring_params is not None: if benchmark_config.scoring_params is not None:
@ -211,7 +251,9 @@ class MetaReferenceEvalImpl(
for scoring_fn_id in scoring_functions for scoring_fn_id in scoring_functions
} }
else: else:
scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions} scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score( score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict input_rows=score_input_rows, scoring_functions=scoring_functions_dict

View file

@ -17,8 +17,7 @@ import torch
from torch import nn from torch import nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training from torchtune import modules, training, utils as torchtune_utils
from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import ( from torchtune.modules.peft import (
@ -89,7 +88,9 @@ class LoraFinetuningSingleDevice:
self.job_uuid = job_uuid self.job_uuid = job_uuid
self.training_config = training_config self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig): if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning") raise ValueError(
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
)
self.algorithm_config = algorithm_config self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device() self._device = torchtune_utils.get_device()
self._dtype = training.get_dtype(training_config.dtype, device=self._device) self._dtype = training.get_dtype(training_config.dtype, device=self._device)
@ -98,7 +99,10 @@ class LoraFinetuningSingleDevice:
def model_checkpoint_dir(model) -> str: def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor())) checkpoint_dir = Path(model_local_dir(model.descriptor()))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]] paths = [
Path(checkpoint_dir / f"consolidated.{ext}")
for ext in ["pth", "00.pth"]
]
if not any(p.exists() for p in paths): if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original" checkpoint_dir = checkpoint_dir / "original"
@ -113,7 +117,9 @@ class LoraFinetuningSingleDevice:
else: else:
model = resolve_model(self.model_id) model = resolve_model(self.model_id)
if model is None: if model is None:
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list") raise ValueError(
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
)
self.checkpoint_dir = model_checkpoint_dir(model) self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
@ -185,7 +191,9 @@ class LoraFinetuningSingleDevice:
self._tokenizer = await self._setup_tokenizer() self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.") log.info("Tokenizer is initialized.")
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config) self._optimizer = await self._setup_optimizer(
optimizer_config=self.training_config.optimizer_config
)
log.info("Optimizer is initialized.") log.info("Optimizer is initialized.")
self._loss_fn = CEWithChunkedOutputLoss() self._loss_fn = CEWithChunkedOutputLoss()
@ -213,8 +221,13 @@ class LoraFinetuningSingleDevice:
# by the dataloader and the max_steps_per_epoch param set by the user and is used # by the dataloader and the max_steps_per_epoch param set by the user and is used
# for logging and tracking training state. This should be computed after the dataloader # for logging and tracking training state. This should be computed after the dataloader
# has been setup # has been setup
self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps self._steps_per_epoch = (
if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch: len(self._training_dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = self.max_steps_per_epoch self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch
@ -228,7 +241,9 @@ class LoraFinetuningSingleDevice:
log.info("Learning rate scheduler is initialized.") log.info("Learning rate scheduler is initialized.")
# Used to ignore labels for loss computation # Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device) self.ignore_labels_cache = torch.full(
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
)
def _log_memory_stats(self): def _log_memory_stats(self):
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing # torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
@ -269,9 +284,13 @@ class LoraFinetuningSingleDevice:
set_trainable_params(model, self.adapter_params) set_trainable_params(model, self.adapter_params)
if enable_activation_checkpointing: if enable_activation_checkpointing:
training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}) training.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)
base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False) base_missing, base_unexpected = model.load_state_dict(
base_model_state_dict, strict=False
)
# This is for any adapters that need to be initialized after base weights # This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA). # have been loaded (e.g. DoRA).
@ -280,7 +299,9 @@ class LoraFinetuningSingleDevice:
if hasattr(m, "initialize_dora_magnitude"): if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude() m.initialize_dora_magnitude()
if lora_weights_state_dict: if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False) lora_missing, lora_unexpected = model.load_state_dict(
lora_weights_state_dict, strict=False
)
else: else:
lora_missing, lora_unexpected = None, None lora_missing, lora_unexpected = None, None
validate_missing_and_unexpected_for_lora( validate_missing_and_unexpected_for_lora(
@ -294,10 +315,14 @@ class LoraFinetuningSingleDevice:
) )
# Validate model adapter params were loaded in with the expected dtype # Validate model adapter params were loaded in with the expected dtype
training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype) training.validate_expected_param_dtype(
self.adapter_params.items(), dtype=self._dtype
)
# activation offloading # activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading) self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)
self._log_memory_stats() self._log_memory_stats()
@ -328,7 +353,7 @@ class LoraFinetuningSingleDevice:
batch_size: int, batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]: ) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows(dataset_id: str): async def fetch_rows(dataset_id: str):
return await self.datasetio_api.get_rows_paginated( return await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )
@ -433,7 +458,9 @@ class LoraFinetuningSingleDevice:
# Shift labels to compute loss # Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :] # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels. # But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])) labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list): if not isinstance(logits, list):
labels = labels.reshape(-1) labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1)) logits = logits.reshape(-1, logits.size(-1))
@ -462,7 +489,9 @@ class LoraFinetuningSingleDevice:
for curr_epoch in range(self.epochs_run, self.total_epochs): for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs # Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True # in case shuffle is True
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log") metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log"
)
self._training_sampler.set_epoch(curr_epoch) self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0 loss_to_log = 0.0
@ -470,7 +499,8 @@ class LoraFinetuningSingleDevice:
for idx, batch in enumerate(self._training_dataloader): for idx, batch in enumerate(self._training_dataloader):
if ( if (
self.max_steps_per_epoch is not None self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
): ):
break break
@ -478,7 +508,9 @@ class LoraFinetuningSingleDevice:
# Calculate the number of unmasked tokens in the current batch # Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step # and increment the total number of tokens seen in the step
current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum() current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens num_tokens += current_num_tokens
# Loss is normalized by default so we multiply by the number of tokens # Loss is normalized by default so we multiply by the number of tokens
@ -503,7 +535,9 @@ class LoraFinetuningSingleDevice:
loss_to_log = running_loss.item() / num_tokens loss_to_log = running_loss.item() / num_tokens
pbar.update(1) pbar.update(1)
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}") pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)
time_per_step = time.perf_counter() - t0 time_per_step = time.perf_counter() - t0
log_dict = { log_dict = {

View file

@ -24,7 +24,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn,
)
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
@ -62,11 +64,15 @@ class BasicScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]: async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [ scoring_fn_defs_list = [
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs() fn_def
for impl in self.scoring_fn_id_impls.values()
for fn_def in impl.get_supported_scoring_fn_defs()
] ]
for f in scoring_fn_defs_list: for f in scoring_fn_defs_list:
assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! " assert f.identifier.startswith(
"basic"
), "All basic scoring fn must have identifier prefixed with 'basic'! "
return scoring_fn_defs_list return scoring_fn_defs_list
@ -80,9 +86,11 @@ class BasicScoringImpl(
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )
@ -110,8 +118,12 @@ class BasicScoringImpl(
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None) scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params) score_results = await scoring_fn.score(
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params) input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(
score_rows=score_results, score_rows=score_results,
aggregated_results=agg_results, aggregated_results=agg_results,

View file

@ -122,10 +122,12 @@ class BraintrustScoringImpl(
self.datasets_api = datasets_api self.datasets_api = datasets_api
self.braintrust_evaluators = { self.braintrust_evaluators = {
entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY entry.identifier: entry.evaluator
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
} }
self.supported_fn_defs_registry = { self.supported_fn_defs_registry = {
entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY entry.identifier: entry.fn_def
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
} }
async def initialize(self) -> None: ... async def initialize(self) -> None: ...
@ -135,14 +137,16 @@ class BraintrustScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]: async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values()) scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
for f in scoring_fn_defs_list: for f in scoring_fn_defs_list:
assert f.identifier.startswith("braintrust"), ( assert f.identifier.startswith(
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! " "braintrust"
) ), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
return scoring_fn_defs_list return scoring_fn_defs_list
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
raise NotImplementedError("Registering scoring function not allowed for braintrust provider") raise NotImplementedError(
"Registering scoring function not allowed for braintrust provider"
)
async def set_api_key(self) -> None: async def set_api_key(self) -> None:
# api key is in the request headers # api key is in the request headers
@ -165,13 +169,17 @@ class BraintrustScoringImpl(
await self.set_api_key() await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions) res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions
)
if save_results_dataset: if save_results_dataset:
# TODO: persist and register dataset on to server for reading # TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset() # self.datasets_api.register_dataset()
@ -212,8 +220,13 @@ class BraintrustScoringImpl(
if scoring_fn_id not in self.supported_fn_defs_registry: if scoring_fn_id not in self.supported_fn_defs_registry:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows] score_results = [
aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
aggregation_functions = self.supported_fn_defs_registry[
scoring_fn_id
].params.aggregation_functions
# override scoring_fn params if provided # override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None: if scoring_functions[scoring_fn_id] is not None:

View file

@ -54,9 +54,9 @@ class LlmAsJudgeScoringImpl(
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs() scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs(): for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
assert f.identifier.startswith("llm-as-judge"), ( assert f.identifier.startswith(
"All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! " "llm-as-judge"
) ), "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
return scoring_fn_defs_list return scoring_fn_defs_list
@ -70,9 +70,11 @@ class LlmAsJudgeScoringImpl(
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )
@ -98,8 +100,12 @@ class LlmAsJudgeScoringImpl(
for scoring_fn_id in scoring_functions.keys(): for scoring_fn_id in scoring_functions.keys():
scoring_fn = self.llm_as_judge_fn scoring_fn = self.llm_as_judge_fn
scoring_fn_params = scoring_functions.get(scoring_fn_id, None) scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params) score_results = await scoring_fn.score(
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params) input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(
score_rows=score_results, score_rows=score_results,
aggregated_results=agg_results, aggregated_results=agg_results,

View file

@ -73,7 +73,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
await self.kvstore.delete(key=key) await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id] del self.dataset_infos[dataset_id]
async def get_rows_paginated( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
rows_in_page: int, rows_in_page: int,