mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 11:08:20 +00:00
fix all iterrows callsites
This commit is contained in:
parent
f117407af6
commit
b561cfd902
6 changed files with 56 additions and 163 deletions
|
@ -229,9 +229,7 @@ def run_evaluation_3():
|
|||
output_res[scoring_fn] = []
|
||||
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
||||
|
||||
progress_text_container.write(
|
||||
f"Expand to see current processed result ({i + 1} / {len(rows)})"
|
||||
)
|
||||
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
|
||||
results_container.json(eval_res, expanded=2)
|
||||
|
||||
progress_bar.progress(1.0, text="Evaluation complete!")
|
||||
|
|
|
@ -89,16 +89,10 @@ class MetaReferenceEvalImpl(
|
|||
dataset_id = task_def.dataset_id
|
||||
scoring_functions = task_def.scoring_functions
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
|
||||
)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
limit=(
|
||||
-1
|
||||
if benchmark_config.num_examples is None
|
||||
else benchmark_config.num_examples
|
||||
),
|
||||
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
)
|
||||
res = await self.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
|
@ -124,14 +118,10 @@ class MetaReferenceEvalImpl(
|
|||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [
|
||||
UserMessage(**x) for x in input_messages if x["role"] == "user"
|
||||
]
|
||||
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
session_create_response = await self.agents_api.create_agent_session(
|
||||
agent_id, f"session-{i}"
|
||||
)
|
||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
turn_request = dict(
|
||||
|
@ -140,12 +130,7 @@ class MetaReferenceEvalImpl(
|
|||
messages=input_messages,
|
||||
stream=True,
|
||||
)
|
||||
turn_response = [
|
||||
chunk
|
||||
async for chunk in await self.agents_api.create_agent_turn(
|
||||
**turn_request
|
||||
)
|
||||
]
|
||||
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
|
||||
final_event = turn_response[-1].event.payload
|
||||
|
||||
# check if there's a memory retrieval step and extract the context
|
||||
|
@ -154,14 +139,10 @@ class MetaReferenceEvalImpl(
|
|||
if step.step_type == StepType.tool_execution.value:
|
||||
for tool_response in step.tool_responses:
|
||||
if tool_response.tool_name == MEMORY_QUERY_TOOL:
|
||||
memory_rag_context = " ".join(
|
||||
x.text for x in tool_response.content
|
||||
)
|
||||
memory_rag_context = " ".join(x.text for x in tool_response.content)
|
||||
|
||||
agent_generation = {}
|
||||
agent_generation[ColumnName.generated_answer.value] = (
|
||||
final_event.turn.output_message.content
|
||||
)
|
||||
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
|
||||
if memory_rag_context:
|
||||
agent_generation[ColumnName.context.value] = memory_rag_context
|
||||
|
||||
|
@ -173,9 +154,7 @@ class MetaReferenceEvalImpl(
|
|||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
assert (
|
||||
candidate.sampling_params.max_tokens is not None
|
||||
), "SamplingParams.max_tokens must be provided"
|
||||
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||
|
||||
generations = []
|
||||
for x in tqdm(input_rows):
|
||||
|
@ -186,39 +165,21 @@ class MetaReferenceEvalImpl(
|
|||
content=input_content,
|
||||
sampling_params=candidate.sampling_params,
|
||||
)
|
||||
generations.append(
|
||||
{
|
||||
ColumnName.generated_answer.value: response.completion_message.content
|
||||
}
|
||||
)
|
||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(
|
||||
x[ColumnName.chat_completion_input.value]
|
||||
)
|
||||
input_messages = [
|
||||
UserMessage(**x)
|
||||
for x in chat_completion_input_json
|
||||
if x["role"] == "user"
|
||||
]
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += [
|
||||
SystemMessage(**x)
|
||||
for x in chat_completion_input_json
|
||||
if x["role"] == "system"
|
||||
]
|
||||
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=candidate.model,
|
||||
messages=messages,
|
||||
sampling_params=candidate.sampling_params,
|
||||
)
|
||||
generations.append(
|
||||
{
|
||||
ColumnName.generated_answer.value: response.completion_message.content
|
||||
}
|
||||
)
|
||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
else:
|
||||
raise ValueError("Invalid input row")
|
||||
|
||||
|
@ -241,8 +202,7 @@ class MetaReferenceEvalImpl(
|
|||
|
||||
# scoring with generated_answer
|
||||
score_input_rows = [
|
||||
input_r | generated_r
|
||||
for input_r, generated_r in zip(input_rows, generations, strict=False)
|
||||
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
||||
]
|
||||
|
||||
if benchmark_config.scoring_params is not None:
|
||||
|
@ -251,9 +211,7 @@ class MetaReferenceEvalImpl(
|
|||
for scoring_fn_id in scoring_functions
|
||||
}
|
||||
else:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: None for scoring_fn_id in scoring_functions
|
||||
}
|
||||
scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
|
||||
|
||||
score_response = await self.scoring_api.score(
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||
|
|
|
@ -17,7 +17,8 @@ import torch
|
|||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training, utils as torchtune_utils
|
||||
from torchtune import modules, training
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.data import padded_collate_sft
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
|
@ -88,9 +89,7 @@ class LoraFinetuningSingleDevice:
|
|||
self.job_uuid = job_uuid
|
||||
self.training_config = training_config
|
||||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
raise ValueError(
|
||||
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
|
||||
)
|
||||
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
|
||||
self.algorithm_config = algorithm_config
|
||||
self._device = torchtune_utils.get_device()
|
||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||
|
@ -99,10 +98,7 @@ class LoraFinetuningSingleDevice:
|
|||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
||||
paths = [
|
||||
Path(checkpoint_dir / f"consolidated.{ext}")
|
||||
for ext in ["pth", "00.pth"]
|
||||
]
|
||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
|
@ -117,9 +113,7 @@ class LoraFinetuningSingleDevice:
|
|||
else:
|
||||
model = resolve_model(self.model_id)
|
||||
if model is None:
|
||||
raise ValueError(
|
||||
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
|
||||
)
|
||||
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
|
@ -191,9 +185,7 @@ class LoraFinetuningSingleDevice:
|
|||
self._tokenizer = await self._setup_tokenizer()
|
||||
log.info("Tokenizer is initialized.")
|
||||
|
||||
self._optimizer = await self._setup_optimizer(
|
||||
optimizer_config=self.training_config.optimizer_config
|
||||
)
|
||||
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
|
||||
log.info("Optimizer is initialized.")
|
||||
|
||||
self._loss_fn = CEWithChunkedOutputLoss()
|
||||
|
@ -221,13 +213,8 @@ class LoraFinetuningSingleDevice:
|
|||
# by the dataloader and the max_steps_per_epoch param set by the user and is used
|
||||
# for logging and tracking training state. This should be computed after the dataloader
|
||||
# has been setup
|
||||
self._steps_per_epoch = (
|
||||
len(self._training_dataloader) // self._gradient_accumulation_steps
|
||||
)
|
||||
if (
|
||||
self.max_steps_per_epoch is not None
|
||||
and self.max_steps_per_epoch < self._steps_per_epoch
|
||||
):
|
||||
self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps
|
||||
if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch:
|
||||
self._steps_per_epoch = self.max_steps_per_epoch
|
||||
self.global_step = self.epochs_run * self._steps_per_epoch
|
||||
|
||||
|
@ -241,9 +228,7 @@ class LoraFinetuningSingleDevice:
|
|||
log.info("Learning rate scheduler is initialized.")
|
||||
|
||||
# Used to ignore labels for loss computation
|
||||
self.ignore_labels_cache = torch.full(
|
||||
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
|
||||
)
|
||||
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
|
||||
|
||||
def _log_memory_stats(self):
|
||||
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
|
||||
|
@ -284,13 +269,9 @@ class LoraFinetuningSingleDevice:
|
|||
set_trainable_params(model, self.adapter_params)
|
||||
|
||||
if enable_activation_checkpointing:
|
||||
training.set_activation_checkpointing(
|
||||
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
|
||||
)
|
||||
training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer})
|
||||
|
||||
base_missing, base_unexpected = model.load_state_dict(
|
||||
base_model_state_dict, strict=False
|
||||
)
|
||||
base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False)
|
||||
|
||||
# This is for any adapters that need to be initialized after base weights
|
||||
# have been loaded (e.g. DoRA).
|
||||
|
@ -299,9 +280,7 @@ class LoraFinetuningSingleDevice:
|
|||
if hasattr(m, "initialize_dora_magnitude"):
|
||||
m.initialize_dora_magnitude()
|
||||
if lora_weights_state_dict:
|
||||
lora_missing, lora_unexpected = model.load_state_dict(
|
||||
lora_weights_state_dict, strict=False
|
||||
)
|
||||
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
|
||||
else:
|
||||
lora_missing, lora_unexpected = None, None
|
||||
validate_missing_and_unexpected_for_lora(
|
||||
|
@ -315,14 +294,10 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
|
||||
# Validate model adapter params were loaded in with the expected dtype
|
||||
training.validate_expected_param_dtype(
|
||||
self.adapter_params.items(), dtype=self._dtype
|
||||
)
|
||||
training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype)
|
||||
|
||||
# activation offloading
|
||||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
|
||||
model, enable_activation_offloading
|
||||
)
|
||||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
|
||||
|
||||
self._log_memory_stats()
|
||||
|
||||
|
@ -458,9 +433,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Shift labels to compute loss
|
||||
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
|
||||
# But this way we dont need to slice the logits. We just add an ignore index to labels.
|
||||
labels = torch.hstack(
|
||||
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
|
||||
)
|
||||
labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]))
|
||||
if not isinstance(logits, list):
|
||||
labels = labels.reshape(-1)
|
||||
logits = logits.reshape(-1, logits.size(-1))
|
||||
|
@ -489,9 +462,7 @@ class LoraFinetuningSingleDevice:
|
|||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||
# in case shuffle is True
|
||||
metric_logger = DiskLogger(
|
||||
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log"
|
||||
)
|
||||
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log")
|
||||
self._training_sampler.set_epoch(curr_epoch)
|
||||
loss_to_log = 0.0
|
||||
|
||||
|
@ -499,8 +470,7 @@ class LoraFinetuningSingleDevice:
|
|||
for idx, batch in enumerate(self._training_dataloader):
|
||||
if (
|
||||
self.max_steps_per_epoch is not None
|
||||
and (idx // self._gradient_accumulation_steps)
|
||||
== self.max_steps_per_epoch
|
||||
and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch
|
||||
):
|
||||
break
|
||||
|
||||
|
@ -508,9 +478,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
# Calculate the number of unmasked tokens in the current batch
|
||||
# and increment the total number of tokens seen in the step
|
||||
current_num_tokens = (
|
||||
batch["labels"] != self._loss_fn.ignore_index
|
||||
).sum()
|
||||
current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
|
||||
num_tokens += current_num_tokens
|
||||
|
||||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
|
@ -535,9 +503,7 @@ class LoraFinetuningSingleDevice:
|
|||
loss_to_log = running_loss.item() / num_tokens
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_description(
|
||||
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
|
||||
)
|
||||
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
|
||||
|
||||
time_per_step = time.perf_counter() - t0
|
||||
log_dict = {
|
||||
|
|
|
@ -64,15 +64,11 @@ class BasicScoringImpl(
|
|||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def
|
||||
for impl in self.scoring_fn_id_impls.values()
|
||||
for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
]
|
||||
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith(
|
||||
"basic"
|
||||
), "All basic scoring fn must have identifier prefixed with 'basic'! "
|
||||
assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! "
|
||||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
|
@ -86,9 +82,7 @@ class BasicScoringImpl(
|
|||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
|
@ -118,12 +112,8 @@ class BasicScoringImpl(
|
|||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(
|
||||
input_rows, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
agg_results = await scoring_fn.aggregate(
|
||||
score_results, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
||||
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
|
|
|
@ -122,12 +122,10 @@ class BraintrustScoringImpl(
|
|||
self.datasets_api = datasets_api
|
||||
|
||||
self.braintrust_evaluators = {
|
||||
entry.identifier: entry.evaluator
|
||||
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
}
|
||||
self.supported_fn_defs_registry = {
|
||||
entry.identifier: entry.fn_def
|
||||
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
@ -137,16 +135,14 @@ class BraintrustScoringImpl(
|
|||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith(
|
||||
"braintrust"
|
||||
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
||||
assert f.identifier.startswith("braintrust"), (
|
||||
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
||||
)
|
||||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
|
||||
raise NotImplementedError(
|
||||
"Registering scoring function not allowed for braintrust provider"
|
||||
)
|
||||
raise NotImplementedError("Registering scoring function not allowed for braintrust provider")
|
||||
|
||||
async def set_api_key(self) -> None:
|
||||
# api key is in the request headers
|
||||
|
@ -169,17 +165,13 @@ class BraintrustScoringImpl(
|
|||
await self.set_api_key()
|
||||
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
limit=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.data, scoring_functions=scoring_functions
|
||||
)
|
||||
res = await self.score(input_rows=all_rows.data, scoring_functions=scoring_functions)
|
||||
if save_results_dataset:
|
||||
# TODO: persist and register dataset on to server for reading
|
||||
# self.datasets_api.register_dataset()
|
||||
|
@ -220,13 +212,8 @@ class BraintrustScoringImpl(
|
|||
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
|
||||
score_results = [
|
||||
await self.score_row(input_row, scoring_fn_id)
|
||||
for input_row in input_rows
|
||||
]
|
||||
aggregation_functions = self.supported_fn_defs_registry[
|
||||
scoring_fn_id
|
||||
].params.aggregation_functions
|
||||
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
|
||||
aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
|
||||
|
||||
# override scoring_fn params if provided
|
||||
if scoring_functions[scoring_fn_id] is not None:
|
||||
|
|
|
@ -54,9 +54,9 @@ class LlmAsJudgeScoringImpl(
|
|||
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
|
||||
|
||||
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
|
||||
assert f.identifier.startswith(
|
||||
"llm-as-judge"
|
||||
), "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
|
||||
assert f.identifier.startswith("llm-as-judge"), (
|
||||
"All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
|
||||
)
|
||||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
|
@ -70,9 +70,7 @@ class LlmAsJudgeScoringImpl(
|
|||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
|
@ -100,12 +98,8 @@ class LlmAsJudgeScoringImpl(
|
|||
for scoring_fn_id in scoring_functions.keys():
|
||||
scoring_fn = self.llm_as_judge_fn
|
||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(
|
||||
input_rows, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
agg_results = await scoring_fn.aggregate(
|
||||
score_results, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
||||
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue